import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import { faArrowDown, faArrowUp } from '@fortawesome/free-solid-svg-icons';
import { BehaviorSubject } from 'rxjs';
import {
    Column,
    ColumnSorting,
    ColumnsFor,
    DataTableModel,
    ObjectDataTableStore,
    objectDataTableStore,
} from '../../components/DataTable';
import { PlotlyFigure } from '../../components/Plot';
import { tryGetAssayValueGuess } from '../../lib/assays/util';
import { asyncCacheAdd } from '../../lib/util/async-cache';
import { DefaultPlotlyFigureLayout, PLOT_COLORS } from '../../lib/util/plot';
import { ModelAction, ReactiveModel } from '../../lib/util/reactive-model';
import { CompoundAPI } from '../Compounds/compound-api';
import {
    BoxAPI,
    BoxDetailsData,
    BoxTrainingDataResult,
    BoxVersion,
    boxVersionCompare,
    boxVersionToString,
} from './box-api';
import { BaseColors } from '../../lib/services/theme';

export const BOX_INFO_STEPS = [
    'Box info',
    'Metadata',
    'Training data',
    'Conda environment',
    'Input and result structures',
    'Sample code',
] as const;
type BoxInfoStepsList = typeof BOX_INFO_STEPS;
type BoxInfoStep = BoxInfoStepsList[number];

export interface BoxVersionTable {
    id: number;
    version: BoxVersion;
    training_size?: number;
    test_rsquared?: number;
    mae?: number;
    mse?: number;
    uq_area_under_confidence_oracle?: number;
    uq_coefficient_of_variation?: number;
    uq_confidence_monotonic_non_increase_ratio?: number;
    uq_error_drop?: number;
    uq_expected_calibration_error?: number;
    uq_spearman?: number;
}

type BetterMetric = 'higher' | 'lower' | 'neither';

function BoxMetricHeader({ name, bounds, better }: { name: string; bounds: [number, number]; better: BetterMetric }) {
    const min = Number.isFinite(bounds[0]) ? bounds[0] : '-∞';
    const max = Number.isFinite(bounds[1]) ? bounds[1] : '∞';

    return (
        <div className='d-flex flex-column'>
            {name}
            <div>
                <span className='me-1'>
                    ({min}, {max})
                </span>
                {better === 'higher' && <FontAwesomeIcon icon={faArrowUp} fixedWidth />}
                {better === 'lower' && <FontAwesomeIcon icon={faArrowDown} fixedWidth />}
            </div>
        </div>
    );
}

const BoxVersionTableSchema: ColumnsFor<BoxVersionTable> = {
    id: Column.create({
        kind: 'int',
        noHeaderTooltip: true,
        width: 80,
    }),
    version: Column.create({
        kind: 'obj',
        noHeaderTooltip: true,
        width: 80,
        format: (value: BoxVersion) => boxVersionToString(value),
        compare: ColumnSorting.comparerWithBlanks(boxVersionCompare),
        render: ({ value }: { value: BoxVersion }) => boxVersionToString(value),
        disableGlobalFilter: true,
    }),
    training_size: Column.create({
        kind: 'int',
        noHeaderTooltip: true,
        width: 100,
    }),
    test_rsquared: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 100,
        header: () => <BoxMetricHeader name='Test R2' bounds={[0, 1]} better='higher' />,
    }),
    mae: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 80,
        header: () => <BoxMetricHeader name='MAE' bounds={[0, Infinity]} better='lower' />,
    }),
    mse: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 80,
        header: () => <BoxMetricHeader name='MSE' bounds={[0, Infinity]} better='lower' />,
    }),
    uq_area_under_confidence_oracle: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 225,
        header: () => <BoxMetricHeader name='UQ Area Under Confidence Oracle' bounds={[0, 100]} better='lower' />,
    }),
    uq_coefficient_of_variation: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 175,
        header: () => <BoxMetricHeader name='UQ Coefficient of Variation' bounds={[0, 1]} better='higher' />,
    }),
    uq_confidence_monotonic_non_increase_ratio: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 280,
        header: () => (
            <BoxMetricHeader name='UQ Confidence Monotonic Non-Increase Ratio' bounds={[0, 1]} better='higher' />
        ),
    }),
    uq_error_drop: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 110,
        header: () => <BoxMetricHeader name='UQ Error Drop' bounds={[0, Infinity]} better='neither' />,
    }),
    uq_expected_calibration_error: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 200,
        header: () => <BoxMetricHeader name='UQ Expected Calibration Error' bounds={[0, 1]} better='lower' />,
    }),
    uq_spearman: Column.create({
        ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
        noHeaderTooltip: true,
        width: 110,
        header: () => <BoxMetricHeader name='UQ Spearman' bounds={[-1, 1]} better='higher' />,
    }),
};

