import { ReactNode } from 'react';
import { BehaviorSubject, Subject } from 'rxjs';
import { columnNameToHeader } from '../../api/data';
import { ErrorWithMessage, ErrorIconWithMessage, isErrorWithMessage } from '../../components/common/Error';
import { InfoTooltip } from '../../components/common/Tooltips';
import {
    Column,
    columnDataTableStore,
    DefaultRowHeight,
    DataTableModel,
    DataTableStore,
    DefaultFloatColumnFormatOptions,
    formatFloat,
} from '../../components/DataTable';
import {
    DownloadArtifactColumn,
    HoverBatchLink,
    NumberListColumnSchema,
    SelectionColumn,
    SmilesColumn,
} from '../../components/DataTable/common';
import { DefaultFigureLayout, PlotlyFigure } from '../../components/Plot';
import {
    assayValuePotensToNM,
    assayValueTypeCompare,
    assayValueTypeToCsvString,
    isCurveMeasurement,
} from '../../lib/assays/util';
import { AssayValueView, formatPotensName } from '../../lib/assays/display';
import type { AssayValueCreate, AssayValueGraph, AssayValueType } from '../../lib/assays/models';
import { AssayValueHeader, PotensAsNmProp } from '../../lib/assays/table';
import { AsyncQueue } from '../../lib/util/async-queue';
import { AsyncMoleculeDrawer } from '../../lib/util/draw-molecules';
import { reportErrorAsToast } from '../../lib/util/errors';
import { ReactiveModel } from '../../lib/util/reactive-model';
import { AssayAPI, AssayDetail, AssayValueDetail } from './assay-api';
import { updateAssayPlot } from './assay-common';
import { getCustomFits } from './upload/bayes-api';

const SmilesRowHeightFactor = 2;
const rowHeight = DefaultRowHeight;

function AssayDetailValueColumn(options: { columnName: string }): Column<AssayValueType | ErrorWithMessage, {}> {
    return {
        ...Column.obj({
            format: (v) => '<unused>',
            compare: assayValueTypeCompare,
            csvFormat: (value: unknown, opt, table, columnName) => {
                if (isErrorWithMessage(value)) return (value as ErrorWithMessage).message;
                const assayValue = value as AssayValueType | undefined;
                const asNM =
                    table.state.customState.units[columnName] === 'potens' && table.state.customState[PotensAsNmProp];
                return assayValueTypeToCsvString(asNM ? assayValuePotensToNM(assayValue) : assayValue);
            },
        }),
        csvHeader: (table, { columnName }) => {
            const isPotens = table.state.customState.units[columnName] === 'potens';
            if (isPotens) {
                const asNM = table.state.customState[PotensAsNmProp];
                return formatPotensName(options.columnName, { asNM, asString: true }) as string;
            }
            return columnNameToHeader(options.columnName);
        },
        width: 175,
        label: options.columnName,
        header: (tbl, { columnName }) => (
            <AssayValueHeader table={tbl} columnName={options?.columnName ?? columnName} columnId={columnName} />
        ),
        render: ({ value, columnName, table }) => {
            if (isErrorWithMessage(value)) return <ErrorIconWithMessage value={value as ErrorWithMessage} />;
            const assayValue = value as AssayValueType | undefined;
            return (
                <AssayValueView
                    value={assayValue}
                    options={{
                        asNM:
                            table.state.customState.units[columnName] === 'potens' &&
                            table.state.customState[PotensAsNmProp],
                    }}
                />
            );
        },
    } as Column<AssayValueType | ErrorWithMessage, {}>;
}

