import type { QcHeatmapRange } from '../ColorRamp';
import { TriosMaterialFragmentShader } from '../ModelViewer/materials/scanMeshShaderMaterialChairside';
import { createScanMeshHeatmapMaterial, createScanMeshStoneMaterial } from './ScanReview.utils';
import { ScanReviewRecord, type ScanReviewRecordFactory } from './ScanReviewTypes';
import { type DcmGeometryInjector, AttributeName } from '@orthly/forceps';
import type { Jaw } from '@orthly/shared-types';
import type React from 'react';
import * as THREE from 'three';

export class ScanReviewDcmBuilder {
    private readonly scanImage: HTMLImageElement;
    private readonly scanGeometry: THREE.BufferGeometry;
    private readonly scanMesh: THREE.Mesh<THREE.BufferGeometry, THREE.Material>;
    private readonly scanMeshTexture: THREE.Texture;
    private readonly scanMeshMaterial: THREE.Material;
    private scanMeshStoneMaterial?: THREE.Material;
    private scanMeshHeatMapMaterial?: THREE.Material;
    private updateHeatmapRange?: (vMin: number, vMax: number) => void;

    constructor(scanGeometry: THREE.BufferGeometry, scanImage: HTMLImageElement) {
        this.scanGeometry = scanGeometry;
        this.scanImage = scanImage;

        this.scanMeshTexture = new THREE.Texture(this.scanImage);
        this.scanMeshTexture.flipY = false;
        this.scanMeshTexture.needsUpdate = true;

        const n1 = [0, 0, 1, 0];
        const n2 = [0, 1, 0, 0];
        const n3 = [1, 0, 0, 0];
        const n4 = [0, 0, 0, 1];
        //This permutation matrix is to reverse the rgb components of the texture packed into
        //the DCM file.
        const bgrSwap = new THREE.Matrix4().fromArray([...n1, ...n2, ...n3, ...n4]);

        this.scanMeshMaterial = new THREE.MeshPhongMaterial({
            map: this.scanMeshTexture,
            side: THREE.DoubleSide,
        });
        this.scanMeshMaterial.defines = {
            DANDY_COLOR_TRANSFORM: 1,
        };

        this.scanMeshMaterial.onBeforeCompile = shader => {
            shader.uniforms['vertexColorLerp'] = { value: 1.0 };
            shader.uniforms['ambientStrength'] = { value: 0.1 };
            shader.uniforms['specularStrength'] = { value: 0.352 };
            shader.uniforms['dandyColorTransform'] = { value: bgrSwap };
            shader.fragmentShader = TriosMaterialFragmentShader;
        };

        this.scanMesh = new THREE.Mesh<THREE.BufferGeometry, THREE.Material>(this.scanGeometry, this.scanMeshMaterial);
    }

    buildDcmScanStoneMaterial(): ScanReviewDcmBuilder {
        this.scanMeshStoneMaterial = createScanMeshStoneMaterial();
        return this;
    }

    buildDcmScanHeatmapMaterial(heatmapRange: QcHeatmapRange): ScanReviewDcmBuilder {
        const { scanMeshHeatMapMaterial, updateHeatmapRange } = createScanMeshHeatmapMaterial(heatmapRange);
        this.scanMeshHeatMapMaterial = scanMeshHeatMapMaterial;
        this.updateHeatmapRange = updateHeatmapRange;
        return this;
    }

