import { type MainViewCameraControlsRef, type ICameraControls, updateRaycaster } from '../ModelViewer';
import {
    type AdjacencyMatrix,
    buildMeshAdjacency,
    ensureMeshIndex,
    getNeighbors,
    AttributeName,
} from '@orthly/forceps';
import { Jaw } from '@orthly/shared-types';
import type React from 'react';
import * as THREE from 'three';
import type { MeshBVH } from 'three-mesh-bvh';

export type ScanReviewViewState = Record<ScanReviewPanelType, ScanReviewPanelCameraState | null>;

export enum ScanReviewViewType {
    Single = 'single',
    SideBySide = 'side-by-side',
    Complete = 'complete',
}

export enum ScanReviewPanelType {
    Upper = 'upper',
    Lower = 'lower',
    Left = 'left',
    Right = 'right',
    Front = 'front',
}

export enum ScanReviewDisplayType {
    Scan = 'scan',
    StoneModel = 'stone-model',
    BiteAnalysis = 'bite-analysis',
}

export interface ScanReviewRecordsFactory {
    (): ScanReviewRecords;
}

export interface ScanReviewRecords {
    lowerJaw: ScanReviewRecord | null;
    upperJaw: ScanReviewRecord | null;
}

export interface ScanReviewRecordFactory {
    (): ScanReviewRecord | null;
}

export class ScanReviewRecord {
    constructor(
        public readonly scanMesh: THREE.Mesh<THREE.BufferGeometry, THREE.Material>,
        public readonly scanMeshTexture: THREE.Texture,
        public readonly scanMeshMaterial: THREE.Material,
        public readonly scanMeshStoneMaterial?: THREE.Material,
        public readonly scanMeshHeatMapMaterial?: THREE.Material,
        public readonly updateHeatMapRange?: (vMin: number, vMax: number) => void,
    ) {}

    setVisible(visible: boolean) {
        this.scanMesh.visible = visible;
    }

    setScanDisplay() {
        this.scanMesh.material = this.scanMeshMaterial;
    }

    setStoneModelDisplay() {
        if (this.scanMeshStoneMaterial) {
            this.scanMesh.material = this.scanMeshStoneMaterial;
        }
    }

    setHeatMapDisplay() {
        if (this.scanMeshHeatMapMaterial) {
            this.scanMesh.material = this.scanMeshHeatMapMaterial;
        }
    }
}

export interface ScanReviewScene {
    setUpperJawVisibility(visible: boolean): void;
    setLowerJawVisibility(visible: boolean): void;
    setStoneModelDisplay(): void;
    setScanDisplay(): void;
    setHeatMapDisplay(): void;
    updateHeatMapRange(min: number, max: number): void;
}

export class ScanReviewPanelCameraState {
    constructor(
        public position: THREE.Vector3,
        public rotation: THREE.Euler,
        public up: THREE.Vector3,
        public zoom: number,
    ) {}
}

export class ScanReviewViewManager {
    constructor(
        public canvasRef: React.MutableRefObject<HTMLCanvasElement | null>,
        public cameraRef: React.MutableRefObject<THREE.OrthographicCamera | null>,
        public cameraControlsRef: MainViewCameraControlsRef,
        public viewState: ScanReviewViewState,
    ) {}

    initializeViewState(controls: ICameraControls, panelType: ScanReviewPanelType) {
        if (!controls) {
            return;
        }
        this.cameraControlsRef.current = controls;
        const panelViewState = this.viewState[panelType];
        if (panelViewState !== null) {
            controls.reset();
            controls.object.position.copy(panelViewState.position);
            controls.object.rotation.copy(panelViewState.rotation);
            controls.object.up.copy(panelViewState.up);
            controls.object.zoom = panelViewState.zoom;
            controls.object.updateProjectionMatrix();
            controls.update();
        }
        const eventTypes = ['change', 'start', 'end', 'zoomChange'];
        for (const eventType of eventTypes) {
            controls.addEventListener(eventType, event => {
                const eventControls = event.target as ICameraControls;
                const currentCameraState = new ScanReviewPanelCameraState(
                    eventControls.object.position.clone(),
                    eventControls.object.rotation.clone(),
                    eventControls.object.up.clone(),
                    eventControls.object.zoom,
                );
                this.viewState[panelType] = currentCameraState;
            });
        }
    }

    get canvas() {
        return this.canvasRef.current;
    }

