import { Annotations, AxisType, LayoutAxis, PlotData, PlotType, Shape } from 'plotly.js-dist';
import { BehaviorSubject } from 'rxjs';
import { DefaultFigureLayout, PlotlyFigure } from '../../components/Plot';
import { BaseColors } from '../../lib/services/theme';
import { isGaussianUncertaintyValue, isUncertaintyValue, tryGetAssayValueGuess } from '../../lib/assays/util';
import type { AssayGraph, AssayValueCreate, SigmoidFitDetails } from '../../lib/assays/models';
import { arrayMinMax } from '../../lib/util/misc';
import { PALETTE, PLOT_COLORS } from '../../lib/util/plot';
import { roundValueDigits } from '../../lib/util/roundValues';
import { potensToNM } from '../../lib/util/units';
import { AssayDetail } from './assay-api';

function sigmoid(x: number, fit_details: SigmoidFitDetails) {
    const g = fit_details;
    const logValue = Math.log10(g.value);
    return (g.max - g.min) / (1 + 10 ** (-g.slope * (x - logValue))) + g.min;
}

function convertToNM(toNMConstant: number, v: number) {
    return toNMConstant * v;
}

function approximateOffsets(all: AssayGraph[]): { [series: number]: (number | undefined)[] } {
    const allPoints: number[][] = [];
    for (let seriesIdx = 0; seriesIdx < all.length; seriesIdx++) {
        const { x, y } = all[seriesIdx];
        for (let i = 0; i < x.length; i++) {
            allPoints.push([x[i], y[i], i, seriesIdx]);
        }
    }

    allPoints.sort((a, b) => {
        if (a[0] !== b[0]) return a[0] - b[0];
        if (a[1] !== b[1]) return a[1] - b[1];
        return 0;
    });

    const offsets: { [series: number]: (number | undefined)[] } = {};
    for (let i = 0; i < allPoints.length - 1; i++) {
        const a = allPoints[i];
        const b = allPoints[i + 1];
        const dx = a[0] - b[0];
        const dy = a[1] - b[1];
        const seriesIdx = a[3];
        const offsetSgn = (-1) ** (i % 2);
        const d = Math.sqrt(dx * dx + dy * dy);
        if (d < 5) {
            if (!offsets[seriesIdx]) offsets[seriesIdx] = [];
            offsets[seriesIdx][a[2]] = offsetSgn * 0.1;
        }
    }

    return offsets;
}

const plotSymbols = ['circle', 'diamond', 'x', 'square', 'triangle-up', 'cross'];

