structure CLA = CommandLineArgs
structure Seq = ArraySequence

val n = CLA.parseInt "n" (1000 * 10)
val impl = CLA.parseString "impl" "reference"
val outfile = CLA.parseString "output" ""
val resolution = CLA.parseInt "resolution" 1000
val heatmap_block_size = CLA.parseInt "heatmap-block-size" 5
val seed = CLA.parseInt "seed" 15210
val _ = print ("n " ^ Int.toString n ^ "\n")
val _ = print ("impl " ^ impl ^ "\n")
val _ = print ("output " ^ outfile ^ "\n")
val _ = print ("resolution " ^ Int.toString resolution ^ "\n")
val _ = print ("heatmap-block-size " ^ Int.toString heatmap_block_size ^ "\n")
val _ = print ("seed " ^ Int.toString seed ^ "\n")

val () =
  if outfile <> "" then
    ()
  else
    Util.die
      ("missing argument: -output FILE.ppm\n"
       ^ "for example: heatmap -n 10000 -resolution 1000 -output heatmap.ppm")

fun rtos x =
  if x < 0.0 then "-" ^ rtos (~x) else Real.fmt (StringCvt.FIX (SOME 3)) x
fun pttos (x, y) =
  String.concat ["(", rtos x, ",", rtos y, ")"]


fun rand_real seed =
  let
    val real_resolution = 1000000
  in
    Real.fromInt (Util.hash seed mod real_resolution)
    / Real.fromInt real_resolution
  end

fun rand_point_in_circle {center = (cx, cy), radius} seed =
  let
    val r = radius * Math.sqrt (rand_real seed)
    val theta = rand_real (seed + 1) * 2.0 * Math.pi
  in
    (cx + r * Math.cos (theta), cy + r * Math.sin (theta))
  end

fun frac n alpha =
  Real.round (Real.fromInt n * alpha)

val inputPts = Seq.flatten (Seq.fromList
  [ Seq.tabulate
      (fn i =>
         rand_point_in_circle {center = (0.0, 0.0), radius = 1.0} (seed + 2 * i))
      (frac n 0.7)
  , Seq.tabulate
      (fn i =>
         rand_point_in_circle {center = (0.8, 0.1), radius = 0.4}
           (Util.hash seed + 2 * i)) (frac n 0.1)
  , Seq.tabulate
      (fn i =>
         rand_point_in_circle {center = (~0.8, ~0.8), radius = 0.2}
           (Util.hash (Util.hash seed) + 2 * i)) (frac n 0.1)
  , Seq.tabulate
      (fn i =>
         rand_point_in_circle {center = (0.0, 1.0), radius = 0.2}
           (Util.hash (Util.hash (Util.hash seed)) + 2 * i)) (frac n 0.1)
  , Seq.fromList [(~1.05, 0.05)]
  ])

val n = Seq.length inputPts
val _ = print ("num points " ^ Int.toString n ^ "\n")


val (minx, maxx, miny, maxy) =
  Parallel.reduce
    (fn ((minx1, maxx1, miny1, maxy1), (minx2, maxx2, miny2, maxy2)) =>
       ( Real.min (minx1, minx2)
       , Real.max (maxx1, maxx2)
       , Real.min (miny1, miny2)
       , Real.max (maxy1, maxy2)
       )) (Real.posInf, Real.negInf, Real.posInf, Real.negInf) (0, n)
    (fn i => let val (x, y) = Seq.nth inputPts i in (x, x, y, y) end)


(*
val _ = print
  (rtos minx ^ " " ^ rtos maxx ^ " " ^ rtos miny ^ " " ^ rtos maxy ^ "\n") *)


val hmbs = heatmap_block_size
val rhmbs = Real.fromInt heatmap_block_size
val log_width = Real.ceil (Real.fromInt resolution * (maxx - minx) / rhmbs)
val log_height = Real.ceil (Real.fromInt resolution * (maxy - miny) / rhmbs)
val log = SeqBasis.tabulate 1000 (0, log_width * log_height) (fn _ => 0)
fun logged_tri_area (p, q, r) =
  let
    val (rx, ry) = r
    val log_x =
      Real.floor (Real.fromInt resolution * (rx - minx) / (maxx - minx) + 0.5)
      div hmbs
    val log_y =
      Real.floor (Real.fromInt resolution * (ry - miny) / (maxy - miny) + 0.5)
      div hmbs
  in
    MLton.Parallel.arrayFetchAndAdd (log, log_y * log_height + log_x) 1;
    Geometry2D.Point.triArea (p, q, r)
  end

fun max_num_calls () =
  Parallel.reduce Int.max 0 (0, Array.length log) (fn i => Array.sub (log, i))
fun total_num_calls log =
  Parallel.reduce op+ 0 (0, Array.length log) (fn i => Array.sub (log, i))
