structure Parallel:
sig
  val parfor: int * int -> (int -> unit) -> unit

  val tabulate: (int * int) -> (int -> 'a) -> 'a Seq.t

  val reduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a

  val scan: ('a * 'a -> 'a)
            -> 'a
            -> (int * int)
            -> (int -> 'a)
            -> 'a Seq.t (* length N+1, for both inclusive and exclusive scan *)

  val filter: (int * int) -> (int -> 'a) -> (int -> bool) -> 'a Seq.t

  val tabFilter: (int * int) -> (int -> 'a option) -> 'a Seq.t

  (* These functions use different evaluation orders of the combination
   * function and are useful for testing for associativity. They are not as
   * efficient as `reduce` or `scan` above, but will still behave correctly as
   * long as the input function is associative.
   *)
  val reduce_testing_1: ('a * 'a -> 'a)
                        -> 'a
                        -> (int * int)
                        -> (int -> 'a)
                        -> 'a
  val reduce_testing_2: ('a * 'a -> 'a)
                        -> 'a
                        -> (int * int)
                        -> (int -> 'a)
                        -> 'a
  val reduce_testing_3: ('a * 'a -> 'a)
                        -> 'a
                        -> (int * int)
                        -> (int -> 'a)
                        -> 'a
  val reduce_testing_4: ('a * 'a -> 'a)
                        -> 'a
                        -> (int * int)
                        -> (int -> 'a)
                        -> 'a

  val scan_testing_1: ('a * 'a -> 'a)
                      -> 'a
                      -> (int * int)
                      -> (int -> 'a)
                      -> 'a Seq.t

  val scan_testing_2: ('a * 'a -> 'a)
                      -> 'a
                      -> (int * int)
                      -> (int -> 'a)
                      -> 'a Seq.t

  val scan_testing_3: ('a * 'a -> 'a)
                      -> 'a
                      -> (int * int)
                      -> (int -> 'a)
                      -> 'a Seq.t
end =
struct

  structure A = Array
  structure AS = ArraySlice

  val w2i = Word64.toIntX
  val i2w = Word64.fromInt


  val grain = 4
  val block_size = 100


  fun for (wlo, whi) f =
    if wlo >= whi then () else (f (w2i wlo); for (wlo + 0w1, whi) f)


  fun parfor (lo, hi) f =
    ForkJoin.parform (lo, hi) f


  fun upd a i x = A.update (a, i, x)
  fun nth a i = A.sub (a, i)


  val par = ForkJoin.par
  val allocate = ForkJoin.alloc


  fun tabulate_ (lo, hi) f =
    let
      val n = hi - lo
      val result = allocate n
    in
      if lo = 0 then parfor (0, n) (fn i => upd result i (f i))
      else parfor (0, n) (fn i => upd result i (f (lo + i)));

      result
    end


  fun foldl g b (lo, hi) f =
    if lo >= hi then b
    else let val b' = g (b, f lo) in foldl g b' (lo + 1, hi) f end


  fun foldr g b (lo, hi) f =
    if lo >= hi then
      b
    else
      let
        val hi' = hi - 1
        val b' = g (b, f hi')
      in
        foldr g b' (lo, hi') f
      end


  fun reduce g b (lo, hi) f =
    ForkJoin.reducem g b (lo, hi) f


  fun scan_ g b (lo, hi) (f: int -> 'a) =
    if hi - lo <= block_size then
      let
        val n = hi - lo
        val result = allocate (n + 1)
        fun bump ((j, b), x) =
          (upd result j b; (j + 1, g (b, x)))
        val (_, total) = foldl bump (0, b) (lo, hi) f
      in
        upd result n total;
        result
      end
    else
      let
        val n = hi - lo
        val k = block_size
        val m = 1 + (n - 1) div k (* number of blocks *)
        val sums = tabulate_ (0, m) (fn i =>
          let val start = lo + i * k
          in reduce g b (start, Int.min (start + k, hi)) f
          end)
        val partials = scan_ g b (0, m) (nth sums)
        val result = allocate (n + 1)
      in
        parfor (0, m) (fn i =>
          let
            fun bump ((j, b), x) =
              (upd result j b; (j + 1, g (b, x)))
            val start = lo + i * k
          in
            foldl bump (i * k, nth partials i) (start, Int.min (start + k, hi))
              f;
            ()
          end);
        upd result n (nth partials m);
        result
      end


  fun filter_ (lo, hi) f g =
    let
      val n = hi - lo
      val k = block_size
      val m = 1 + (n - 1) div k (* number of blocks *)
      val counts = tabulate_ (0, m) (fn i =>
        let
          val start = lo + i * k
        in
          reduce op+ 0 (start, Int.min (start + k, hi)) (fn j =>
            if g j then 1 else 0)
        end)
      val offsets = scan_ op+ 0 (0, m) (nth counts)
      val result = allocate (nth offsets m)
      fun filterSeq (i, j) c =
        if i >= j then ()
        else if g i then (upd result c (f i); filterSeq (i + 1, j) (c + 1))
        else filterSeq (i + 1, j) c
    in
      parfor (0, m) (fn i =>
        let val start = lo + i * k
        in filterSeq (start, Int.min (start + k, hi)) (nth offsets i)
        end);
      result
    end


  fun tabFilter_ (lo, hi) (f: int -> 'a option) =
    let
      val n = hi - lo
      val k = block_size
      val m = 1 + (n - 1) div k (* number of blocks *)
      val tmp = allocate n

      fun filterSeq (i, j, k) =
        if (i >= j) then
          k
        else
          case f i of
            NONE => filterSeq (i + 1, j, k)
          | SOME v => (A.update (tmp, k, v); filterSeq (i + 1, j, k + 1))

      val counts = tabulate_ (0, m) (fn i =>
        let
          val last = filterSeq
            (lo + i * k, lo + Int.min ((i + 1) * k, n), i * k)
        in
          last - i * k
        end)

      val outOff = scan_ op+ 0 (0, m) (fn i => A.sub (counts, i))
      val outSize = A.sub (outOff, m)

      val result = allocate outSize
    in
      parfor (0, m) (fn i =>
        let
          val soff = i * k
          val doff = A.sub (outOff, i)
          val size = A.sub (outOff, i + 1) - doff
        in
          parfor (0, size) (fn j =>
            A.update (result, doff + j, A.sub (tmp, soff + j)))
        end);
      result
    end


  fun tabulate (lo, hi) f =
    ArraySlice.full (tabulate_ (lo, hi) f)
  fun scan g z (lo, hi) f =
    ArraySlice.full (scan_ g z (lo, hi) f)
  fun filter (lo, hi) f g =
    ArraySlice.full (filter_ (lo, hi) f g)
  fun tabFilter (lo, hi) f =
    ArraySlice.full (tabFilter_ (lo, hi) f)


  (* ========================================================================
   * implementations of reduce for testing
   *)


  fun reduce_testing_1 g b (lo, hi) f =
    if lo >= hi then
      b
    else if lo + 1 = hi then
      f lo
    else
      let
        val mid = lo + (hi - lo) div 2
        val (left, right) =
          ForkJoin.par (fn () => reduce_testing_1 g b (lo, mid) f, fn () =>
            reduce_testing_1 g b (mid, hi) f)
      in
        g (left, right)
      end


  fun reduce_testing_2 g b (lo, hi) f =
    let
      fun loop i acc =
        if i <= lo then acc else loop (i - 1) (g (f (i - 1), acc))
    in
      if lo >= hi then b else loop (hi - 1) (f (hi - 1))
    end


  fun reduce_testing_3 g b (lo, hi) f =
    if lo >= hi then
      b
    else if lo + 1 = hi then
      f lo
    else if lo + 2 = hi then
      g (f lo, f (lo + 1))
    else
      let
        val mid = lo + (hi - lo) div 2
        val mid_elem = f mid
        val (left, right) =
          ForkJoin.par (fn () => reduce_testing_3 g b (lo, mid) f, fn () =>
            reduce_testing_3 g b (mid + 1, hi) f)
      in
        g (left, g (mid_elem, right))
      end


  fun reduce_testing_4 g b (lo, hi) f =
    let
      val n = hi - lo
      val target_num_chunks = Util.ceilDiv n 10

      fun make_chunks slop_factor =
        let
          val offsets =
            scan_ op+ 0 (0, slop_factor * target_num_chunks) (fn i =>
              2 + Util.hash i mod 10)
          val i = BinarySearch.countLess Int.compare (ArraySlice.full offsets) n
        (* val _ = print
          ("make_chunks " ^ Int.toString n ^ " " ^ Int.toString slop_factor
           ^ " " ^ Int.toString target_num_chunks ^ " " ^ Int.toString i
           ^ "\n") *)
        in
          if i > slop_factor * target_num_chunks then
            make_chunks (slop_factor + 1)
          else
            ArraySlice.slice (offsets, 0, SOME i)
        end

      val offsets = make_chunks 2
    (* val _ = print ("num chunks " ^ Int.toString (Seq.length offsets) ^ "\n") *)
    in
      reduce_testing_1 g b (0, Seq.length offsets) (fn ci =>
        let
          val start = lo + Seq.nth offsets ci
          val stop =
            if ci = Seq.length offsets - 1 then hi
            else (lo + Seq.nth offsets (ci + 1))
        in
          case Util.hash ci mod 3 of
            0 => reduce_testing_1 g b (start, stop) f
          | 1 => reduce_testing_2 g b (start, stop) f
          | 2 => reduce_testing_3 g b (start, stop) f
          | _ => Util.die ("bug! error in Parallel.reduce_testing_4\n")
        end)
    end


  (* fun reduce_testing_1_ a b c d =
    (print ("hello from reduce_testing_1\n"); reduce_testing_1 a b c d)
  fun reduce_testing_2_ a b c d =
    (print ("hello from reduce_testing_2\n"); reduce_testing_2 a b c d)
  fun reduce_testing_3_ a b c d =
    (print ("hello from reduce_testing_3\n"); reduce_testing_3 a b c d)
  fun reduce_testing_4_ a b c d =
    (print ("hello from reduce_testing_4\n"); reduce_testing_4 a b c d)
  
  val reduce_testing_1 = reduce_testing_1_
  val reduce_testing_2 = reduce_testing_2_
  val reduce_testing_3 = reduce_testing_3_
  val reduce_testing_4 = reduce_testing_4_ *)


  (* ========================================================================
   * implementations of scan for testing
   *)

  fun scan_testing_1 g b (lo, hi) (f: int -> 'a) =
    if lo >= hi then
      Seq.singleton b
    else if lo + 1 = hi then
      Seq.fromList [b, f lo]
    else
      let
        val mid = lo + (hi - lo) div 2
        val (p1, p2) =
          ForkJoin.par (fn () => scan_testing_1 g b (lo, mid) f, fn () =>
            scan_testing_1 g b (mid, hi) f)
        val t1 = Seq.nth p1 (mid - lo)
      in
        Seq.append (Seq.subseq p1 (0, mid - lo), Seq.map (fn x => g (t1, x)) p2)
      end


  fun scan_testing_2 g b (lo, hi) f =
    if lo >= hi then
      Seq.singleton b
    else if lo + 1 = hi then
      Seq.fromList [b, f lo]
    else
      let
        val n = hi - lo
        val half = Util.ceilDiv n 2
        fun get i =
          f (lo + i)
        val contracted = tabulate (0, half) (fn i =>
          if 2 * i + 1 = n then get (2 * i)
          else g (get (2 * i), get (2 * i + 1)))
        val recursive_sums = scan_testing_2 g b (0, half) (Seq.nth contracted)
        fun output_elem i =
          if i mod 2 = 0 then Seq.nth recursive_sums (i div 2)
          else g (Seq.nth recursive_sums (i div 2), get (i - 1))
        val expanded = tabulate (0, n + 1) output_elem
      in
        expanded
      end


  fun scan_testing_3_ g b (lo, hi) (f: int -> 'a) =
    if hi - lo <= 3 then
      let
        val n = hi - lo
        val result = allocate (n + 1)
        fun bump ((j, b), x) =
          (upd result j b; (j + 1, g (b, x)))
        val (_, total) = foldl bump (0, b) (lo, hi) f
      in
        upd result n total;
        result
      end
    else
      let
        val n = hi - lo
        val k = 3
        val m = 1 + (n - 1) div k (* number of blocks *)
        val sums = tabulate_ (0, m) (fn i =>
          let val start = lo + i * k
          in reduce_testing_2 g b (start, Int.min (start + k, hi)) f
          end)
        val partials = scan_testing_3_ g b (0, m) (nth sums)
        val result = allocate (n + 1)
      in
        parfor (0, m) (fn i =>
          let
            fun bump ((j, b), x) =
              (upd result j b; (j + 1, g (b, x)))
            val start = lo + i * k
          in
            foldl bump (i * k, nth partials i) (start, Int.min (start + k, hi))
              f;
            ()
          end);
        upd result n (nth partials m);
        result
      end

  fun scan_testing_3 a b c d =
    ArraySlice.full (scan_testing_3_ a b c d)


(* fun scan_testing_1_ a b c d =
  (print ("hello from scan_testing_1\n"); scan_testing_1 a b c d)
fun scan_testing_2_ a b c d =
  (print ("hello from scan_testing_2\n"); scan_testing_2 a b c d)
fun scan_testing_3_ a b c d =
  (print ("hello from scan_testing_3\n"); scan_testing_3 a b c d)

val scan_testing_1 = scan_testing_1_
val scan_testing_2 = scan_testing_2_
val scan_testing_3 = scan_testing_3_ *)

end
