import { useQueries, useQuery } from '@tanstack/react-query';
import { AxiosError, AxiosResponse } from 'axios';
import isEmpty from 'lodash/isEmpty';
import { useAppMetadata } from '../contexts/app-metadata/AppMetadata';
import {
  ClassificationMetricsInfo,
  ComputedMetricsForRecording,
  EvaluationsClassLevelMetricsResponse,
  GetComputedMetricsForRecordingResponse,
  MultipleEvaluationsClassificationMetricsAllClasses,
} from '../generated/api';
import { evaluationsApi, recordingsApi } from '../lib/api';
import { capitalize } from '../lib/ui';
import { useChartQuery } from './charts-query-wrapper';
import { EVALUATION } from './queryConstants';

export interface IMetricDetail {
  name: string;
  value: number;
  description: string;
}

export interface IMetricsSummary {
  distributions: {
    actual: Record<string, number>;
    model: Record<string, number>;
  };
  averageMetrics: IMetricDetail[];
}

// NOTE: Metric descriptions will eventually come from the backend; we need to refactor
// the schema for this endpoint in general (MKV-243 in JIRA)
type MetricName = 'precision' | 'recall' | 'f1-score';
const metricDescriptionMap: Record<MetricName, string> = {
  precision:
    'Precision quantifies the number of positive class predictions that actually belong to the positive class',
  recall:
    'Recall quantifies the number of positive class predictions made out of all positive examples in the dataset',
  'f1-score':
    'F-Score provides a single score that balances both the concerns of precision and recall in one number',
};

const metricSummarySelector = (
  data: AxiosResponse<GetComputedMetricsForRecordingResponse>,
): IMetricsSummary | undefined => {
  const response = data.data.response as ComputedMetricsForRecording;
  if (isEmpty(response)) {
    return;
  }

  const { metrics, distributions } = response;
  const avgMetrics: Record<MetricName, number> = metrics.weightedAvg;
  const averageMetrics = (
    Object.keys(avgMetrics)
      // filter out "support" from reported metrics (for demo)
      .filter(key => key !== 'support') as Array<MetricName>
  ).map(key => ({
    name: capitalize(key),
    value: avgMetrics[key],
    description: metricDescriptionMap[key] || '',
  }));

  // Add accuracy metric for display
  // TODO: Get description from the BE
  averageMetrics.push({
    name: 'Accuracy',
    value: metrics.accuracy,
    description:
      'Accuracy quantifies the number of correct predictions out of the total number of predictions',
  });

  return {
    distributions: {
      actual: distributions.trueDistribution,
      model: distributions.predictedDistribution,
    },
    averageMetrics,
  };
};

export const useMetricsQuery = (recordingId: string) => {
  const { workspaceId } = useAppMetadata();

  return useQuery(
    [EVALUATION.GET_EVAL_METRICS, workspaceId, recordingId],
    () => recordingsApi.workspaceGetComputedMetricsForRecordingV1(workspaceId, recordingId),
    {
      enabled: Boolean(workspaceId && recordingId),
      select: metricSummarySelector,
    },
  );
};

export const useMetricsQueries = (recordingIds: string[]) => {
  const { workspaceId } = useAppMetadata();

  return useQueries({
    queries: recordingIds.map(recordingId => ({
      queryKey: [EVALUATION.GET_EVAL_METRICS, workspaceId, recordingId],
      queryFn: () =>
        recordingsApi.workspaceGetComputedMetricsForRecordingV1(workspaceId, recordingId),
      select: metricSummarySelector,
      enabled: Boolean(workspaceId && recordingId),
    })),
  });
};

export const useMultiEvalClassLevelMetricsQuery = (evaluationIds: string[]) => {
  const { workspaceId } = useAppMetadata();

  return useQuery<
    AxiosResponse<EvaluationsClassLevelMetricsResponse>,
    AxiosError,
    MultipleEvaluationsClassificationMetricsAllClasses
  >(
    [EVALUATION.GET_EVAL_METRICS, workspaceId, evaluationIds],
    () => evaluationsApi.getClassLevelMetricsForEvaluationsV1(workspaceId, evaluationIds),
    {
      select: res => res.data.response,
    },
  );
};

interface ClassLevelMetrics {
  className: string;
  metrics: IMetricDetail[];
}

interface ClassMetricsResult {
  metricsInfo: ClassificationMetricsInfo[];
  classLevelMetrics: ClassLevelMetrics[];
}

const classLevelMetricsSelector = (
  data: AxiosResponse<EvaluationsClassLevelMetricsResponse>,
): ClassMetricsResult => {
  const response = data.data.response;
  const { metricsPerClass, metricsInfo } = response;
  const classLevelMetrics = metricsPerClass.map(({ className, classLevelMetrics }) => {
    const clsMetrics = classLevelMetrics[0].metrics;
    const metrics = (
      Object.keys(clsMetrics)
        // filter out "support" from reported metrics (for demo)
        .filter(key => key !== 'support') as Array<MetricName>
    ).map(key => ({
      name: capitalize(key),
      value: clsMetrics[key],
      description: metricDescriptionMap[key] || '',
    }));

    return {
      className,
      metrics,
    };
  });

  return {
    classLevelMetrics,
    metricsInfo: metricsInfo.filter(data => data.name !== 'support'),
  };
};

export const useEvalClassLevelMetricsQuery = (evaluationId: string) => {
  const { workspaceId } = useAppMetadata();

  const evaluationIds = [evaluationId];

  return useChartQuery<
    AxiosResponse<EvaluationsClassLevelMetricsResponse>,
    AxiosError,
    ClassMetricsResult
  >(
    [EVALUATION.GET_EVAL_METRICS, workspaceId, evaluationIds],
    () => evaluationsApi.getClassLevelMetricsForEvaluationsV1(workspaceId, evaluationIds),
    {
      select: classLevelMetricsSelector,
    },
  );
};