fun clear_log () =
  ForkJoin.parform (0, Array.length log) (fn i => Array.update (log, i, 0))

structure RQH =
  ReferenceQuickhull (type point = real * real val tri_area = logged_tri_area)
structure MQH =
  MyQuickhull (type point = real * real val tri_area = logged_tri_area)


val _ = print
  ("================= running ReferenceQuickhull =================\n")
val (ref_result, ref_log) =
  ( clear_log ()
  ; let
      val result = RQH.hull inputPts
    in
      ( result
      , SeqBasis.tabulate 1000 (0, Array.length log) (fn i =>
          Array.sub (log, i))
      )
    end
  )
val ref_total = total_num_calls ref_log
val _ = print ("hull size " ^ Int.toString (Seq.length ref_result) ^ "\n")
val _ = print ("total num calls to tri_area " ^ Int.toString ref_total ^ "\n")


val _ = print
  ("==================== running MyQuickhull ====================\n")
val (my_result, my_log) =
  ( clear_log ()
  ; let
      val result = MQH.hull inputPts
    in
      ( result
      , SeqBasis.tabulate 1000 (0, Array.length log) (fn i =>
          Array.sub (log, i))
      )
    end
  )

val my_total = total_num_calls my_log
val _ = print ("hull size " ^ Int.toString (Seq.length my_result) ^ "\n")
val _ = print ("total num calls to tri_area " ^ Int.toString my_total ^ "\n")
val total_improvement =
  100.0 * (Real.fromInt my_total - Real.fromInt ref_total)
  / Real.fromInt ref_total
val _ = print
  ("percent fewer calls to tri_area: "
   ^ Real.fmt (StringCvt.FIX (SOME 2)) (~total_improvement) ^ "%\n")

val correct = Seq.equal op= (ref_result, my_result)
val _ = print
  ("correct (matches reference result)? " ^ (if correct then "yes" else "NO")
   ^ "\n")


val _ = print
  ("=============================================================\n")

(* val _ = print
  ("heaviest heatmap block " ^ Int.toString (max_num_calls ()) ^ "\n") *)

(* ==========================================================================
 * output result image
 *)

val t0 = Time.now ()

val img = Image.fresh
  { resolution = resolution
  , bottom_left = (minx - 0.1, miny - 0.1)
  , top_right = (maxx + 0.1, maxy + 0.1)
  , background = Color.white
  }

(* val num_points_in_block =
  SeqBasis.tabulate 1000 (0, log_width * log_height) (fn _ => 0)
val () = ForkJoin.parform (0, Seq.length inputPts) (fn i =>
  let
    val (rx, ry) = Seq.nth inputPts i
    val log_x =
      Real.floor (Real.fromInt resolution * (rx - minx) / (maxx - minx) + 0.5)
      div hmbs
    val log_y =
      Real.floor (Real.fromInt resolution * (ry - miny) / (maxy - miny) + 0.5)
      div hmbs
  in
    MLton.Parallel.arrayFetchAndAdd (log, log_y * log_height + log_x) 1;
    ()
  end) *)

(* val max_block_average =
  Parallel.reduce Real.max 0.0 (0, Array.length log) (fn i =>
    Real.fromInt (Array.sub (log, i))
    / Real.fromInt (Array.sub (num_points_in_block, i))) *)

fun improvement_at i =
  Array.sub (ref_log, i) - Array.sub (my_log, i)

val best_improvement =
  Parallel.reduce Int.max 0 (0, Array.length ref_log) improvement_at
val worst_improvement =
  Parallel.reduce Int.min (valOf Int.maxInt) (0, Array.length ref_log)
    improvement_at

(* val _ = print ("best improvement: " ^ Int.toString best_improvement ^ "\n")
val _ = print ("worst improvement: " ^ Int.toString worst_improvement ^ "\n") *)

fun choose_color improvement =
  if improvement = 0 then
    Color.hsva {h = 0.0, s = 0.0, v = 0.0, a = 0.0}
  else if improvement > 0 then
    Color.hsva
      { h = 120.0
      , s = Real.min (1.0, 0.5 + Real.fromInt improvement / 10.0)
      , v = 1.0
      , a = Real.min (1.0, Real.fromInt improvement / 20.0)
      }
  else
    Color.hsva
      { h = 0.0
      , s = Real.min (1.0, 0.5 - Real.fromInt improvement / 10.0)
      , v = 1.0
      , a = Real.min (1.0, Real.fromInt (~improvement) / 20.0)
      }

(* let
  val range_size = best_improvement - worst_improvement
  val goodness = Real.fromInt improvement / 10.0
  (* val heaviness = Real.min (1.0, Real.fromInt count / Real.fromInt heaviest) *)

  val hue = 120.0 * (1.0 - heaviness)
  val sat = 0.5 + heaviness / 2.0
  val alpha = heaviness

  val color = Color.hsva {h = hue, s = sat, v = 1.0, a = alpha}
in
  color
end *)


