structure ReferenceImplementations :>
sig
  val dedup_insertion_sort_serial: ('a * 'a -> order)
                                   -> 'a Seq.t
                                   -> ('a * int) Seq.t
  val dedup_merge_serial: ('a * 'a -> order)
                          -> ('a * int) Seq.t * ('a * int) Seq.t
                          -> ('a * int) Seq.t
  val dedup_sort_serial: ('a * 'a -> order) -> 'a Seq.t -> ('a * int) Seq.t

  val merge: ('a * 'a -> order) -> 'a Seq.t * 'a Seq.t -> 'a Seq.t
end =
struct

  fun dedup_insertion_sort_serial (cmp: 'a * 'a -> order) (s: 'a Seq.t) =
    let
      val n = Seq.length s
      val items = ForkJoin.alloc n
      val counts = ForkJoin.alloc n

      fun item j = Array.sub (items, j)
      fun count j = Array.sub (counts, j)

      fun get j = (item j, count j)
      fun update j (x, c) =
        (Array.update (items, j, x); Array.update (counts, j, c))

      fun increment_count j =
        Array.update (counts, j, 1 + count j)

      datatype spot = AlreadyPresentAt of int | ShouldInsertAt of int

      fun find_spot j x =
        if j = 0 then
          ShouldInsertAt 0
        else
          case cmp (item (j - 1), x) of
            EQUAL => AlreadyPresentAt (j - 1)
          | GREATER => find_spot (j - 1) x
          | LESS => ShouldInsertAt j

      fun insert j limit x =
        if j = limit then update j x
        else let val y = get j in update j x; insert (j + 1) limit y end

      val output_size = Util.loop (0, n) 0 (fn (curr_output_size, i) =>
        let
          val x = Seq.nth s i
        in
          case find_spot curr_output_size x of
            AlreadyPresentAt j => (increment_count j; curr_output_size)
          | ShouldInsertAt j =>
              (insert j curr_output_size (x, 1); curr_output_size + 1)
        end)
    in
      ArraySlice.full (Array.tabulate (output_size, get))
    end


  fun dedup_merge_serial cmp (s, t) =
    let
      val output = ForkJoin.alloc (Seq.length s + Seq.length t)

      (* i index into s
       * j index into t
       * k index into output
       *)
      fun loop i j k =
        if i >= Seq.length s then
          ( Util.for (0, Seq.length t - j) (fn j' =>
              Array.update (output, k + j', Seq.nth t (j + j')))
          ; k + (Seq.length t - j)
          )
        else if j >= Seq.length t then
          ( Util.for (0, Seq.length s - i) (fn i' =>
              Array.update (output, k + i', Seq.nth s (i + i')))
          ; k + (Seq.length s - i)
          )
        else
          let
            val (x, xcount) = Seq.nth s i
            val (y, ycount) = Seq.nth t j
          in
            case cmp (x, y) of
              GREATER =>
                (Array.update (output, k, (y, ycount)); loop i (j + 1) (k + 1))
            | EQUAL =>
                ( Array.update (output, k, (x, xcount + ycount))
                ; loop (i + 1) (j + 1) (k + 1)
                )
            | LESS =>
                (Array.update (output, k, (x, xcount)); loop (i + 1) j (k + 1))
          end

      val count = loop 0 0 0
    in
      ArraySlice.slice (output, 0, SOME count)
    end


  fun dedup_sort_serial cmp s =
    if Seq.length s <= 1 then
      Seq.map (fn x => (x, 1)) s
    else
      dedup_merge_serial cmp
        ( dedup_sort_serial cmp (Seq.take s (Seq.length s div 2))
        , dedup_sort_serial cmp (Seq.drop s (Seq.length s div 2))
        )

  fun merge_tree cmp (s1, s2) =
    if Seq.length s1 = 0 then
      TFlatten.leaf s2
    else if Seq.length s2 = 0 then
      TFlatten.leaf s1
    else
      let
        val n1 = Seq.length s1
        val n2 = Seq.length s2
        val mid1 = n1 div 2
        val pivot = Seq.nth s1 mid1
        val mid2 = BinarySearch.search cmp s2 pivot

        val l1 = Seq.take s1 mid1
        val r1 = Seq.drop s1 (mid1 + 1)
        val l2 = Seq.take s2 mid2
        val r2 = Seq.drop s2 mid2

        val (outl, outr) =
          ForkJoin.par (fn _ => merge_tree cmp (l1, l2), fn _ =>
            merge_tree cmp (r1, r2))
      in
        TFlatten.node
          (TFlatten.node (outl, TFlatten.leaf (Seq.singleton pivot)), outr)
      end

  fun merge cmp (s1, s2) =
    TFlatten.flatten (merge_tree cmp (s1, s2))

end
