structure UpDownPrefixSums:
sig
  val up_down_prefix_sums: int Seq.t -> int Seq.t
end =
struct
  datatype tree = Leaf of int | Node of int * tree * tree

  fun sum_of (t) =
    case t of
      Leaf x => x
    | Node (x, _, _) => x

  fun upsweep (s: int Seq.t) =
    if Seq.length s = 1 then
      Leaf (Seq.nth s 0)
    else
      let
        val half = Seq.length s div 2
        val (t1, t2) =
          ForkJoin.par (fn () => upsweep (Seq.take s half), fn () =>
            upsweep (Seq.drop s half))
      in
        Node (sum_of t1 + sum_of t2, t1, t2)
      end

  fun downsweep (t, acc, output, offset, n) =
    case t of
      Leaf _ => Array.update (output, offset, acc)
    | Node (_, t1, t2) =>
        let
          val half = n div 2
        in
          ForkJoin.par
            ( fn () => downsweep (t1, acc, output, offset, half)
            , fn () =>
                downsweep (t2, sum_of t1 + acc, output, offset + half, n - half)
            );
          ()
        end

  fun up_down_prefix_sums (s: int Seq.t) =
    if Seq.length s = 0 then
      Seq.singleton 0
    else
      let
        val n = Seq.length s
        val t = upsweep s
        val output = ForkJoin.alloc (n + 1)
      in
        downsweep (t, 0, output, 0, n);
        Array.update (output, n, sum_of t);
        ArraySlice.full output
      end
end
