val % = Seq.fromList
type ipair_seq = (int * int) Seq.t

fun rand_counts seed s =
  Seq.mapIdx (fn (i, x) => (x, 1 + Util.hash (i + seed) mod 3)) s

(* Add more if you'd like! *)
val dedup_merge_inputs: (ipair_seq * ipair_seq) list =
  [ (%[], %[])
  , (%[(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)], %[])
  , (%[], %[(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)])
  , ( %[(0, 5), (2, 10), (3, 1), (10, 1), (11, 1), (12, 1)]
    , %[ (~4, 4)
       , (~3, 3)
       , (~2, 2)
       , (~1, 1)
       , (0, 1)
       , (3, 5)
       , (11, 42)
       , (13, 1)
       , (14, 2)
       , (15, 3)
       , (16, 4)
       ]
    )
  , ( Seq.tabulate (fn i => (2 * i, 1)) 500
    , Seq.tabulate (fn i => (2 * i + 1, 1)) 500
    )
  , let
      val n = 100000
      val left_nums = Seq.tabulate (fn i => i) n
      val ids = Parallel.scan op+ 0 (0, n) (fn i => 1 + Util.hash i mod 4)
      val right_len = BinarySearch.countLess Int.compare ids n
      val right_nums = Seq.tabulate (Seq.nth left_nums o Seq.nth ids) right_len
    in
      ( Seq.map (fn x => (x, 1 + Util.hash x mod 3)) left_nums
      , Seq.map (fn x => (x, 1 + Util.hash (n + x) mod 10)) right_nums
      )
    end
  ]
  @
  [ ( rand_counts 15210 (Seq.tabulate (fn i => i) 100000)
    , rand_counts 15211 (Seq.tabulate (fn i => i + 100000) 100000)
    )
  , ( rand_counts 15212 (Mergesort.sort Int.compare
        (Seq.tabulate (fn i => if i mod 2 = 0 then i else ~i) 100000))
    , rand_counts 15213 (Mergesort.sort Int.compare
        (Seq.tabulate (fn i => if i mod 2 = 0 then i else ~i) 100000))
    )
  ]
  @
  List.tabulate (10, fn seed =>
    let
      val seed = 15214 + seed
      val xs = Seq.tabulate (fn i => Util.hash (seed + i) mod 7500)
        (5000 + (Util.hash seed mod 5000))
      val seed' = Util.hash seed
      val ys = Seq.tabulate (fn i => Util.hash (seed' + i) mod 7500)
        (5000 + (Util.hash seed' mod 5000))
    in
      ( ReferenceImplementations.dedup_sort_serial Int.compare xs
      , ReferenceImplementations.dedup_sort_serial Int.compare ys
      )
    end)
  @
  List.tabulate (12, fn seed =>
    let
      val n = 10000
      val seed = Util.hash (seed + 3033121)
      val data = Seq.tabulate (fn i => i) n

      val right_ids = Parallel.scan op+ 0 (0, n) (fn i =>
        1 + Util.hash (i + seed) mod 4)
      val right_nums = Seq.tabulate (Seq.nth data o Seq.nth right_ids)
        (BinarySearch.countLess Int.compare right_ids n)

      val seed = Util.hash seed
      val left_ids = Parallel.scan op+ 0 (0, n) (fn i =>
        1 + Util.hash (i + seed) mod 4)
      val left_nums = Seq.tabulate (Seq.nth data o Seq.nth left_ids)
        (BinarySearch.countLess Int.compare left_ids n)
    in
      (rand_counts 15215 left_nums, rand_counts 15215 right_nums)
    end)


(* Add more if you'd like! *)
val dedup_sort_inputs: int Seq.t list =
  [ %[]
  , %[42]
  , %[5, 4, 5, 3, 5, 2, 5, 1, 5, 0, 5, 1, 5, 2, 5, 3, 5, 4, 5]
  , Shuffle.shuffle (Seq.tabulate (fn i => i mod 500) 1000) 15210
  , Seq.tabulate (fn i => Util.hash i mod 10000) 100000
  ]
  @
  [ Seq.tabulate (fn i => i) 100000
  , Seq.tabulate (fn i => if i mod 2 = 0 then i else ~i) 100000
  , Seq.tabulate (fn i => 0) 100000
  , Seq.flatten (Seq.tabulate (fn i => Seq.tabulate (fn j => i mod 3) 1513) 10)
  ]
  @
  List.tabulate (11, fn seed =>
    Shuffle.shuffle
      (Seq.tabulate (fn i => i) (5000 + (Util.hash seed mod 5000))) seed)
  @
  List.tabulate (10, fn seed =>
    let
      val seed = Util.hash (15210 + seed)
    in
      Scramble.scramble {target_sortedness = 60, cmp = Int.compare}
        (Seq.tabulate (fn i => Util.hash (seed + i) mod 10000) 10000)
    end)


(* ======================================================================== *)

fun iptos (a, b) =
  "(" ^ Int.toString a ^ "," ^ Int.toString b ^ ")"

fun leftmost (x, y) =
  if Option.isSome x then x else y

fun diff eq (x, y) =
  Parallel.reduce leftmost NONE (0, Int.min (Seq.length x, Seq.length y))
    (fn i => if eq (Seq.nth x i, Seq.nth y i) then NONE else SOME i)

fun btos true = "true"
  | btos false = "false"

fun copy s =
  Seq.map (fn x => x) s

fun check_valid_input kvs =
  Parallel.reduce (fn (a, b) => a andalso b) true (0, Seq.length kvs - 1)
    (fn i =>
       let
         val (k1, c1) = Seq.nth kvs i
         val (k2, c2) = Seq.nth kvs (i + 1)
       in
         c1 >= 1 andalso c2 >= 1 andalso Int.compare (k1, k2) = LESS
       end)

val passed = {score = 1.0, summary = "Passed", details = []}
fun failed details = {score = 0.0, summary = "Failed", details = details}

fun report_dedup_merge (inputs as (inp1, inp2)) result =
  case result of
    Tester.Result.Raised exn => failed ["raised exception: " ^ exnMessage exn]
  | Tester.Result.Okay output =>
      let
        val () =
          if check_valid_input inp1 andalso check_valid_input inp2 then ()
          else raise Fail "whoops! invalid input!"
        (* val () = print (Int.toString (Seq.length inp1) ^ "\n")
        val () = print (Int.toString (Seq.length inp2) ^ "\n") *)
        val expected =
          ReferenceImplementations.dedup_merge_serial Int.compare inputs
      in
        if Seq.length expected <> Seq.length output then
          failed
            ["expected output of length " ^ Int.toString (Seq.length expected)
             ^ " but got output of length " ^ Int.toString (Seq.length output)]
        else
          case diff op= (output, expected) of
            NONE => passed
          | SOME i =>
              failed
                ["expected " ^ iptos (Seq.nth expected i) ^ " at index "
                 ^ Int.toString i ^ " but got " ^ iptos (Seq.nth output i)]
      end


fun report_dedup_sort input result =
  case result of
    Tester.Result.Raised exn => failed ["raised exception: " ^ exnMessage exn]
  | Tester.Result.Okay output =>
      let
        val expected =
          ReferenceImplementations.dedup_sort_serial Int.compare input
      in
        if Seq.length expected <> Seq.length output then
          failed
            ["expected output of length " ^ Int.toString (Seq.length expected)
             ^ " but got output of length " ^ Int.toString (Seq.length output)]
        else
          case diff op= (output, expected) of
            NONE => passed
          | SOME i =>
              failed
                ["expected " ^ iptos (Seq.nth expected i) ^ " at index "
                 ^ Int.toString i ^ " but got " ^ iptos (Seq.nth output i)]
      end


fun make_dedup_merge_test (inp1, inp2) =
  Tester.T
    { input = fn () => (copy inp1, copy inp2)
    , func = DedupSort.dedup_merge Int.compare
    , report = report_dedup_merge
    }

fun make_dedup_sort_test inp =
  Tester.T
    { input = fn () => copy inp
    , func = DedupSort.dedup_sort Int.compare
    , report = report_dedup_sort
    }

val _ = Tester.run_tests "Testing DedupSort.dedup_merge"
  (List.map make_dedup_merge_test dedup_merge_inputs)
val _ = Tester.run_tests "Testing DedupSort.dedup_sort"
  (List.map make_dedup_sort_test dedup_sort_inputs)