function createTable(store: ObjectDataTableStore<BoxVersionTable, BoxVersionTable>, id?: number | null) {
    const table = new DataTableModel<BoxVersionTable>(store, {
        columns: BoxVersionTableSchema,
        hideNonSchemaColumns: true,
    });

    table.setColumnStickiness('id', true);
    table.setColumnStickiness('version', true);

    const rowIdx = store.findValueIndex('id', id);

    if (rowIdx > -1) table.setSelected(rowIdx, true);

    return table;
}

function getTrainingSize(
    data_size_per_split: { training_set_size: number; validation_set_size: number; test_set_size: number }[]
): number | undefined {
    if (!data_size_per_split || data_size_per_split.length === 0) return undefined;
    // all entries in data_size_per_split should be the same so it doesn't matter which we pick
    const entry = data_size_per_split[0];
    return entry.training_set_size + entry.validation_set_size + entry.test_set_size;
}

export class BoxInfoModel extends ReactiveModel {
    public boxes: BoxDetailsData[] = undefined as any;
    public table: DataTableModel<BoxVersionTable> = undefined as any;

    state = {
        step: new BehaviorSubject<BoxInfoStep>(BOX_INFO_STEPS[0]),
        currentBoxId: new BehaviorSubject<number | undefined>(undefined),
    };

    private trainingDataCache = new Map<number, Promise<BoxTrainingDataResult>>();
    private trainingSubstanceCache = new Map<number, Promise<Record<string, string | undefined>>>();

    actions = {
        loadTrainingData: new ModelAction<BoxTrainingDataResult>({
            onError: 'state',
        }),
        querySubstances: new ModelAction<Record<string, string | undefined>>({
            onError: 'state',
        }),
    };

    loadTrainingData(id: number) {
        let promise: Promise<BoxTrainingDataResult>;
        if (!this.trainingDataCache.has(id)) {
            promise = asyncCacheAdd(this.trainingDataCache, id, BoxAPI.downloadTrainingData(id));
            this.trainingDataCache.set(id, promise);
        } else {
            promise = this.trainingDataCache.get(id)!;
        }
        return promise;
    }

    querySubstances(id: number, data: BoxTrainingDataResult) {
        let promise: Promise<any>;
        if (!this.trainingSubstanceCache.has(id)) {
            const smilesToIdentifierColumnName = data.dataframe.hasColumn('STANDARDIZED_SMILES')
                ? 'STANDARDIZED_SMILES'
                : 'SMILES';
            const smilesList = data.dataframe
                .getColumnValues(smilesToIdentifierColumnName)
                .filter((v) => !!v) as string[];
            promise = asyncCacheAdd(this.trainingSubstanceCache, id, CompoundAPI.querySubstances(smilesList));
            this.trainingSubstanceCache.set(id, promise);
        } else {
            promise = this.trainingSubstanceCache.get(id)!;
        }
        return promise;
    }

