structure ChunkedContractionPrefixSums:
sig
  val chunked_contraction_prefix_sums: {chunk_size: int}
                                       -> int Seq.t
                                       -> int Seq.t
end =
struct

  fun tabulate f n =
    Parallel.tabulate (0, n) f


  fun chunked_contraction_prefix_sums {chunk_size} (s: int Seq.t) =
    if Seq.length s <= chunk_size then
      SequentialPrefixSums.sequential_prefix_sums {init = 0} s
    else
      let
        val n = Seq.length s
        val num_chunks = Util.ceilDiv n chunk_size

        fun input_chunk i =
          let
            val start = i * chunk_size
            val stop = if i = num_chunks - 1 then n else start + chunk_size
          in
            Seq.subseq s (start, stop - start)
          end

        fun contract i =
          let val c = input_chunk i
          in SeqBasis.foldl op+ 0 (0, Seq.length c) (fn i => Seq.nth c i)
          end

        val contracted = tabulate contract num_chunks

        val recursive_sums =
          chunked_contraction_prefix_sums {chunk_size = chunk_size} contracted

        val output = ArraySlice.full (ForkJoin.alloc (n + 1))

        fun output_chunk i =
          let
            val start = i * chunk_size
            val stop = if i = num_chunks - 1 then n + 1 else start + chunk_size
          in
            Seq.subseq output (start, stop - start)
          end

        fun write_output_chunk i =
          SequentialPrefixSums.write_sequential_prefix_sums
            {init = Seq.nth recursive_sums i}
            {input = input_chunk i, output = output_chunk i}
      in
        ForkJoin.parform (0, num_chunks) write_output_chunk;
        output
      end

end
