import React, {
  useRef,
  useState,
  useEffect,
  useCallback,
  useMemo,
} from "react";
import { Group } from "@visx/group";
import { GridRows } from "@visx/grid";
import { useTooltip, useTooltipInPortal } from "@visx/tooltip";
import { scaleBand, scaleLinear, scaleOrdinal } from "@visx/scale";
import { ParentSize } from "@visx/responsive";
import { waitForElementToExist } from "@figmentjs/utils";
import { theme } from "../../../../../theme/tailwind.config";
import {
  margin,
  tooltipStyles,
  barChartColorRange,
  bucketKey,
} from "../../utils/constants";
import { AxisLeft, AxisBottom, LEFT_LABEL_OFFSET } from "../index";
import { Legend } from "../legend";
import { ContainerProps } from "./container.types";
import { DisplayedKey } from "../../types";

const {
  colors: { basic },
} = theme;

const MARGIN_LEFT_PADDING = 10;

const Container = <DataPoint, Key extends keyof any>({
  keys,
  bucketedKeys,
  xTickFormatter,
  yTickFormatter,
  numXTicks,
  numYTicks,
  renderTooltipChildren,
  children,
  data,
  filterFunc,
  xDomainFunc,
  yDomainFunc,
  xPadding,
  showLegend = true,
  customColors,
  customLegendLabels,
  enableFiltering,
}: ContainerProps<DataPoint, Key>): React.ReactElement => {
  const tooltipRef = useRef<HTMLDivElement>(null);
  const containerRef = useRef<HTMLDivElement>(null);
  const [marginLeft, setMarginLeft] = useState(margin.left);

  const {
    tooltipOpen,
    tooltipLeft,
    tooltipTop,
    tooltipData,
    hideTooltip,
    showTooltip,
  } = useTooltip<DataPoint>();
  const { containerRef: tooltipContainerRef, TooltipInPortal } =
    useTooltipInPortal({
      detectBounds: true,
      scroll: true,
    });

  useEffect(() => {
    if (!containerRef?.current) {
      return;
    }

    waitForElementToExist(containerRef.current, ".axis-left").then(
      (leftAxis: Element) => {
        const width = leftAxis?.getBoundingClientRect()?.width;
        const extraWidth = LEFT_LABEL_OFFSET + MARGIN_LEFT_PADDING;

        // Due to a @visx bug in Firefox where the left axis width is too big,
        // we need to do a width check and use the width of the largest child.
        // If the left axis width is not too big, we can use it as-is.
        if (width) {
          const isWidthTooBig = width > containerRef.current!.clientWidth / 2;

          if (isWidthTooBig) {
            const childrenMaxWidth = Array.from(leftAxis.children).reduce(
              (total, current) =>
                Math.max(
                  total,
                  current.querySelector("text")?.getBoundingClientRect()
                    ?.width || 0
                ),
              0
            );

            if (childrenMaxWidth) {
              setMarginLeft(childrenMaxWidth + extraWidth);
            }
          } else {
            setMarginLeft(width + extraWidth);
          }
        }
      }
    );
  }, [containerRef, keys]);

  // 1. Initializing state
  const [filtered, setFiltered] = useState<DataPoint[]>(data);
  const initialFilterKeys = React.useMemo(
    () =>
      (keys as Key[]).reduce<{
        [key in Key]: boolean;
      }>((prev, curr) => {
        prev[curr] = true;
        return prev;
      }, {} as { [key in Key]: boolean }),
    [keys]
  );

  useEffect(() => {
    setFilterKeys(initialFilterKeys);
  }, [initialFilterKeys]);

  // 2. Handling filter keys
  const [filterKeys, setFilterKeys] =
    useState<{ [key in Key]: boolean }>(initialFilterKeys);

  const activeKeys = useMemo(
    () => keys.filter((key) => filterKeys[key as Key]),
    [filterKeys]
  );
  useEffect(() => {
    if (filterFunc) {
      const filtered = filterFunc(data, filterKeys);
      if (activeKeys.length === 0) {
        setFiltered([]);
      } else {
        setFiltered(filtered);
      }
    }
  }, [filterKeys, data, filterFunc]);
  const handleFilter = (key: Key) => {
    setFilterKeys((keys) => ({
      ...keys,
      [key]: !keys[key],
    }));
  };

  const getColorScale = () => {
    return scaleOrdinal<DisplayedKey<Key>, string>({
      domain: keys,
      range: customColors
        ? keys.map((color) => customColors[color])
        : barChartColorRange,
    });
  };

  // Memoize DomainFunc to prevent unnecessary re-renders
  const xDomainFuncMemoized = useCallback(xDomainFunc, [filtered]);

  const xScale = scaleBand<string>({
    domain: xDomainFuncMemoized(filtered),
    padding: xPadding,
  });

  // Memoize DomainFunc to prevent unnecessary re-renders
  const yDomainFuncMemoized = useCallback(yDomainFunc, [filtered, filterKeys]);

  const yScale = scaleLinear<number>({
    domain: yDomainFuncMemoized(filtered, activeKeys as Key[]),
    nice: true,
  });

  return (
    <div className="flex flex-1 h-80" ref={containerRef}>
      <div className="flex w-full h-full" ref={tooltipContainerRef}>
        <ParentSize debounceTime={50}>
          {({ width, height }) => {
            if (width < 1 || height < 1) {
              return null;
            }

            const xMax = width - marginLeft - margin.right;
            const yMax = height - margin.top - margin.bottom;

            xScale.range([0, xMax]);
            yScale.range([yMax, 0]);

            return (
              <div className="relative">
                {/* There's an inexplicable gap between the x-axis and the legend, and this height change accounts for it. */}
                <svg width={width} height={showLegend ? height - 16 : height}>
                  <rect
                    x={0}
                    y={0}
                    width={width}
                    height={height}
                    className="fill-white"
                    rx={14}
                  />

                  <Group left={marginLeft} top={margin.top}>
                    <GridRows
                      scale={yScale}
                      width={xMax}
                      stroke={basic["200"]}
                    />
                    {children({
                      showTooltip,
                      hideTooltip,
                      colorScale: getColorScale(),
                      marginLeft,
                      data: filtered,
                      xScale,
                      yScale,
                      activeKeys,
                    })}
                    <AxisLeft
                      scale={yScale}
                      tickFormat={yTickFormatter}
                      numTicks={numYTicks}
                    />
                    <AxisBottom
                      top={yMax}
                      scale={xScale}
                      tickFormat={xTickFormatter}
                      numTicks={numXTicks}
                    />
                  </Group>
                </svg>

                {showLegend && (
                  <Legend<Key>
                    colorScale={getColorScale}
                    filterKeys={filterKeys}
                    customLabels={customLegendLabels}
                    onClick={
                      enableFiltering
                        ? (label) => handleFilter(label.datum as Key)
                        : undefined
                    }
                  />
                )}

                {tooltipOpen && tooltipData && (
                  <TooltipInPortal
                    key={Math.random()}
                    top={tooltipTop}
                    left={tooltipLeft}
                    style={tooltipStyles}
                  >
                    <div ref={tooltipRef}>
                      {renderTooltipChildren({
                        tooltipData,
                        colorScale: getColorScale(),
                        bucketedKeys,
                        bucketKey,
                      })}
                    </div>
                  </TooltipInPortal>
                )}
              </div>
            );
          }}
        </ParentSize>
      </div>
    </div>
  );
};

export default Container;
