import React, { ReactNode, useCallback, useState } from "react";
import { extent, NumberValue, bisectCenter } from "d3";
import { GridRows, GridColumns } from "@visx/grid";
import { Group } from "@visx/group";
import { LinePath, Line } from "@visx/shape";
import { GlyphCircle } from "@visx/glyph";
import { scaleLinear, scaleOrdinal } from "@visx/scale";
import { AxisLeft, AxisBottom } from "@visx/axis";
import { Text } from "@visx/text";
import { localPoint } from "@visx/event";
import { useTooltip, TooltipWithBounds, defaultStyles } from "@visx/tooltip";
import { LegendOrdinal, LegendLabel, LegendItem } from "@visx/legend";
import { EventType } from "@visx/event/lib/types";

const legendGlyphSize = 10;

const legendHeight = 24;

export interface Line {
  color: string;
  segments: { x: number; y: number }[][];
  label: string;
  strokeWidth?: number;
  strokeDashArray?: string;
  marker?: number;
  opacity?: number;
}

const DEFAULT_MARGIN = { top: 20, right: 10, bottom: 40, left: 40 };

export function LineChart(props: {
  height: number;
  width: number;
  lines: Line[];
  tooltip?: (data: {
    data: { color: string; x: number; y: number; label: string }[];
    xVal: number;
    yVal: number;
  }) => ReactNode;
  dragTooltip?: (data: {
    data: { color: string; x: number; y: number; label: string }[];
    data2: { color: string; x: number; y: number; label: string }[];
    xVal: number;
    xVal2: number;
    yVal: number;
  }) => ReactNode;
  handleClick?: (label: string, xVal: number) => void;
  xTicks?: number[];
  yTicks?: number[];
  yDomain?: number[];
  yDirection?: "asc" | "desc";
  numXTicks?: number;
  xTickFormat?: (val: NumberValue) => string;
  yTickFormat?: (val: NumberValue) => string;
  yTickPadding?: number;
  title?: string;
  xAxisLabel?: string;
  yAxisLabel?: string;
  showLegend: boolean;
  referenceLines?: { value: number; color?: string; label?: string }[];
  referenceRanges?: {
    y1: number;
    y2: number;
    color: string;
    opacity: number;
  }[];
  margin?: { top?: number; right?: number; bottom?: number; left?: number };
  hideAxis?: boolean;
  hideGrid?: boolean;
  labelsForLegend?: string[];
}) {
  const {
    height,
    width,
    lines,
    tooltip,
    dragTooltip,
    handleClick,
    xTicks,
    yTicks,
    yDomain,
    numXTicks,
    xTickFormat,
    yTickFormat,
    yTickPadding,
    title,
    xAxisLabel,
    yAxisLabel,
    showLegend,
    referenceLines,
    referenceRanges,
    hideAxis,
    hideGrid,
    labelsForLegend,
  } = props;
  const {
    tooltipData,
    tooltipLeft = 0,
    tooltipTop = 0,
    showTooltip,
    hideTooltip,
  } = useTooltip<{
    data: { color: string; x: number; y: number; label: string }[];
    xVal: number;
    yVal: number;
    // Below are only used in drag variant of tooltip.
    data2: { color: string; x: number; y: number; label: string }[];
    xVal2: number;
  }>();

  const [mouseDownPoint, setMouseDownPoint] = useState<{
    x: number;
    y: number;
  } | null>(null);

  // Values at the top of the graph should be larger and desc as you read down.
  const yDirection = props.yDirection || "desc";

  const margin = !props.margin
    ? DEFAULT_MARGIN
    : Object.assign({}, DEFAULT_MARGIN, props.margin);

  const innerWidth = width - margin.left - margin.right;
  const innerHeight =
    height - margin.top - margin.bottom - (showLegend ? legendHeight : 0);

  const allData = lines.map((l) => l.segments.flat()).flat();
  const uniqueXVals = new Set(allData.map((a) => a.x));

  // Make sure our domain fits all points and ticks.
  const allXData = allData.map((d) => d.x).concat(xTicks ? xTicks : []);
  const xScale = scaleLinear<number>({
    domain: extent(allXData) as [number, number],
    range: [0, innerWidth],
  });

  // Make sure our domain fits all points, ticks, and reference ranges. If a
  // yDomain is provided just use that instead.
  const allYData = yDomain
    ? yDomain
    : allData
        .map((d) => d.y)
        .concat(yTicks ? yTicks : [])
        .concat(
          referenceRanges
            ? referenceRanges.map((rr) => [rr.y1, rr.y2]).flat()
            : []
        )
        .concat(referenceLines ? referenceLines.map((rl) => rl.value) : []);

  const yExtent = extent(allYData) as [number, number];

  const yExtentSize = yExtent[1] - yExtent[0];

  // If user provides some yTickPadding apply that here.
  const padding = yTickPadding || 0;
  const paddedYExtent = [
    yExtent[yDirection === "desc" ? 0 : 1] - yExtentSize * padding,
    yExtent[yDirection === "desc" ? 1 : 0] + yExtentSize * padding,
  ];

  const yScale = scaleLinear<number>({
    domain: paddedYExtent,
    range: [innerHeight, 0],
  });

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    color: "black",
  };

  const handleMouseDown = useCallback((event: EventType) => {
    const { x, y } = localPoint(event) || { x: 0, y: 0 };
    setMouseDownPoint({ x, y });
  }, []);

  const handleMouseUp = useCallback(
    (event: EventType) => {
      setMouseDownPoint(null);
      const { x, y } = localPoint(event) || { x: 0, y: 0 };

      // Skip squaring/square rooting b/c we don't need the *real* distance.
      const distFromMouseDown =
        Math.abs((mouseDownPoint?.x || 0) - x) +
        Math.abs((mouseDownPoint?.y || 0) - y);
      if (!handleClick || distFromMouseDown > 10) return;

      const yVal = yScale.invert(y - margin.top);
      const xVal = xScale.invert(x - margin.left);
      let lineIdx = 0;
      let dist = Number.MAX_VALUE;
      let bestX = 0;

      for (let i = 0; i < lines.length; i++) {
        const line = lines[i];
        if (!line) continue;
        const allSegments = line.segments.flat().sort((a, b) => a.x - b.x);

        // Make sure the target X value is within range for this line. If it
        // isn't then we shouldn't consider it.
        const [minX, maxX] = extent(allSegments.map((d) => d.x));
        if (minX === undefined || maxX === undefined) continue;
        if (minX < minX || xVal > maxX) continue;

        const idx = bisectCenter(
          allSegments.map((d) => d.x),
          xVal
        );
        const segmentAtIdx = allSegments[idx];
        if (segmentAtIdx && dist > Math.abs(segmentAtIdx.y - yVal)) {
          lineIdx = i;
          dist = Math.abs(segmentAtIdx.y - yVal);
          bestX = segmentAtIdx.x;
        }
      }
      const lineAtLineIdx = lines[lineIdx];
      if (lineAtLineIdx) {
        handleClick(lineAtLineIdx.label, bestX);
      }
    },
    [
      handleClick,
      lines,
      margin.left,
      margin.top,
      mouseDownPoint?.x,
      mouseDownPoint?.y,
      xScale,
      yScale,
    ]
  );

  const handleTooltip = useCallback(
    (event: EventType) => {
      const { x, y } = localPoint(event) || { x: 0, y: 0 };
      if (dragTooltip && mouseDownPoint) {
        // Make sure nothing is selected.
        const documentSelection = document.getSelection();
        if (documentSelection) {
          documentSelection.empty();
        }
        const windowSelection = window.getSelection();
        if (windowSelection) {
          windowSelection.removeAllRanges();
        }

        const x1 = xScale.invert(x - margin.left);
        const x2 = xScale.invert(mouseDownPoint.x - margin.left);
        // Make sure that xVal2 and data2 are always the most right point portion
        // of the drag so we don't have to figure it out in callers of this.
        const xVal = Math.min(x1, x2);
        const xVal2 = Math.max(x1, x2);
        const yVal = yScale.invert(y - margin.top);
        const data = getTooltipDataForLines(lines, xVal);
        const data2 = getTooltipDataForLines(lines, xVal2);
        showTooltip({
          tooltipData: { data, data2, xVal, xVal2, yVal },
          tooltipLeft: x,
          tooltipTop: y,
        });
      } else if (tooltip) {
        const xVal = xScale.invert(x - margin.left);
        const yVal = yScale.invert(y - margin.top);
        const data = getTooltipDataForLines(lines, xVal);
        showTooltip({
          tooltipData: { data, xVal, yVal, data2: [], xVal2: 0 },
          tooltipLeft: x,
          tooltipTop: y,
        });
      }
    },
    [
      tooltip,
      dragTooltip,
      mouseDownPoint,
      xScale,
      margin.left,
      margin.top,
      yScale,
      lines,
      showTooltip,
    ]
  );

  const legendScale = scaleOrdinal({
    domain: lines.map((l) => l.label),
    range: lines.map((l) => l.color),
  });

  const showAxis = !hideAxis;
  const showGrid = !hideGrid;

  if (height === 0 || width === 0) return null;

  return (
    <div>
      <svg width={width} height={height - (showLegend ? legendHeight : 0)}>
        <Group left={margin.left} top={margin.top}>
          {showGrid && (
            <GridRows
              scale={yScale}
              width={innerWidth}
              height={innerHeight - margin.top}
              stroke="#EDF2F7"
              strokeOpacity={1}
            />
          )}
          {showGrid && (
            <GridColumns
              scale={xScale}
              width={innerWidth}
              height={innerHeight}
              stroke="#EDF2F7"
              strokeOpacity={1}
            />
          )}
          {referenceRanges &&
            referenceRanges.map((rr, i) => (
              <rect
                key={i}
                fill={rr.color}
                fillOpacity={rr.opacity}
                x={0}
                y={yScale(rr.y2)}
                width={innerWidth}
                height={yScale(rr.y1) - yScale(rr.y2)}
              />
            ))}
          {referenceLines &&
            referenceLines.map((rl, i) => (
              <Group key={i}>
                {rl.label && (
                  <Text
                    fontSize={10}
                    dx={innerWidth}
                    // The + 2 center aligns the text to the line.
                    dy={yScale(rl.value) + 2}
                    fill={rl.color}
                  >
                    {rl.label}
                  </Text>
                )}
                <Line
                  from={{ x: 0, y: yScale(rl.value) }}
                  to={{ x: innerWidth, y: yScale(rl.value) }}
                  stroke={rl.color || "black"}
                  opacity={0.5}
                  strokeDasharray={"3 3"}
                />
              </Group>
            ))}
          {lines.map((line, i) => {
            return line.segments.map((segment, j) => {
              return (
                <LinePath
                  key={i + " " + j}
                  onClick={(evt) => {
                    handleClick &&
                      handleClick(
                        line.label,
                        xScale.invert((localPoint(evt)?.x || 0) - margin.left)
                      );
                  }}
                  data={
                    yDomain
                      ? segment.map((d) => {
                          return {
                            x: d.x,
                            y: Math.min(
                              Math.max(d.y, yDomain[0] || 0),
                              yDomain[1] || 0
                            ),
                          };
                        })
                      : segment
                  }
                  x={(d) => xScale(d.x)}
                  y={(d) => yScale(d.y)}
                  stroke={line.color}
                  strokeWidth={line.strokeWidth || 1}
                  strokeDasharray={line.strokeDashArray}
                  opacity={line.opacity || 1}
                />
              );
            });
          })}
          {lines.map((l, i) => {
            return (
              l.marker &&
              l.segments.flat().find((d) => d.x === l.marker) && (
                <GlyphCircle
                  key={i}
                  left={xScale(l.marker)}
                  top={yScale(
                    l.segments.flat().find((d) => d.x === l.marker)?.y || 0
                  )}
                  size={50}
                  fill={l.color}
                  strokeWidth={2}
                />
              )
            );
          })}
          {showAxis && (
            <AxisLeft
              scale={yScale}
              label={yAxisLabel}
              tickFormat={yTickFormat}
              tickValues={yTicks}
              numTicks={yTicks && yTicks.length}
            />
          )}
          {showAxis && (
            <AxisBottom
              scale={xScale}
              tickValues={xTicks}
              numTicks={
                (xTicks && xTicks.length) || numXTicks || uniqueXVals.size
              }
              top={innerHeight}
              tickFormat={xTickFormat}
              label={xAxisLabel}
            />
          )}
          {title && (
            <Text dx={10} style={{ fontWeight: "bold" }}>
              {title}
            </Text>
          )}
          {dragTooltip && mouseDownPoint && (
            <g>
              <Line
                from={{ x: mouseDownPoint.x - margin.left, y: 0 }}
                to={{ x: mouseDownPoint.x - margin.left, y: innerHeight }}
                stroke={"gray"}
                strokeDasharray={"3 3"}
                pointerEvents="none"
              />
            </g>
          )}
          {tooltipData && (
            <g>
              <Line
                from={{ x: tooltipLeft - margin.left, y: 0 }}
                to={{ x: tooltipLeft - margin.left, y: innerHeight }}
                stroke={"gray"}
                strokeDasharray={"3 3"}
                pointerEvents="none"
              />
            </g>
          )}
          <rect
            x={0}
            y={0}
            width={innerWidth}
            height={innerHeight}
            onMouseDown={handleMouseDown}
            onMouseUp={handleMouseUp}
            onTouchStart={handleTooltip}
            fill={"transparent"}
            onTouchMove={handleTooltip}
            onMouseMove={handleTooltip}
            onMouseLeave={() => hideTooltip()}
          />
        </Group>
      </svg>
      {showLegend && (
        <LegendOrdinal scale={legendScale}>
          {(labels) => (
            <div
              style={{
                height: legendHeight,
                display: "flex",
                flexDirection: "row",
                justifyContent: "center",
              }}
            >
              {labels
                .filter(
                  (l) => !labelsForLegend || labelsForLegend.includes(l.text)
                )
                .map((label, i) => (
                  <LegendItem key={i} margin="0 4px">
                    <svg width={legendGlyphSize} height={legendGlyphSize}>
                      <rect
                        fill={label.value}
                        width={legendGlyphSize}
                        height={legendGlyphSize}
                      />
                    </svg>
                    <LegendLabel align="left" margin="0 0 0 4px">
                      {label.text}
                    </LegendLabel>
                  </LegendItem>
                ))}
            </div>
          )}
        </LegendOrdinal>
      )}
      {tooltip && tooltipData && (tooltipData.data[0] || tooltipData.data2[0]) && (
        <TooltipWithBounds
          top={tooltipTop}
          left={tooltipLeft}
          style={tooltipStyles}
        >
          {mouseDownPoint !== null && dragTooltip
            ? dragTooltip(tooltipData)
            : tooltip(tooltipData)}
        </TooltipWithBounds>
      )}
    </div>
  );
}

function getTooltipDataForLines(lines: Line[], xVal: number) {
  const data: { color: string; label: string; x: number; y: number }[] = [];

  for (const line of lines) {
    const segmentContainingXVal = line.segments.find(
      (s) => s.some((pt) => pt.x >= xVal) && s.some((pt) => pt.x <= xVal)
    );

    if (!segmentContainingXVal) {
      continue;
    }

    const point = [...segmentContainingXVal].sort(
      (a, b) => Math.abs(a.x - xVal) - Math.abs(b.x - xVal)
    )[0];

    if (!point) {
      continue;
    }

    data.push({
      color: line.color,
      label: line.label,
      y: point.y,
      x: point.x,
    });
  }
  return data;
}
