structure SimpleSeq: SIMPLE_SEQ =
struct
  type 'a seq = 'a ArraySlice.slice
  type 'a t = 'a seq

  fun from_array_seq a = a
  fun to_array_seq s = s

  fun length s = ArraySlice.length s
  fun nth s i = ArraySlice.sub (s, i)

  fun tabulate n f =
    let
      val a = ForkJoin.alloc n
    in
      ForkJoin.parform (0, n) (fn i => Array.update (a, i, f i));
      ArraySlice.full a
    end

  fun map s f =
    tabulate (length s) (fn i => f (nth s i))

  fun reduce g z s =
    ForkJoin.reducem g z (0, length s) (fn i => nth s i)


  val block_size = 1000

  fun foldl g b (lo, hi) f =
    if lo >= hi then b
    else let val b' = g (b, f lo) in foldl g b' (lo + 1, hi) f end

  fun scan g b input =
    let
      val lo = 0
      val hi = length input
      val f = nth input
    in
      if hi - lo <= block_size then
        let
          val n = hi - lo
          val result = ForkJoin.alloc n
          fun bump ((j, b), x) =
            (Array.update (result, j, b); (j + 1, g (b, x)))
          val (_, total) = foldl bump (0, b) (lo, hi) f
        in
          (ArraySlice.slice (result, 0, SOME n), total)
        end
      else
        let
          val n = hi - lo
          val k = block_size
          val m = 1 + (n - 1) div k (* number of blocks *)
          val sums = tabulate m (fn i =>
            let val start = lo + i * k
            in foldl g b (start, Int.min (start + k, hi)) f
            end)
          val (partials, total) = scan g b sums
          val result = ForkJoin.alloc n
        in
          ForkJoin.parform (0, m) (fn i =>
            let
              fun bump ((j, b), x) =
                (Array.update (result, j, b); (j + 1, g (b, x)))
              val start = lo + i * k
            in
              foldl bump (i * k, nth partials i)
                (start, Int.min (start + k, hi)) f;
              ()
            end);

          (ArraySlice.full result, total)
        end
    end


  (* scan_map: ('a * 'a -> 'a) -> 'a -> 'a seq -> ('a -> 'b) -> 'b seq * 'b *)
  fun scan_map g b input out_f =
    let
      val lo = 0
      val hi = length input
      val f = nth input
    in
      if hi - lo <= block_size then
        let
          val n = hi - lo
          val result = ForkJoin.alloc n
          fun bump ((j, b), x) =
            (Array.update (result, j, out_f b); (j + 1, g (b, x)))
          val (_, total) = foldl bump (0, b) (lo, hi) f
        in
          (ArraySlice.slice (result, 0, SOME n), out_f total)
        end
      else
        let
          val n = hi - lo
          val k = block_size
          val m = 1 + (n - 1) div k (* number of blocks *)
          val sums = tabulate m (fn i =>
            let val start = lo + i * k
            in foldl g b (start, Int.min (start + k, hi)) f
            end)
          val (partials, total) = scan g b sums
          val result = ForkJoin.alloc n
        in
          ForkJoin.parform (0, m) (fn i =>
            let
              fun bump ((j, b), x) =
                (Array.update (result, j, out_f b); (j + 1, g (b, x)))
              val start = lo + i * k
            in
              foldl bump (i * k, nth partials i)
                (start, Int.min (start + k, hi)) f;
              ()
            end);

          (ArraySlice.full result, out_f total)
        end
    end

end