    get camera() {
        return this.cameraRef.current;
    }

    get controls() {
        return this.cameraControlsRef.current;
    }
}

export class ScanReviewPartialScene implements ScanReviewScene {
    scene: THREE.Scene;

    constructor(
        public upperJaw: ScanReviewRecord | null,
        public lowerJaw: ScanReviewRecord | null,
    ) {
        this.scene = new THREE.Scene();
        if (upperJaw) {
            this.scene.add(upperJaw.scanMesh);
        }
        if (lowerJaw) {
            this.scene.add(lowerJaw.scanMesh);
        }
    }

    setUpperJawVisibility(visible: boolean) {
        this.upperJaw?.setVisible(visible);
    }

    setLowerJawVisibility(visible: boolean) {
        this.lowerJaw?.setVisible(visible);
    }

    setStoneModelDisplay() {
        this.upperJaw?.setStoneModelDisplay();
        this.lowerJaw?.setStoneModelDisplay();
    }

    setScanDisplay(): void {
        this.upperJaw?.setScanDisplay();
        this.lowerJaw?.setScanDisplay();
    }

    setHeatMapDisplay() {
        this.upperJaw?.setHeatMapDisplay();
        this.lowerJaw?.setHeatMapDisplay();
    }

    updateHeatMapRange(min: number, max: number) {
        this.upperJaw?.updateHeatMapRange?.(min, max);
        this.lowerJaw?.updateHeatMapRange?.(min, max);
    }
}

export class ScanReviewCompositeScene implements ScanReviewScene {
    public upperView: ScanReviewPartialScene;
    public lowerView: ScanReviewPartialScene;
    public leftView: ScanReviewPartialScene;
    public rightView: ScanReviewPartialScene;
    public frontView: ScanReviewPartialScene;

    constructor(recordsFactory: ScanReviewRecordsFactory) {
        this.upperView = new ScanReviewPartialScene(recordsFactory()?.upperJaw, null);
        this.lowerView = new ScanReviewPartialScene(null, recordsFactory()?.lowerJaw);
        this.leftView = new ScanReviewPartialScene(recordsFactory().upperJaw, recordsFactory()?.lowerJaw);
        this.rightView = new ScanReviewPartialScene(recordsFactory().upperJaw, recordsFactory()?.lowerJaw);
        this.frontView = new ScanReviewPartialScene(recordsFactory().upperJaw, recordsFactory()?.lowerJaw);
    }

    private *partialScenes() {
        for (const partialScene of [this.upperView, this.lowerView, this.leftView, this.rightView, this.frontView]) {
            yield partialScene;
        }
    }

    getPartialSceneForPanelType(panelType: ScanReviewPanelType): ScanReviewPartialScene {
        switch (panelType) {
            case ScanReviewPanelType.Upper: {
                return this.upperView;
            }
            case ScanReviewPanelType.Lower: {
                return this.lowerView;
            }
            case ScanReviewPanelType.Left: {
                return this.leftView;
            }
            case ScanReviewPanelType.Right: {
                return this.rightView;
            }
            case ScanReviewPanelType.Front: {
                return this.frontView;
            }
        }
    }

    setUpperJawVisibility(visible: boolean): void {
        for (const partialScene of this.partialScenes()) {
            partialScene.setUpperJawVisibility(visible);
        }
    }
    setLowerJawVisibility(visible: boolean): void {
        for (const partialScene of this.partialScenes()) {
            partialScene.setLowerJawVisibility(visible);
        }
    }
    setStoneModelDisplay() {
        for (const partialScene of this.partialScenes()) {
            partialScene.setStoneModelDisplay();
        }
    }
    setScanDisplay() {
        for (const partialScene of this.partialScenes()) {
            partialScene.setScanDisplay();
        }
    }
    setHeatMapDisplay() {
        for (const partialScene of this.partialScenes()) {
            partialScene.setHeatMapDisplay();
        }
    }
    updateHeatMapRange(min: number, max: number): void {
        for (const partialScene of this.partialScenes()) {
            partialScene.updateHeatMapRange(min, max);
        }
    }
}

export interface ScanReviewShadePick {
    /**
     * RGB value with each component in range of 0-255
     */
    color: [number, number, number];
    /**
     * The center of the sampling area on the surface of the mesh.
     */
    center: THREE.Vector3;
}