function AssayValueTableSchema(
    assay: AssayDetail,
    drawer: AsyncMoleculeDrawer
): [colName: keyof AssayValueDetail, column: Column][] {
    return [
        [
            'smiles',
            SmilesColumn(drawer, SmilesRowHeightFactor, {
                width: 150,
                identifierPadding: 18,
                getIdentifierElement: ({ rowIndex, table }) => {
                    const identifier = table.store.getValue('identifier', rowIndex);
                    if (!identifier) return '-';
                    return <HoverBatchLink identifier={identifier} withQuery />;
                },
            }),
        ],
        ['identifier', Column.str()],
        [
            'supplier_id',
            Column.create({
                kind: 'str',
                header: 'Supplier ID',
                noHeaderTooltip: true,
            }),
        ],
        ['value', AssayDetailValueColumn({ columnName: assay.property.measurement })],
        ['ic90', AssayDetailValueColumn({ columnName: 'IC90' })],
        ['y_values', NumberListColumnSchema({ defaultFormatting: DefaultFloatColumnFormatOptions })],
        [
            'y_std_dev',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                noHeaderTooltip: true,
            }),
        ],
        [
            'obs_min',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                label: 'Obs. Min',
                noHeaderTooltip: true,
            }),
        ],
        [
            'obs_max',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                label: 'Obs. Max',
                noHeaderTooltip: true,
            }),
        ],
        [
            'emin',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                label: 'Emin',
                noHeaderTooltip: true,
            }),
        ],
        [
            'emax',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                label: 'Emax',
                noHeaderTooltip: true,
            }),
        ],
        [
            'auc',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                label: 'Normalized AUC',
                header: 'Normalized AUC',
                noHeaderTooltip: true,
                render: ({ value, table, columnName }) => {
                    const formatting = table.state.columnFormatting[columnName];
                    if (value.dynamic) {
                        const min = formatFloat(value.range[0], formatting);
                        const max = formatFloat(value.range[1], formatting);
                        return (
                            <div>
                                <div>
                                    {formatFloat(value.dynamic, formatting)}
                                    <InfoTooltip
                                        className='text-info'
                                        buttonClassName='px-1 py-0'
                                        tooltip='Dynamically generated'
                                    />
                                </div>
                                <span className='font-body-xxsmall text-secondary'>
                                    {min} - {max} nM
                                </span>
                            </div>
                        );
                    }
                    return formatFloat(value, formatting);
                },
            }),
        ],
        ['ligand_efficiency', AssayDetailValueColumn({ columnName: 'LE' })],
        [
            'scaled_clearance',
            Column.create({ ...AssayDetailValueColumn({ columnName: 'Scaled Clearance' }), width: 200 }),
        ],
        [
            'hepatic_extraction_ratio',
            Column.create({ ...AssayDetailValueColumn({ columnName: 'Hepatic Extraction Ratio' }), width: 200 }),
        ],
        [
            'performed_on',
            Column.create({
                ...Column.datetime({ format: 'full' }),
                noHeaderTooltip: true,
                width: 175,
            }),
        ],
        [
            'uploaded_on',
            Column.create({
                ...Column.datetime({ format: 'full' }),
                noHeaderTooltip: true,
                width: 175,
            }),
        ],
        [
            'fit_source',
            Column.create({
                kind: 'str',
                noHeaderTooltip: true,
            }),
        ],
        [
            'r2',
            Column.create({
                ...Column.float({ defaultFormatting: { significantDigits: 3, scientific: false } }),
                header: () => (
                    <>
                        Fit R<sup>2</sup>
                    </>
                ),
                noHeaderTooltip: true,
            }),
        ],
        ['slope', AssayDetailValueColumn({ columnName: 'Fit Slope' })],
        ['min', AssayDetailValueColumn({ columnName: 'Fit Min' })],
        ['max', AssayDetailValueColumn({ columnName: 'Fit Max' })],
        [
            'doi',
            Column.create({
                kind: 'str',
                noHeaderTooltip: true,
            }),
        ],
        [
            'patent_no',
            Column.create({
                kind: 'str',
                noHeaderTooltip: true,
            }),
        ],
        ['artifacts', DownloadArtifactColumn()],
    ];
}

function createTable(assay: AssayDetail, units: Record<string, string>, values: DataTableStore<AssayValueDetail>) {
    const isCurve = isCurveMeasurement(assay.property);

    const drawer = new AsyncMoleculeDrawer();
    const table = new DataTableModel(values, {
        columns: AssayValueTableSchema(assay, drawer),
        actions: isCurve ? [SelectionColumn()] : undefined,
        hideNonSchemaColumns: true,
        rowHeight: SmilesRowHeightFactor * rowHeight,
        customState: {
            units,
            'show-smiles': true,
            [PotensAsNmProp]: false,
        },
        globalFilterHiddenColumns: true,
    });

    table.setHiddenColumns(['identifier']);

    table.sortBy('uploaded_on', true);

    if (isCurve) {
        table.setColumnStickiness('selection', true);
    }

    table.setColumnStickiness('smiles', true);

    return table;
}

export async function loadAssay(id: number, message: BehaviorSubject<ReactNode>) {
    const promises = [AssayAPI.get(id), AssayAPI.getValuesFull(id)] as const;
    const [assay, { dataframe, units, values }] = await Promise.all(promises);

    const table = createTable(assay, units, columnDataTableStore(dataframe));

    const model = new AssayInformationModel(assay, table, values);
    model.mount();
    return model;
}

export class AssayInformationModel extends ReactiveModel {
    moleculeDrawer = new AsyncMoleculeDrawer();

    readonly isCurveMeasurement: boolean;

    state = {
        figure: new BehaviorSubject<PlotlyFigure>({
            data: [],
            layout: DefaultFigureLayout,
        }),
        plot: {
            selectPointsEvent: new Subject<number[]>(),
        },
        valueColumn: new BehaviorSubject('value'),
    };

    graphSelectionPerValue: Record<number, Plotly.PlotDatum[] | undefined> = {};

    private queue = new AsyncQueue({ maxConcurrent: 1 });

    get selectedRowIds() {
        return this.table.selectedRows;
    }

    get currentSelectedGraphPoints() {
        const selectedRows = this.table.getSelectedRowIndices();
        if (selectedRows.length !== 1) return undefined;
        const selectedValue = this.selectedValues[0];
        return this.graphSelectionPerValue[selectedValue.id!];
    }