export function updateAssayPlot(
    assay: AssayDetail,
    assayValues: AssayValueCreate[],
    figure?: BehaviorSubject<PlotlyFigure>,
    options?: {
        offsetCloseValues?: boolean;
        showAlternateFit?: boolean;
        customFits?: { fit: SigmoidFitDetails; width: number; label: string; isMain?: boolean; color?: string }[][];
    },
    ic90Values?: Record<string, number>,
    initialSelection?: Plotly.PlotDatum[]
) {
    const data: Partial<PlotData>[] = [];
    const shapes: Partial<Shape>[] = [];
    const annotations: Partial<Annotations>[] = [];
    const extraTickChecks: number[] = [];
    const [staticYRangeMin, staticYRangeMax] =
        assay.property.environment === 'In vitro - cell' ? [-50, 150] : [-20, 120];
    let [xRangeMin, xRangeMax]: number[] = [Infinity, -Infinity];
    let [gObsMin, gObsMax] = [staticYRangeMax, staticYRangeMin];
    let [gFitMin, gFitMax] = [staticYRangeMax, staticYRangeMin];
    let showFitBounds = false;
    let hasGraph = false;
    let xAxisTitle = '';
    let yAxisTitle = '';
    let xAxisType: AxisType | undefined;

    const xLabel = (x: number, text: string, color: any) => {
        annotations.push({
            xref: 'x',
            yref: 'paper',
            x,
            xanchor: 'left',
            y: 0,
            yanchor: 'bottom',
            text,
            font: { color, size: 14 },
            showarrow: false,
        });
    };

    const yLabel = (y: number, text: string, color: any) => {
        annotations.push({
            xref: 'paper',
            yref: 'y',
            x: 0,
            xanchor: 'right',
            y,
            yanchor: 'middle',
            text,
            font: { color, size: 14 },
            showarrow: false,
        });
    };

    const hLine = (v: number, color: any) =>
        shapes.push({
            type: 'line',
            x0: 0,
            y0: v,
            x1: 1,
            y1: v,
            xref: 'paper',
            line: {
                color,
                width: 0.75,
                dash: 'dot',
            },
        });

    for (let idx = 0; idx < assayValues.length; idx++) {
        const assayValue = assayValues[idx];
        if (!assayValue.graph) continue;
        hasGraph = true;
        const color = assayValues.length > 1 ? PALETTE[(3 * (idx + 1)) % PALETTE.length] : PLOT_COLORS.pink;
        const infoColor = assayValues.length > 1 ? color : PLOT_COLORS.info;
        let toNMConstant = 1;
        let xUnits = assayValue.graph.x_units;
        if (assayValue.graph?.x_units === 'M') {
            toNMConstant = 1e9;
            xUnits = 'nM';
        } else if (assayValue.graph?.x_units === 'mM') {
            toNMConstant = 1e6;
            xUnits = 'nM';
        } else if (assayValue.graph?.x_units === 'uM') {
            toNMConstant = 1e3;
            xUnits = 'nM';
        }
        if (xUnits === 'nM') xAxisType = 'log';
        if (!xAxisTitle) xAxisTitle = `${assayValue.graph.x_axis} ${xUnits ? `(${xUnits})` : xUnits}`;
        if (!yAxisTitle) yAxisTitle = `${assayValue.graph.y_axis} (${assayValue.graph.y_units})`;
        const xBounds: number[] = [];
        const yBounds: number[] = [];
        const yBoundsSelected: number[] = [];
        const xOffsets = approximateOffsets(assayValue.graph.data);
        let gI = 0;
        for (const { name, x: xs, y: ys, mask } of assayValue.graph.data) {
            const selData = mask?.map((v) => (v ? 0 : 1)) ?? xs.map(() => 0);
            const allSelected = selData.every((v) => v === 0);
            const curveNumber = gI;
            const traceSelectedPoints = initialSelection?.filter((p) => p.curveNumber === curveNumber);

            const xdata = xUnits === 'nM' ? xs.map((v) => convertToNM(toNMConstant, v)) : xs;
            xBounds.push(...arrayMinMax(xs));
            yBounds.push(...arrayMinMax(ys));
            yBoundsSelected.push(...arrayMinMax(ys.filter((_, i) => selData[i] === 0)));

            const displayName = `${assayValue.batch_identifier}: ${name}`;
            const text: string[] = xs.map(
                (x, i) =>
                    `${displayName}<br>${roundValueDigits(
                        3,
                        xUnits === 'nM' ? convertToNM(toNMConstant, x) : x
                    )} ${xUnits}: ${ys[i].toFixed(1)}${
                        assayValue.graph?.y_units !== 'potens' ? assayValue.graph?.y_units : ''
                    }`
            );

            const seriesOffsets = xOffsets[gI];
            data.push({
                // visually offset x value to reduce overlap
                // uses multiplication because of log scale
                x: options?.offsetCloseValues ? xdata.map((v, i) => v + (seriesOffsets?.[i] ?? 0) * v) : xdata,
                y: ys,
                selectedpoints: traceSelectedPoints?.map((p) => p.pointNumber),
                name: assayValue.batch_identifier,
                showlegend: gI === 0, // NOTE: For now we will show only one entry in the legend
                type: 'scatter' as PlotType,
                mode: 'markers',
                text,
                marker: {
                    size: 8,
                    color: allSelected ? color : selData,
                    symbol: plotSymbols[gI % plotSymbols.length],
                    colorscale: [
                        [0, color],
                        [1, '#121212'],
                    ],
                    line: {
                        width: selData,
                        color,
                    },
                },
                hoverinfo: 'text',
            });

            gI++;
        }

        const valueSource = assayValue.details.insight_value_source ?? assayValue.value_details.source;

        const [yMin, yMax] = arrayMinMax(yBounds);
        let [bottom, top] = valueSource === 'cro' ? [yMin, yMax] : arrayMinMax(yBoundsSelected);
        gObsMin = Math.min(gObsMin, yMin);
        gObsMax = Math.max(gObsMax, yMax);

        let midPoint = bottom + (top - bottom) / 2;
        let midPointIc90 = 0;

        let [xMin, xMax] = arrayMinMax(xBounds);
        if (xAxisType === 'log') {
            xMin = Math.log10(xMin);
            xMax = Math.log10(xMax);
        }
        if (xMin < xRangeMin) xRangeMin = xMin;
        if (xMax > xRangeMax) xRangeMax = xMax;

        const fits: { fit: SigmoidFitDetails; width: number; label: string; isMain?: boolean; color?: string }[] =
            options?.customFits?.[idx] ?? [];

        if (!options?.customFits?.[idx]) {
            let mainFit: SigmoidFitDetails | undefined;
            let alternateFit: SigmoidFitDetails | undefined;

            if (valueSource === 'sigmoid_fit') {
                mainFit = assayValue.value_details.sigmoid_fit_details;
                alternateFit = assayValue.value_details.cro_sigmoid_fit_details;
            } else if (valueSource === 'cro') {
                alternateFit = assayValue.value_details.sigmoid_fit_details;
                mainFit = assayValue.value_details.cro_sigmoid_fit_details;
            }

            if (mainFit) {
                fits.push({ fit: mainFit, width: 1.5, label: 'Main Fit', isMain: true });
            }
            if (options?.showAlternateFit && alternateFit) {
                fits.push({ fit: alternateFit, width: 1, label: 'Alt Fit' });
            }
        }

        let fitMin = 0;
        let fitMax = 0;
        for (const { fit, isMain } of fits) {
            if (!fit || fit.value <= 0) continue;

            const logValue = Math.log10(fit.value);
            // The +/- 2 in the log scale ensures the entire sigmoid is visible outside the bounds
            // of the available measured data points
            fitMin = Math.min(xMin, Math.max(1e-6, logValue - 2));
            fitMax = Math.max(xMax, logValue + 2);
            if (isMain) break;
        }
        const pointCount = Math.ceil((fitMax - fitMin) * 15);
        const ic90valuePotens = ic90Values ? tryGetAssayValueGuess(ic90Values[assayValue.batch_identifier]) : undefined;
        const ic90value = typeof ic90valuePotens === 'number' ? potensToNM(ic90valuePotens, false) : undefined;
        for (const { fit, width, label, isMain, color: fitColor } of fits) {
            if (!fit || fit.value <= 0) continue;

            const xs = new Array<number>(pointCount);
            const ys = new Array<number>(pointCount);

            for (let i = 0; i < pointCount; i++) {
                const x = fitMin + ((fitMax - fitMin) * i) / (pointCount - 1);
                xs[i] = convertToNM(toNMConstant, 10 ** x);
                ys[i] = sigmoid(x, fit);
            }

            if (isMain) {
                bottom = fit.min;
                top = fit.max;
                midPoint = sigmoid(Math.log10(fit.value), fit);
                midPointIc90 = ic90value ? sigmoid(Math.log10(ic90value / toNMConstant), fit) : 0;

                gFitMin = Math.min(gFitMin, bottom);
                gFitMax = Math.max(gFitMax, top);
                showFitBounds = true;
            }

            data.push({
                x: xs,
                y: ys,
                name: `${assayValue.batch_identifier}: ${label}`,
                showlegend: false,
                type: 'scatter' as PlotType,
                mode: 'lines',
                line: {
                    color: fitColor ?? color,
                    dash: isMain ? 'solid' : 'dot',
                    width,
                },
            });
        }

        if (assay.property.measurement === 'IC50') {
            let ic50value = tryGetAssayValueGuess(assayValue.value);
            if (ic50value !== undefined && Number.isFinite(ic50value)) {
                ic50value = convertToNM(toNMConstant, ic50value);

                if (ic90value) {
                    shapes.push({
                        type: 'line',
                        x0: ic90value,
                        y0: 0,
                        x1: ic90value,
                        yref: 'paper',
                        y1: 1,
                        line: {
                            color: infoColor,
                            width: 1.5,
                            dash: 'solid',
                        },
                    });
                    xLabel(Math.log10(ic90value), 'IC<sub>90</sub>', infoColor);

                    hLine(midPointIc90, infoColor);
                    extraTickChecks.push(midPointIc90);

                    yLabel(midPointIc90, midPointIc90.toFixed(0), infoColor);
                }
                shapes.push({
                    type: 'line',
                    x0: ic50value,
                    y0: 0,
                    x1: ic50value,
                    yref: 'paper',
                    y1: 1,
                    line: {
                        color: infoColor,
                        width: 1.5,
                        dash: 'solid',
                    },
                });

                if (isGaussianUncertaintyValue(assayValue.value)) {
                    shapes.push({
                        type: 'rect',
                        xref: 'x',
                        x0: convertToNM(toNMConstant, assayValue.value.lower_bound),
                        x1: convertToNM(toNMConstant, assayValue.value.upper_bound),
                        yref: 'paper',
                        y0: 0,
                        y1: 1,
                        layer: 'below',
                        fillcolor: infoColor,
                        opacity: 0.3,
                    });
                }

                if (isUncertaintyValue(assayValue.value)) {
                    for (let i = 0; i < assayValue.value.lower_bounds.length; i++) {
                        const lower = assayValue.value.lower_bounds[i];
                        const upper = assayValue.value.upper_bounds[i];

                        shapes.push({
                            type: 'rect',
                            xref: 'x',
                            x0: convertToNM(toNMConstant, lower),
                            x1: convertToNM(toNMConstant, upper),
                            yref: 'paper',
                            y0: 0,
                            y1: 1,
                            layer: 'below',
                            fillcolor: infoColor,
                            opacity: 0.1,
                        });
                    }
                }

                xLabel(Math.log10(ic50value), 'IC<sub>50</sub>', infoColor);
            }
            hLine(midPoint, infoColor);
            extraTickChecks.push(midPoint);

            // label midpoint
            yLabel(midPoint, midPoint.toFixed(0), infoColor);
        }
    }

    if (showFitBounds) {
        hLine(gFitMin, PLOT_COLORS.yellow);
        hLine(gFitMax, PLOT_COLORS.yellow);
    }

    // Show min/max observed values on Y axis
    if (assayValues.length > 0 && hasGraph && assay.property.measurement === 'IC50') {
        extraTickChecks.push(gObsMin, gObsMax);
        hLine(gObsMin, PLOT_COLORS.pink);
        yLabel(gObsMin, Math.round(gObsMin).toString(), PLOT_COLORS.pink);
        hLine(gObsMax, PLOT_COLORS.pink);
        yLabel(gObsMax, Math.round(gObsMax).toString(), PLOT_COLORS.pink);
    }

    const xAxis = getXAxisRange(assay, xRangeMin, xRangeMax, xAxisType);
    const yAxis = getYAxisRange(assay, staticYRangeMin, staticYRangeMax, extraTickChecks, xAxisType);

    const newFigure: PlotlyFigure = {
        layout: {
            ...DefaultFigureLayout,
            font: { color: BaseColors.body },
            showlegend: assayValues.length > 1,
            dragmode: figure?.value.layout.dragmode ?? 'turntable',
            clickmode: 'event',
            margin: { t: 20, b: 50, l: 50, r: 10 },
            shapes,
            annotations,
            yaxis: {
                ...DefaultFigureLayout.yaxis,
                ...yAxis,
                title: { text: yAxisTitle, font: { size: 14 } },
                gridcolor: PLOT_COLORS.grid,
                zeroline: false,
                tickfont: { size: 14 },
            },
            xaxis: {
                ...DefaultFigureLayout.xaxis,
                ...xAxis,
                title: { text: xAxisTitle, font: { size: 14 } },
                gridcolor: PLOT_COLORS.grid,
                zeroline: false,
                type: xAxisType,
                tickfont: { size: 14 },
            },
        },
        data,
    };

    figure?.next(newFigure);
    return newFigure;
}

