import first from 'lodash/first';
import noop from 'lodash/noop';
import {
  PropsWithChildren,
  createContext,
  useCallback,
  useContext,
  useMemo,
  useState,
} from 'react';
import {
  ChartMode,
  ScatterChartType,
  ScatterData,
} from '../../charts/providers/deck-gl/layers.util';
import {
  PanelTab,
  getUniqueLassoSelectionPoints,
} from '../../components/dataset-details/clusters/util';
import { useInputState } from '../../design-system/v2';
import { useQueryParams } from '../../hooks/useQueryParams';

export interface ChartViewState {
  target?: number[];
  rotationX?: number;
  rotationOrbit?: number;
  latitude?: number;
  longitude?: number;
  minZoom?: number;
  maxZoom?: number;
  zoom?: number;
}

export enum ColorByOptions {
  CLASSES = 'CLASSES',
  CLUSTERS = 'CLUSTERS',
  TOPICS = 'TOPICS',
  OUTLIERS = 'OUTLIERS',
}

export interface ClusterContext {
  hideInactivePoints: boolean;
  setHideInactivePoints: (val: boolean) => void;

  colorBy: ColorByOptions;
  setColorBy: (val: ColorByOptions) => void;

  colorMap: Record<string, string>;
  setColorMap: (map: Record<string, string>) => void;

  datasetId: string;

  embeddingId: string;
  setEmbeddingId: (val: ClusterContext['embeddingId']) => void;

  searchText: string[];
  setSearchText: (val: string[]) => void;
  taskletId: string;
  setTaskletId: (val: string) => void;

  hasSelection: boolean;

  resetSelection: () => void;
  undoSelection: () => void;

  selectedClasses: string[];
  setSelectedClasses: (val: string[]) => void;

  selectedClusters: string[];
  setSelectedClusters: (val: string[]) => void;

  selectedTopics: string[];
  setSelectedTopics: (val: string[]) => void;

  selectedOutliers: string[];
  setSelectedOutliers: (val: string[]) => void;

  selectedSegments: string[];
  setSelectedSegments: (val: string[]) => void;

  // When value is undefined, no filtering is applied (all data is included).
  // A value of false will hide all mislabeled points; true will show only mislabeled points.
  showMislabelledData: boolean | undefined;
  setShowMislabelledData: (val: boolean | undefined) => void;

  lassoSelection: ScatterData[][];
  uniqLassoSelection: ScatterData[][];
  flatLassoSelection: ScatterData[];
  addLassoSelection: (val: ScatterData[]) => void;

  selectedSubsetIds: string[] | undefined;
  setSelectedSubsetIds: (val: string[]) => void;

  selectedSimilarityPoints?: string[];
  setSelectedSimilarityPoints: (val: string[] | undefined) => void;

  chartMode?: ChartMode;
  setChartMode: (mode: ChartMode) => void;

  scatterViewMode?: ScatterChartType;
  setScatterViewMode: (mode: ScatterChartType) => void;

  hoverPointId: string;
  setHoverPointId: (p: string) => void;

  tableSelectedPointId: string;
  setTableSelectedPointId: (p: string) => void;

  pointSize: number;
  setPointSize: (val: number) => void;
}

const defaultValue: ClusterContext = {
  hideInactivePoints: false,
  setHideInactivePoints: noop,

  colorBy: ColorByOptions.CLASSES,
  setColorBy: noop,

  colorMap: {},
  setColorMap: noop,

  datasetId: '',

  embeddingId: 'DEFAULT',
  setEmbeddingId: noop,

  searchText: [],
  setSearchText: noop,
  taskletId: '',
  setTaskletId: noop,

  hasSelection: false,

  resetSelection: noop,
  undoSelection: noop,

  selectedClasses: [],
  setSelectedClasses: noop,

  selectedClusters: [],
  setSelectedClusters: noop,

  selectedTopics: [],
  setSelectedTopics: noop,

  selectedOutliers: [],
  setSelectedOutliers: noop,

  selectedSegments: [],
  setSelectedSegments: noop,

  showMislabelledData: undefined,
  setShowMislabelledData: noop,

  lassoSelection: [],
  uniqLassoSelection: [],
  flatLassoSelection: [],
  addLassoSelection: noop,

  selectedSubsetIds: undefined,
  setSelectedSubsetIds: noop,

  selectedSimilarityPoints: undefined,
  setSelectedSimilarityPoints: noop,

  chartMode: ChartMode.SELECT_RECTANGLE,
  setChartMode: noop,

  scatterViewMode: ScatterChartType.SCATTER_2D,
  setScatterViewMode: noop,

  hoverPointId: '',
  setHoverPointId: noop,

  tableSelectedPointId: '',
  setTableSelectedPointId: noop,

  pointSize: 6,
  setPointSize: noop,
};

const DatasetClusterContext = createContext<ClusterContext>(defaultValue);

export const useDatasetClusterContext = () => useContext(DatasetClusterContext);

interface DatasetClusterProviderProps {
  datasetId?: string;
}