    // Plotly does not support log scale x-axes for histograms
    // https://github.com/plotly/plotly.js/issues/6200
    // so we have to fake it with a bar chart + manually created bins
    private getLogHistogram(data: BoxTrainingDataResult): PlotlyFigure {
        const counts: number[] = [];
        for (let i = 1; i < data.bins.length; i++) {
            const min = data.bins[i - 1];
            const max = data.bins[i];
            const count = data.dataframe.getColumnValues('y').filter((v) => {
                const val = tryGetAssayValueGuess(v);
                if (typeof val !== 'number') return false;
                return val > min && val <= max;
            }).length;
            counts.push(count);
        }

        return {
            data: [
                {
                    type: 'bar',
                    marker: { color: PLOT_COLORS.pink },
                    x: [...data.bins],
                    y: counts,
                    width: 0.2,
                },
            ],
            layout: {
                ...DefaultPlotlyFigureLayout,
                font: { color: BaseColors.body },
                xaxis: {
                    ...DefaultPlotlyFigureLayout.xaxis,
                    title: `Value (${data.units})`,
                    type: 'log',
                },
            },
        } as PlotlyFigure;
    }

    private getLinearHistogram(data: BoxTrainingDataResult): PlotlyFigure {
        return {
            data: [
                {
                    type: 'histogram',
                    marker: { color: PLOT_COLORS.pink },
                    x: [...data.dataframe.getColumnValues('y').map((v) => tryGetAssayValueGuess(v))],
                },
            ],
            layout: {
                ...DefaultPlotlyFigureLayout,
                font: { color: BaseColors.body },
                xaxis: {
                    ...DefaultPlotlyFigureLayout.xaxis,
                    title: `Value (${data.units})`,
                    type: 'linear',
                },
            },
        } as PlotlyFigure;
    }

    getHistogram(data: BoxTrainingDataResult) {
        const id = this.state.currentBoxId.value;
        const version = this.boxes.find((b) => b.id === id)!;
        const transform = version.metadata.training_task?.transform_specification?.transformer;
        if (transform === 'log') return this.getLogHistogram(data);
        return this.getLinearHistogram(data);
    }

    async init(id?: number | null) {
        const box_details = await BoxAPI.query(this.kind, this.name);
        this.boxes = box_details;
        const store = objectDataTableStore<BoxVersionTable, BoxVersionTable>(
            [
                { name: 'id' },
                { name: 'version' },
                { name: 'training_size' },
                { name: 'test_rsquared' },
                { name: 'mae' },
                { name: 'mse' },
                { name: 'uq_area_under_confidence_oracle' },
                { name: 'uq_coefficient_of_variation' },
                { name: 'uq_confidence_monotonic_non_increase_ratio' },
                { name: 'uq_error_drop' },
                { name: 'uq_expected_calibration_error' },
                { name: 'uq_spearman' },
            ],
            this.boxes.map((b) => ({
                id: b.id,
                version: b.identifier.version,
                training_size: getTrainingSize(b.metadata.data_size_per_split),
                test_rsquared: b.metadata.metrics?.['R2 Score'],
                mae: b.metadata.metrics?.['Mean Absolute Error'],
                mse: b.metadata.metrics?.['Mean Squared Error'],
                uq_area_under_confidence_oracle: b.metadata.metrics?.['UQ Area Under Confidence Oracle'],
                uq_coefficient_of_variation: b.metadata.metrics?.['UQ Coefficient of Variation'],
                uq_confidence_monotonic_non_increase_ratio:
                    b.metadata.metrics?.['UQ Confidence Monotonic Non-Increase Ratio'],
                uq_error_drop: b.metadata.metrics?.['UQ Error Drop'],
                uq_expected_calibration_error: b.metadata.metrics?.['UQ Expected Calibration Error'],
                uq_spearman: b.metadata.metrics?.['UQ Spearman'],
            }))
        );
        this.table = createTable(store, id);
    }

    mount() {
        this.subscribe(this.table.version, () => {
            const currentBoxId = this.state.currentBoxId.value;
            const selectedIndices = this.table.getSelectedRowIndices();
            if (selectedIndices.length > 0) {
                const id = this.table.store.getValue('id', selectedIndices[0]);
                if (currentBoxId !== id) {
                    this.state.currentBoxId.next(id);
                    this.actions.loadTrainingData.run(this.loadTrainingData(id));
                }
            } else if (currentBoxId !== undefined) {
                this.state.currentBoxId.next(undefined);
            }
        });
    }

    constructor(public kind: string, public name: string) {
        super();
    }
}