function getXAxisRange(
    assay: AssayDetail,
    xRangeMin: number,
    xRangeMax: number,
    xAxisType?: AxisType
): Partial<LayoutAxis> {
    // for clearance assays we will use default axis range settings
    if (['CL', 'CLL'].includes(assay.property.measurement)) return {} as Partial<LayoutAxis>;
    return {
        rangemode: 'nonnegative',
        range: xAxisType === 'log' ? undefined : [xRangeMin - 1, xRangeMax + 1],
        dtick: xAxisType === 'log' ? undefined : 1,
        tick0: xAxisType === 'log' ? undefined : 0,
        tickcolor: '#b4b7b9',
    } as Partial<LayoutAxis>;
}

function getYAxisRange(
    assay: AssayDetail,
    staticYRangeMin: number,
    staticYRangeMax: number,
    extraTickChecks: number[],
    xAxisType?: AxisType
): Partial<LayoutAxis> {
    // for clearance assays we will use default axis range settings
    if (['CL', 'CLL'].includes(assay.property.measurement)) return {} as Partial<LayoutAxis>;
    let [yRangeMin, yRangeMax] = [staticYRangeMin, staticYRangeMax];
    const ticks: [number, string][] = [];
    const minDist = 10;
    const tickLabel = (tv: number) =>
        ticks.every(([v]) => Math.abs(v - tv) > minDist) && extraTickChecks.every((v) => Math.abs(v - tv) > minDist)
            ? tv.toString()
            : '';

    if (xAxisType === 'log') {
        [yRangeMin, 0, 50, 100, yRangeMax].forEach((t) => ticks.push([t, tickLabel(t)]));
    } else {
        yRangeMin = 0;
        yRangeMax = 14;
        [yRangeMin, 0, 5, 10, yRangeMax].forEach((t) => ticks.push([t, `${t}`]));
    }

    ticks.sort((a, b) => a[0] - b[0]);

    const yAxisRangeWithPadding =
        xAxisType === 'log' ? [yRangeMin - 10, yRangeMax + 10] : [yRangeMin - 1, yRangeMax + 1];

    return {
        range: yAxisRangeWithPadding,
        tickmode: 'array',
        tickvals: ticks.map((v) => v[0]),
        ticktext: ticks.map((v) => v[1]),
        tickcolor: '#b4b7b9',
    } as Partial<LayoutAxis>;
}

export function getAssayPlotYMinMaxValue(assayValue: AssayValueCreate): [number, number] {
    if (!assayValue.graph) return [Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
    const yBounds: number[] = [];
    for (const { y: ys } of assayValue.graph.data) {
        yBounds.push(...arrayMinMax(ys));
    }
    return arrayMinMax(yBounds);
}