    buildDcmVertexColors(): ScanReviewDcmBuilder {
        const canvas = document.createElement('canvas');
        canvas.width = this.scanImage.width;
        canvas.height = this.scanImage.height;
        const ctx = canvas.getContext('2d');
        if (!ctx) {
            console.warn('Warning, could not obtain html canvas context for baking DCM vertex colors.');
            return this;
        }
        ctx.drawImage(this.scanImage, 0, 0, this.scanImage.width, this.scanImage.height);
        const imageData = ctx.getImageData(0, 0, this.scanImage.width, this.scanImage.height);

        // Each vertex is "mapped" to a corresponding point on the texture via uv coordinates.
        // In this scheme, u and v are both numbers between [0, 1], which refer to an x, y position in the texture image.
        // Since there are two values per vertex, we can find the vertex count by dividing by 2.
        // We will create a color map of a r, g, and b byte for each vertex, hence needing vertexCount * 3.
        const uvs = this.scanGeometry.getAttribute(AttributeName.TexCoord).array;
        const vertexCount = uvs.length / 2;
        const colorMap = new Float32Array(vertexCount * 3);

        // Calculate the vertex color for each vertex.
        // vertexCount is generally about 100,000-200,000 for most of the scans we get.
        // This would be another good candidate for optimization if we run into speed issues, though in testing it ran pretty fast.
        for (let i = 0; i < vertexCount; i++) {
            const u = uvs[2 * i] ?? 0;
            const v = uvs[2 * i + 1] ?? 0;

            // u = 0 implies the left side of the image, u = 1 implies the right side.
            // v = 0 implies the top of the image, v = 1 implies the bottom edge.
            const xPos = Math.floor(this.scanImage.width * u);
            const yPos = Math.floor(this.scanImage.height * v);
            const idx = (this.scanImage.width * yPos + xPos) * 4;

            // Out of bounds checks.
            if (
                xPos < 0 ||
                yPos < 0 ||
                xPos >= this.scanImage.width ||
                yPos >= this.scanImage.height ||
                idx + 2 >= imageData.data.length
            ) {
                continue;
            }

            const r = (imageData.data[idx + 2] ?? 0) / 255;
            const g = (imageData.data[idx + 1] ?? 0) / 255;
            const b = (imageData.data[idx] ?? 0) / 255;
            colorMap.set([r, g, b], i * 3);
        }
        this.scanGeometry.setAttribute(AttributeName.Color, new THREE.Float32BufferAttribute(colorMap, 3));

        return this;
    }

    complete(): ScanReviewRecord | undefined {
        const scanReviewRecord =
            this.scanMesh &&
            this.scanMeshTexture &&
            this.scanMeshMaterial &&
            new ScanReviewRecord(
                this.scanMesh,
                this.scanMeshTexture,
                this.scanMeshMaterial,
                this.scanMeshStoneMaterial,
                this.scanMeshHeatMapMaterial,
                this.updateHeatmapRange,
            );
        return scanReviewRecord;
    }
}

export interface ScanReviewDcmFile {
    name: string;
    injector: DcmGeometryInjector;
    jawType: Jaw;
}

export interface ScanReviewDcmFileData {
    dcmFile: ScanReviewDcmFile;
    geometry: THREE.BufferGeometry | null;
    textureData: string | null;
}

function getGeometryFromDcm(injector: DcmGeometryInjector): THREE.BufferGeometry {
    const geometry = injector.buildGeometry({ applyTextureCoords: true });
    geometry.computeVertexNormals();
    return geometry;
}

export function extractScanReviewDcmFileData(
    jawType: Jaw,
    dcmFiles: ScanReviewDcmFile[],
): ScanReviewDcmFileData | null {
    const dcmFile = dcmFiles.find(dcm => dcm.jawType === jawType);
    if (!dcmFile) {
        return null;
    }
    const { injector } = dcmFile;
    const geometry = getGeometryFromDcm(injector);
    const textureImage = injector.parseTextureImages()[0] ?? null;
    const textureData = textureImage ? textureImage.b64Data : null;

    return {
        dcmFile,
        geometry,
        textureData,
    };
}

export function loadScanReviewDcmFileData(
    jawData: ScanReviewDcmFileData,
    onLoadCallback: (geometry: THREE.BufferGeometry, image: HTMLImageElement) => ScanReviewRecordFactory,
    factorySetter: (value: React.SetStateAction<ScanReviewRecordFactory | null>) => void,
) {
    if (!jawData.geometry || !jawData.textureData) {
        return;
    }
    const scanImage = new Image();
    const scanGeometry = jawData.geometry.clone();
    scanImage.onload = () => {
        factorySetter(() => onLoadCallback(scanGeometry, scanImage));
    };
    scanImage.src = `data:image/jpeg;base64,${jawData.textureData}`;
}