export class ScanReviewRecordAccelerationData {
    scanRecord: ScanReviewRecord;
    adjacencyMatrix: AdjacencyMatrix;
    bvhIndex: MeshBVH;
    constructor(scanRecord: ScanReviewRecord) {
        this.scanRecord = scanRecord;
        // Important:
        //  ensureMeshIndex can reorder index attribute of geometry, leaving
        //  previously calculated face-to-face adjacency information invalid.
        //  To alleviate this call ensureMeshIndex first.
        this.bvhIndex = ensureMeshIndex(scanRecord.scanMesh.geometry);
        this.adjacencyMatrix = buildMeshAdjacency(scanRecord.scanMesh.geometry);
    }
}

export class ScanReviewShadeMatchingPicker {
    private readonly lowerJawAccelerationData: ScanReviewRecordAccelerationData | null;
    private readonly upperJawAccelerationData: ScanReviewRecordAccelerationData | null;
    private jawAccelerationData: ScanReviewRecordAccelerationData | null = null;

    private readonly rayCaster: THREE.Raycaster = new THREE.Raycaster();

    constructor(
        public scene: ScanReviewPartialScene,
        public viewManager: ScanReviewViewManager,
    ) {
        this.lowerJawAccelerationData = scene.lowerJaw ? new ScanReviewRecordAccelerationData(scene.lowerJaw) : null;
        this.upperJawAccelerationData = scene.upperJaw ? new ScanReviewRecordAccelerationData(scene.upperJaw) : null;
    }

    setCurrentJawType(jawType: Jaw | null) {
        if (!jawType) {
            this.scene.setUpperJawVisibility(false);
            this.scene.setLowerJawVisibility(false);
            this.jawAccelerationData = null;
            return;
        }
        if (jawType === Jaw.UPPER) {
            this.scene.setUpperJawVisibility(true);
            this.scene.setLowerJawVisibility(false);
            this.jawAccelerationData = this.upperJawAccelerationData;
        } else {
            this.scene.setUpperJawVisibility(false);
            this.scene.setLowerJawVisibility(true);
            this.jawAccelerationData = this.lowerJawAccelerationData;
        }
    }

    respondToMouseEvent(evt: MouseEvent) {
        if (!this.viewManager.canvas || !this.viewManager.camera) {
            return;
        }
        updateRaycaster(this.rayCaster, this.viewManager.canvas, this.viewManager.camera, evt);
    }

    pickShadeFromVertexColors(maxRadiusMm: number): ScanReviewShadePick | null {
        if (!this.jawAccelerationData) {
            return null;
        }

        const intersection = this.jawAccelerationData.bvhIndex.raycastFirst(this.rayCaster.ray, THREE.FrontSide);

        // we did not click on the mesh
        if (!intersection || !intersection.face) {
            return null;
        }

        const adjacencyMatrix = this.jawAccelerationData.adjacencyMatrix;
        const geometry = this.jawAccelerationData.scanRecord.scanMesh.geometry;
        const neighbors = getNeighbors({
            adjacencyMatrix,
            mainHandle: intersection.face.a,
            maxRadiusMm: maxRadiusMm || 1,
            geometry,
        });

        // Find the sum of each of the red, green, and blue channels.
        // These can then be averaged to find the average color within the selected region.
        // We don't do any multiplication or floor until the end to avoid floating point math errors.
        // We intentionally square the r, g, and b values before adding them to the sum.
        // This is to approximately reverse the compression done to store rgb.
        // For more information, see: https://graphicdesign.stackexchange.com/questions/113884/calculating-average-of-two-rgb-values
        const rgbSums = neighbors.reduce<[number, number, number]>(
            (sums, vert) => {
                const r = geometry.getAttribute(AttributeName.Color)?.getX(vert) ?? 0;
                const g = geometry.getAttribute(AttributeName.Color)?.getY(vert) ?? 0;
                const b = geometry.getAttribute(AttributeName.Color)?.getZ(vert) ?? 0;

                return [sums[0] + r * r, sums[1] + g * g, sums[2] + b * b];
            },
            [0, 0, 0],
        );
        const rgb: [number, number, number] = [
            Math.floor(Math.sqrt(rgbSums[0] / neighbors.length) * 255),
            Math.floor(Math.sqrt(rgbSums[1] / neighbors.length) * 255),
            Math.floor(Math.sqrt(rgbSums[2] / neighbors.length) * 255),
        ];

        return {
            color: rgb,
            center: intersection.point,
        };
    }
}