export const DatasetClusterProvider = ({
  children,
  datasetId = '',
}: PropsWithChildren<DatasetClusterProviderProps>) => {
  const queryParams = useQueryParams();

  const [hideInactivePoints, setHideInactivePoints] = useState(false);
  const [colorBy, setColorBy] = useState<ColorByOptions>(defaultValue.colorBy);
  const [colorMap, setColorMap] = useState<Record<string, string>>(defaultValue.colorMap);
  const [embeddingId, setEmbeddingId] = useState(defaultValue.embeddingId);
  const [searchText, setSearchText] = useInputState<string[]>(defaultValue.searchText);
  const [taskletId, setTaskletId] = useState(defaultValue.taskletId);
  const [selectedClasses, setSelectedClasses] = useState<string[]>(defaultValue.selectedClasses);
  const [selectedClusters, setSelectedClusters] = useState<string[]>(defaultValue.selectedClusters);
  const [selectedTopics, setSelectedTopics] = useState<string[]>(defaultValue.selectedTopics);
  const [selectedOutliers, setSelectedOutliers] = useState<string[]>(defaultValue.selectedOutliers);
  const [selectedSegments, setSelectedSegments] = useState<string[]>(defaultValue.selectedSegments);
  const [showMislabelledData, setShowMislabelledData] = useState<boolean | undefined>(
    defaultValue.showMislabelledData,
  );

  const [lassoSelection, setLassoSelection] = useState<ScatterData[][]>([]);
  const [selectedSubsetIds, setSelectedSubsetIds] = useState<string[] | undefined>(
    defaultValue.selectedSubsetIds,
  );
  const [selectedSimilarityPoints, setSelectedSimilarityPoints] = useState<string[] | undefined>(
    defaultValue.selectedSimilarityPoints,
  );
  const [scatterViewMode, setScatterViewMode] = useState<ScatterChartType>(
    defaultValue.scatterViewMode as ScatterChartType,
  );

  // Active panel tab impacts the values of PanelMode and ChartMode
  const defaultPanelTab = (queryParams.get('panelTab') as PanelTab | null) ?? PanelTab.EXPLORE;
  const [chartMode, setChartMode] = useState<ChartMode>(
    defaultPanelTab === PanelTab.VIEW_SUBSETS
      ? ChartMode.PAN
      : (defaultValue.chartMode as ChartMode),
  );

  const [hoverPointId, setHoverPointId] = useState(defaultValue.hoverPointId);
  const [tableSelectedPointId, setTableSelectedPointId] = useState<string>(
    defaultValue.tableSelectedPointId,
  );
  const [pointSize, setPointSize] = useState(defaultValue.pointSize);

  const uniqLassoSelection = useMemo(
    () => getUniqueLassoSelectionPoints(lassoSelection),
    [lassoSelection],
  );

  const flatLassoSelection = useMemo(() => uniqLassoSelection.flat(1), [uniqLassoSelection]);

  const addLassoSelection = useCallback((newVal: ScatterData[]) => {
    setLassoSelection(prevSelection => prevSelection.concat([newVal]));
  }, []);

  const hasSelection =
    (lassoSelection.length > 0 && (first(lassoSelection)?.length ?? 0) > 0) ||
    selectedClasses.length > 0 ||
    selectedClusters.length > 0 ||
    selectedOutliers.length > 0 ||
    selectedSegments.length > 0 ||
    selectedTopics.length > 0 ||
    showMislabelledData !== undefined;

  const undoSelection = () => {
    if (lassoSelection.length > 0) {
      setLassoSelection(lassoSelection.slice(0, -1));
    }
  };

  const resetSelection = () => {
    setSelectedClasses([]);
    setSelectedClusters([]);
    setSelectedSegments([]);
    setLassoSelection([]);
    setSelectedSubsetIds([]);
    setSelectedSimilarityPoints(undefined);
    setShowMislabelledData(undefined);
  };

  return (
    <DatasetClusterContext.Provider
      value={{
        hideInactivePoints,
        setHideInactivePoints,

        colorBy,
        setColorBy,

        colorMap,
        setColorMap,

        datasetId,

        embeddingId,
        setEmbeddingId,

        searchText,
        setSearchText,
        taskletId,
        setTaskletId,

        hasSelection,

        resetSelection,
        undoSelection,

        selectedClasses,
        setSelectedClasses: (val: string[]) => setSelectedClasses(val),

        selectedClusters,
        setSelectedClusters: (val: string[]) => setSelectedClusters(val),

        selectedTopics,
        setSelectedTopics: (val: string[]) => setSelectedTopics(val),

        selectedOutliers,
        setSelectedOutliers: (val: string[]) => setSelectedOutliers(val),

        selectedSegments,
        setSelectedSegments: (val: string[]) => setSelectedSegments(val),

        showMislabelledData,
        setShowMislabelledData,

        lassoSelection,
        uniqLassoSelection,
        flatLassoSelection,
        addLassoSelection,

        selectedSimilarityPoints,
        setSelectedSimilarityPoints,

        chartMode,
        setChartMode: (mode: ChartMode) => setChartMode(mode),

        scatterViewMode,
        setScatterViewMode: (mode: ScatterChartType) => setScatterViewMode(mode),

        hoverPointId,
        setHoverPointId,

        tableSelectedPointId,
        setTableSelectedPointId,

        selectedSubsetIds,
        setSelectedSubsetIds,

        pointSize,
        setPointSize: (val: number) => setPointSize(val),
      }}
    >
      {children}
    </DatasetClusterContext.Provider>
  );
};
