structure MaxSums =
struct

  (* max prefix sum and total *)
  fun mpst (lo, hi) f =
    if lo >= hi then
      (0, 0)
    else if lo+1 = hi then
      let
        val x = f lo
        val best_prefix_here = Int.max (x, 0)
        val total_sum_here = x
      in
        (best_prefix_here, total_sum_here)
      end
    else
      let
        val mid = lo + (hi-lo) div 2
        val ((p1, t1), (p2, t2)) =
          ForkJoin.par (fn () => mpst (lo, mid) f,
                        fn () => mpst (mid, hi) f)
      in
        (Int.max (p1, t1+p2), t1+t2)
      end

  fun mps (lo, hi) f =
    let
      val (p, _) = mpst (lo, hi) f
    in
      p
    end



  fun mps_as_reduce (lo, hi) f =
    let
      fun combine ((p1, t1), (p2, t2)) =
        (Int.max (p1, t1+p2), t1+t2)
      
      val zero = (0, 0)

      val (p, _) =
        ForkJoin.reducem combine zero (lo, hi) (fn i =>
          let
            val x = f i
            val best_prefix_here = Int.max (x, 0)
            val total_sum_here = x
          in
            (best_prefix_here, total_sum_here)
          end)
    in
      p
    end


  (* ===================================================================== *)

  fun reduce' combine zero (lo, hi) f =
    let
      val n = hi - lo
      val chunk_size = 100
      val num_chunks = Util.ceilDiv n chunk_size
    in
      ForkJoin.reducem combine zero (0, num_chunks) (fn ci =>
        let
          val start = lo + ci * chunk_size
          val stop = Int.min (start + chunk_size, hi)
        in
          SeqBasis.foldl combine zero (start, stop) f
        end)
    end

  fun mcss (lo, hi) f =
    let
      fun combine ((p1, t1, s1, b1), (p2, t2, s2, b2)) =
        let
          val p = Int.max (p1, t1+p2)
          val t = t1+t2
          val s = Int.max (s1+t2, s2)
          val b = Int.max (s1+p2, Int.max (b1, b2))
        in
          (p, t, s, b)
        end
       
      val zero = (0, 0, 0, 0)

      val (_, _, _, b) =
        reduce' combine zero (lo, hi) (fn i =>
          let
            val x = f i
            val v = Int.max (x, 0)
          in
            (v, x, v, v)
          end)
    in
      b
    end

end