import React, { ReactNode, useCallback, useMemo } from "react";
import { Group } from "@visx/group";
import { AxisBottom } from "@visx/axis";
import { scaleLinear } from "@visx/scale";
import { extent } from "d3-array";
import { Circle } from "@visx/shape";
import { GridColumns } from "@visx/grid";
import { localPoint } from "@visx/event";
import { useTooltip, TooltipWithBounds, defaultStyles } from "@visx/tooltip";
import { BoxPlot } from "@visx/stats";
import { EventType } from "@visx/event/lib/types";

import { percentile } from "../../util/Util";

const margin = { top: 40, right: 20, bottom: 10, left: 20 };

const circleRadius = 4;

export function SwarmChart(props: {
  height: number;
  width: number;
  label: string;
  data: { value: number; stdErr?: number; label: string; color: string }[];
  format: (value: number | null) => string;
  tooltip: (data: {
    value: number;
    stdErr?: number;
    label: string;
  }) => ReactNode;
}) {
  const { height, width, label, data, format, tooltip } = props;

  const innerHeight = height - (margin.top + margin.bottom);
  const innerWidth = width - (margin.left + margin.right);

  const xScale = scaleLinear({
    range: [margin.left, width - margin.right],
    domain: extent(data, (d) => d.value) as [number, number],
    nice: true,
  });

  // Figure out where to put each point such that they don't overlap. This can
  // get expensize so we memoize it.
  const points = useMemo(() => {
    const drawnPoints: {
      x: number;
      y: number;
      value: number;
      label: string;
      color: string;
      stdErr?: number;
    }[] = [];
    for (const d of data) {
      const cy = innerHeight / 2;
      const cx = xScale(d.value);
      let skip = false;

      let t1 = 0;
      while (
        skip === false &&
        collidesWithPt(cx, cy + t1 * circleRadius * 2, drawnPoints)
      ) {
        t1++;
        // If its not gonna fit on the graph just skip this point.
        if (t1 * circleRadius * 2 > innerHeight / 2) {
          skip = true;
        }
      }

      let t2 = 0;
      while (
        skip === false &&
        collidesWithPt(cx, cy - t2 * circleRadius * 2, drawnPoints)
      ) {
        t2++;
        // If its not gonna fit on the graph just skip this point.
        if (t2 * circleRadius * 2 > innerHeight / 2) {
          skip = true;
        }
      }

      const newY =
        t1 < t2 ? cy + t1 * circleRadius * 2 : cy - t2 * circleRadius * 2;

      if (!skip) {
        drawnPoints.push({
          x: cx,
          y: newY,
          value: d.value,
          label: d.label,
          color: d.color,
          stdErr: d.stdErr,
        });
      }
    }
    return drawnPoints;
  }, [data, xScale, innerWidth, innerHeight]);

  const percentiles = useMemo(() => {
    const values = data.map((d) => d.value).sort((a, b) => a - b);
    return [
      percentile(values, 0.05),
      percentile(values, 0.25),
      percentile(values, 0.5),
      percentile(values, 0.75),
      percentile(values, 0.95),
    ];
  }, [data]);

  const {
    tooltipData,
    tooltipLeft = 0,
    tooltipTop = 0,
    showTooltip,
    hideTooltip,
  } = useTooltip<{
    value: number;
    stdErr?: number;
    label: string;
  }>();

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    color: "black",
  };

  const boxPlotHeight = innerHeight * 0.35;

  const handleTooltip = useCallback(
    (event: EventType) => {
      const { x, y } = localPoint(event) || { x: 0, y: 0 };
      for (const point of points) {
        if (
          Math.abs(point.x - x) < circleRadius &&
          Math.abs(point.y - y) < circleRadius
        ) {
          showTooltip({
            tooltipData: {
              value: point.value,
              stdErr: point.stdErr,
              label: point.label,
            },
            tooltipLeft: point.x,
            tooltipTop: point.y,
          });
          return;
        }
      }
      hideTooltip();
    },
    [data, xScale]
  );

  return (
    <div style={{ position: "relative" }}>
      <svg width={width} height={height}>
        <AxisBottom
          scale={xScale}
          top={innerHeight}
          labelOffset={20}
          label={label}
          labelProps={{ fill: "black", opacity: 0.6, textAnchor: "middle" }}
          hideTicks={true}
          hideAxisLine={true}
          tickLabelProps={() => ({
            fill: "black",
            opacity: 0.4,
            textAnchor: "middle",
          })}
          tickFormat={(val) => format(val as number)}
        />
        <GridColumns
          scale={xScale}
          width={innerWidth}
          height={innerHeight}
          stroke="black"
          strokeOpacity={0.2}
        />
        <BoxPlot
          boxWidth={boxPlotHeight}
          top={innerHeight / 2 - boxPlotHeight / 2}
          min={percentiles[0]}
          valueScale={xScale}
          firstQuartile={percentiles[1]}
          median={percentiles[2]}
          thirdQuartile={percentiles[3]}
          max={percentiles[4]}
          stroke="rgba(0,0,0,.8)"
          fill="transparent"
          horizontal={true}
        />
        <Group x={margin.left}>
          {points.map((d, i) => {
            return (
              <Circle
                key={`circle-${i}`}
                cx={d.x}
                cy={d.y}
                r={circleRadius}
                fill={d.color}
              />
            );
          })}
          <rect
            x={margin.left}
            width={innerWidth}
            height={innerHeight}
            onTouchStart={handleTooltip}
            fill={"transparent"}
            onTouchMove={handleTooltip}
            onMouseMove={handleTooltip}
            onMouseLeave={() => hideTooltip()}
          />
        </Group>
      </svg>
      {tooltipData && (
        <TooltipWithBounds
          top={tooltipTop}
          left={tooltipLeft}
          style={tooltipStyles}
        >
          {tooltip(tooltipData)}
        </TooltipWithBounds>
      )}
      {points.length !== data.length && (
        <div style={{ textAlign: "center", color: "#888" }}>
          (Too many points had the same value and were not able to fit on the
          chart)
        </div>
      )}
    </div>
  );
}

const collidesWithPt = (
  x: number,
  y: number,
  drawnPoints: { x: number; y: number }[]
) => {
  for (const pt of drawnPoints) {
    if (Math.abs(pt.x - x) < circleRadius * 2) {
      if (Math.abs(pt.y - y) < circleRadius * 2) {
        return true;
      }
    }
  }
  return false;
};