    get selectedValues() {
        const selectedRows = this.table.getSelectedRowIndices();
        const selectedIds = new Set(this.table.store.getColumnValues('id', selectedRows));
        return this.values.filter((v) => selectedIds.has(v.id!));
    }

    onSelection = async (selectedPoints: Plotly.PlotDatum[]) => {
        const selectedIndices = this.table.getSelectedRowIndices();
        if (selectedIndices.length !== 1) return;
        const selectedCurves = this.selectedValues;
        const value = selectedCurves[0];
        // if we deselected (selectedPoints === []) we want to return to
        // all points selected / no selection, which is represented by undefined
        // plotly selectedpoints = [] means no points selected, selectedpoints = undefined
        // means no selection
        this.graphSelectionPerValue[value.id!] = selectedPoints.length > 0 ? selectedPoints : undefined;
        this.queue.execute(() => this.updateDynamicAUC(value, selectedPoints));
    };

    private async updateDynamicAUC(value: AssayValueCreate, selectedPoints: Plotly.PlotDatum[]) {
        if (!value.graph || !this.table.store.hasColumn('auc')) return;

        const newGraphData = updateMaskFromSelection(value.graph, selectedPoints);
        const newGraph = { ...value.graph, data: newGraphData };

        let newAUC: AssayValueType;
        try {
            newAUC = await AssayAPI.calculateDynamicAUC(newGraph);
        } catch (err) {
            reportErrorAsToast('Error calculating dynamic AUC', err);
            newAUC = Number.NaN;
        }

        const rowIndex = this.table.store.findValueIndex('identifier', value.batch_identifier);
        // if we have selected some points (no points selected is treated as all points)
        // then we store the dynamic AUC value in the 'auc' column for the selected row
        let aucValue: AssayValueType | { dynamic: AssayValueType; range: [number, number] };
        if (selectedPoints.length > 0) {
            let minX = Infinity;
            let maxX = -Infinity;
            for (const point of selectedPoints) {
                if (typeof point.x !== 'number') continue;
                const x = point.x;
                minX = Math.min(minX, x);
                maxX = Math.max(maxX, x);
            }
            aucValue = { dynamic: newAUC, range: [minX, maxX] };
        } else {
            aucValue = newAUC;
        }
        this.table.store.setValue('auc', rowIndex, aucValue);
        this.table.updated({ clearRenderCache: true, columnId: 'auc', rowIndex });
    }

    private updatePlotData() {
        const selectedRows = this.table.getSelectedRowIndices();
        const selectedValues = this.selectedValues;
        const bayesValues = selectedValues.filter((v) => v.value_details.source === 'bayes_sigmoid_fit');
        const bayesCurves =
            bayesValues.length > 0 ? bayesValues.map((v) => v.value_details.bayes_sigmoid_fit_details!) : undefined;
        const selectedIC90: Record<string, number> = {};
        for (const rowId of selectedRows) {
            if (this.table.store.hasColumn('ic90')) {
                const ic90Value = this.table.store.getValue('ic90', rowId);
                const batchIdentifier = this.table.store.getValue('identifier', rowId);
                selectedIC90[batchIdentifier] = ic90Value;
            }
        }

        updateAssayPlot(
            this.assay,
            selectedValues,
            this.state.figure,
            {
                showAlternateFit: false,
                customFits: getCustomFits(bayesCurves),
            },
            selectedIC90,
            this.currentSelectedGraphPoints
        );
    }

    mount() {
        // if the set of selected rows changes in the table, update the plot
        this.subscribe(this.table.events.selectionChanged, () => {
            this.updatePlotData();
        });

        this.updatePlotData();
    }

    constructor(
        public assay: AssayDetail,
        public table: DataTableModel<AssayValueDetail>,
        public values: AssayValueCreate[]
    ) {
        super();

        table.setGlobalFilter(new URLSearchParams(window.location.search).get('compound-identifier') ?? '');

        const isCurve = isCurveMeasurement(assay.property);
        if (isCurve && table.rows.length > 0) {
            table.setSelection([table.rows[0]]);
        }
        this.isCurveMeasurement = isCurve;
    }
}

function updateMaskFromSelection(graph: AssayValueGraph, selectedPoints: Plotly.PlotDatum[]) {
    const newData = graph.data.map((d) => ({ ...d }));

    if (selectedPoints.length > 0) {
        for (const point of selectedPoints) {
            const { curveNumber, pointIndex } = point;
            const oldTrace = graph.data[curveNumber];
            const trace = newData[curveNumber];
            // skip if this value has a specified mask
            // and the mask is set to false (we don't want to
            // 'turn back on' points that were masked out in QC)
            if (oldTrace.mask && !oldTrace.mask[pointIndex]) continue;
            if (!trace.mask) {
                trace.mask = new Array(trace.x.length).fill(false);
            }
            trace.mask[pointIndex] = true;
        }
    }

    return newData;
}