(* val distinct_counts =
  let
    val cs = Mergesort.sort Int.compare (ArraySlice.full log)
    val ids = Parallel.filter (0, Seq.length cs) (fn i => i) (fn i =>
      i = 0 orelse Seq.nth cs i <> Seq.nth cs (i - 1))
  in
    Parallel.tabulate (0, Seq.length ids) (fn j =>
      let
        val i1 = Seq.nth ids j
        val i2 =
          if j = Seq.length ids - 1 then Seq.length cs else Seq.nth ids (j + 1)
      in
        (Seq.nth cs i1, i2 - i1)
      end)
  end

fun right_pad n str =
  let
    val rem = Int.max (n - String.size str, 0)
  in
    CharVector.tabulate (String.size str + rem, fn i =>
      if i < String.size str then String.sub (str, i) else #" ")
  end

val () = print "\n"
val h1 = "num tri_area calls"
val h2 = "num heatmap blocks"
val () = print (h1 ^ "    " ^ h2 ^ "\n")
val () = Util.for (0, Seq.length distinct_counts) (fn i =>
  let
    val (x, c) = Seq.nth distinct_counts i
  in
    if x = 0 then
      ()
    else
      let
        val color = choose_color x
        val {red, green, blue, ...} =
          Color.pixelToColor (Color.colorToPixel color)
        val color = TerminalColors.rgb {red = red, green = green, blue = blue}
      in
        TerminalColorString.print
          (TerminalColorString.append
             ( TerminalColorString.fromString
                 (right_pad (String.size h1) (Int.toString x) ^ "    "
                  ^ right_pad (String.size h2) (Int.toString c) ^ " ")
             , TerminalColorString.background color
                 (TerminalColorString.fromString "    ")
             ));
        print "\n"
      end
  end)
val () = print "\n" *)

val _ = ForkJoin.parform (0, log_height) (fn log_y =>
  ForkJoin.parform (0, log_width) (fn log_x =>
    let
      val i = log_y * log_height + log_x
      (* val npts = Array.sub (num_points_in_block, i) *)
      val count = Array.sub (log, i)
      val improvement = improvement_at i
    in
      if count = 0 orelse improvement = 0 then
        ()
      else
        let
          val color = Color.colorToPixel (choose_color improvement)

          val x0 =
            rhmbs * Real.fromInt log_x * (maxx - minx) / Real.fromInt resolution
            + minx
          val x1 =
            rhmbs * Real.fromInt (log_x + 1) * (maxx - minx)
            / Real.fromInt resolution + minx
          val y0 =
            rhmbs * Real.fromInt log_y * (maxy - miny) / Real.fromInt resolution
            + miny
          val y1 =
            rhmbs * Real.fromInt (log_y + 1) * (maxy - miny)
            / Real.fromInt resolution + miny

        (* val _ = print
          ("draw_box " ^ Real.toString x0 ^ " " ^ Real.toString y0 ^ " "
           ^ Real.toString x1 ^ " " ^ Real.toString y1 ^ " "
           ^ Real.toString opacity ^ "\n"); *)
        in
          Image.draw_box img {color = color} (x0, y0) (x1, y1)
        end
    end))

val gray = {red = 0w100, green = 0w100, blue = 0w100}
val _ = Image.draw_line img {color = gray} (0.0, miny - 0.1) (0.0, maxy + 0.1)
val _ = Image.draw_line img {color = gray} (minx - 0.1, 0.0) (maxx + 0.1, 0.0)


val _ = ForkJoin.parform (0, Seq.length ref_result) (fn i =>
  let
    val a = Seq.nth inputPts (Seq.nth ref_result i)
    val b = Seq.nth inputPts (Seq.nth ref_result
      ((i + 1) mod Seq.length ref_result))
  in
    Image.draw_line img {color = Color.blue} a b
  end)

val _ = ForkJoin.parform (0, Seq.length my_result) (fn i =>
  let
    val a = Seq.nth inputPts (Seq.nth my_result i)
    val b = Seq.nth inputPts (Seq.nth my_result
      ((i + 1) mod Seq.length my_result))
    val purple: Color.pixel = {red = 0w255, green = 0w0, blue = 0w255}
  in
    Image.draw_line img {color = purple} a b
  end)

val _ = ForkJoin.parform (0, Seq.length inputPts) (fn i =>
  Image.draw_point img {color = Color.black} (Seq.nth inputPts i))

val t1 = Time.now ()

val _ = print ("generated image in " ^ Time.fmt 4 (Time.- (t1, t0)) ^ "s\n")

val (_, tm) = Util.getTime (fn _ => Image.write_to_file outfile img)
val _ = print ("wrote to " ^ outfile ^ " in " ^ Time.fmt 4 tm ^ "s\n")
