import { ZAxis } from '../../utils';
import type { QcHeatmapRange } from '../ColorRamp';
import { DepthMapComputer } from '../GpuAccel/ComputeDepthMap';
import { applyShader } from '../ModelViewer/materials/UndercutDepthShader';
import { createOcclusalHeatmapShader } from '../PrepMaterials/OcclusalShader';
import { DandyShaderMaterial } from '../PrepMaterials/dandyShaderMaterial';
import type { ScanReviewInsertionAxis } from './ScanReviewDesignTypes';
import {
    INITIAL_SCAN_REVIEW_UNDERCUT_HEATMAP_RANGE,
    INITIAL_SCAN_REVIEW_BITE_ANALYSIS_HEATMAP_RANGE,
} from './ScanReviewViewTypes';
import { AttributeName } from '@orthly/forceps';
import * as THREE from 'three';

export interface ScanReviewInsertionDepthMapGeneratorFactory {
    (scanGeometry: THREE.BufferGeometry, insertionAxis: THREE.Vector3): ScanReviewInsertionDepthMapGenerator;
}

export interface ScanReviewDepthData {
    texture: THREE.Texture;
    texSize: THREE.Vector2;
    axisSpaceMatrix: THREE.Matrix4;
    lateralScale: number;
    depthScale: number;
}
const DEPTH_MAP_BUFFER_MM = 1.0;

export class ScanReviewInsertionDepthMapGenerator {
    private readonly box = new THREE.Box3();
    private readonly orientation = new THREE.Quaternion();
    private readonly renderer = new THREE.WebGLRenderer({ antialias: false });
    public readonly depthData: ScanReviewDepthData;

    constructor(
        private geometry: THREE.BufferGeometry,
        private insertionAxis: THREE.Vector3,
        private depthMapCache: Map<string, ScanReviewDepthData>,
    ) {
        this.depthData = this.generate();
    }

    generate() {
        const directionKey = `[${this.insertionAxis.x},${this.insertionAxis.y},${this.insertionAxis.z}]`;
        const geometryKey = `[${this.geometry.uuid}]`;
        const cacheKey = `${directionKey}${geometryKey}`;
        const cachedDepthData = this.depthMapCache.get(cacheKey);
        if (cachedDepthData) {
            this.renderer.forceContextLoss();
            this.renderer.dispose();
            return cachedDepthData;
        }
        this.orientation.setFromUnitVectors(this.insertionAxis, ZAxis.clone().negate());

        this.updateOrientedBoundingBox();
        this.makeSquareXY();
        this.expandBoxZ();
        this.box.expandByScalar(DEPTH_MAP_BUFFER_MM);

        const depthData = this.computeDepthMap(camera => this.setCameraParamsFromOrientedBox(camera));
        this.depthMapCache.set(cacheKey, depthData);

        this.renderer.forceContextLoss();
        this.renderer.dispose();

        return depthData;
    }

    private updateOrientedBoundingBox() {
        const posAttr = this.geometry.getAttribute(AttributeName.Position);
        const vec = new THREE.Vector3();
        for (let vIdx = 0; vIdx < posAttr.count; vIdx += 1) {
            vec.fromBufferAttribute(posAttr, vIdx).applyQuaternion(this.orientation);
            this.box.expandByPoint(vec);
        }
    }

    private makeSquareXY() {
        const w = this.box.max.x - this.box.min.x;
        const h = this.box.max.y - this.box.min.y;
        const halfDiff = (w - h) / 2;
        if (halfDiff > 0) {
            this.box.max.y += halfDiff;
            this.box.min.y -= halfDiff;
        } else {
            this.box.max.x -= halfDiff;
            this.box.min.x += halfDiff;
        }
    }

    private expandBoxZ() {
        const posAttr = this.geometry.getAttribute(AttributeName.Position);
        const vec = new THREE.Vector3();
        for (let vIdx = 0; vIdx < posAttr.count; vIdx += 1) {
            vec.fromBufferAttribute(posAttr, vIdx).applyQuaternion(this.orientation);
            if (vec.x < this.box.min.x || this.box.max.x < vec.x || vec.y < this.box.min.y || this.box.max.y < vec.y) {
                continue;
            }
            this.box.min.z = Math.min(this.box.min.z, vec.z);
            this.box.max.z = Math.max(this.box.max.z, vec.z);
        }
    }

    private setCameraParamsFromOrientedBox(camera: THREE.OrthographicCamera) {
        camera.quaternion.copy(this.orientation).invert();
        // Set camera position to center of max z face
        this.box.getCenter(camera.position).setZ(this.box.max.z).applyQuaternion(camera.quaternion);
        const size = Math.max(this.box.max.x - this.box.min.x, this.box.max.y - this.box.min.y);
        camera.left = -size / 2;
        camera.right = size / 2;
        camera.bottom = -size / 2;
        camera.top = size / 2;
        camera.near = 0;
        camera.far = this.box.max.z - this.box.min.z;
        camera.updateProjectionMatrix();
        camera.updateMatrixWorld();
        return { lateralScale: size, depthScale: camera.far - camera.near };
    }

    private computeDepthMap(
        setCamera: (camera: THREE.OrthographicCamera) => { lateralScale: number; depthScale: number },
    ) {
        const depthMapComputer = new DepthMapComputer([new THREE.Mesh(this.geometry)], {
            renderer: this.renderer,
            resolution: 4096,
        });
        try {
            const { lateralScale, depthScale } = setCamera(depthMapComputer.camera);
            depthMapComputer.compute();
            const texture = depthMapComputer.texture.clone();
            texture.needsUpdate = true;
            const matrix = depthMapComputer.camera.projectionMatrix
                .clone()
                .multiply(depthMapComputer.camera.matrixWorldInverse);
            return {
                texture,
                axisSpaceMatrix: matrix,
                lateralScale,
                depthScale,
                texSize: new THREE.Vector2(texture.image.width, texture.image.height),
            };
        } finally {
            depthMapComputer.dispose();
        }
    }
}

abstract class ScanReviewHeatmapMaterialManagerBase {
    protected constructor(
        protected readonly _material: THREE.Material,
        protected readonly _vMinUniform: THREE.Uniform,
        protected readonly _vMaxUniform: THREE.Uniform,
        protected readonly _heatMapRange: QcHeatmapRange,
    ) {}

    set heatmapRange(newHeatmapRange: QcHeatmapRange) {
        this._vMinUniform.value = newHeatmapRange.min;
        this._vMaxUniform.value = newHeatmapRange.max;
        this._material.needsUpdate = true;
    }

    get heatmapRange() {
        return this._heatMapRange;
    }

    get material() {
        return this._material;
    }
}

export class ScanReviewHeatmapMaterialManager extends ScanReviewHeatmapMaterialManagerBase {
    constructor() {
        const scanMeshHeatMapShader = createOcclusalHeatmapShader({
            showHeatmap: true,
        });

        const vMin = INITIAL_SCAN_REVIEW_BITE_ANALYSIS_HEATMAP_RANGE.min;
        const vMax = INITIAL_SCAN_REVIEW_BITE_ANALYSIS_HEATMAP_RANGE.max;

        const vMinUniform = new THREE.Uniform(vMin);
        const vMaxUniform = new THREE.Uniform(vMax);

        const diffuse = new THREE.Color(151 / 255, 145 / 255, 122 / 255);
        const roughness = 0.5;

        scanMeshHeatMapShader.uniforms = {
            ...THREE.UniformsUtils.clone(scanMeshHeatMapShader.uniforms),
            diffuse: { value: diffuse },
            roughness: { value: roughness },
            vMin: vMinUniform,
            vMax: vMaxUniform,
        };

        const material = new DandyShaderMaterial({
            ...scanMeshHeatMapShader,
            side: THREE.DoubleSide,
        });

        super(material, vMinUniform, vMaxUniform, { min: vMin, max: vMax });
    }
}

export class ScanReviewUndercutMaterialManager extends ScanReviewHeatmapMaterialManagerBase {
    private readonly axisSpaceMatrixUniform: THREE.Uniform;
    private readonly lateralScaleUniform: THREE.Uniform;
    private readonly depthScaleUniform: THREE.Uniform;
    private readonly depthMapSizeUniform: THREE.Uniform;
    private readonly depthMapUniform: THREE.Uniform;
    private readonly centerUniform: THREE.Uniform;
    private readonly radiusSquaredUniform: THREE.Uniform;
    private readonly enableCenterOfInterestUniform: THREE.Uniform;

    constructor(
        public readonly scanGeometry: THREE.BufferGeometry,
        material: THREE.Material,
        insertionAxis: ScanReviewInsertionAxis,
        private readonly defaultColor: THREE.Color,
        private readonly colorRampTexture: THREE.Texture | undefined,
        private readonly insertionDepthMapGeneratorFactory: ScanReviewInsertionDepthMapGeneratorFactory,
    ) {
        const vMin = INITIAL_SCAN_REVIEW_UNDERCUT_HEATMAP_RANGE.min;
        const vMax = INITIAL_SCAN_REVIEW_UNDERCUT_HEATMAP_RANGE.max;

        const vMinUniform = new THREE.Uniform(vMin);
        const vMaxUniform = new THREE.Uniform(vMax);

        super(material, vMinUniform, vMaxUniform, { min: vMin, max: vMax });

        const insertionDepthGenerator = this.insertionDepthMapGeneratorFactory(
            this.scanGeometry,
            insertionAxis.direction,
        );
        const depthData = insertionDepthGenerator.depthData;

        this.axisSpaceMatrixUniform = new THREE.Uniform(depthData.axisSpaceMatrix);
        this.lateralScaleUniform = new THREE.Uniform(depthData.lateralScale);
        this.depthScaleUniform = new THREE.Uniform(depthData.depthScale);
        this.depthMapSizeUniform = new THREE.Uniform(depthData.texSize);
        this.depthMapUniform = new THREE.Uniform(depthData.texture);
        this.centerUniform = new THREE.Uniform(insertionAxis.position ? insertionAxis.position : new THREE.Vector3());
        this.radiusSquaredUniform = new THREE.Uniform(insertionAxis.maxDistance * insertionAxis.maxDistance);
        this.enableCenterOfInterestUniform = new THREE.Uniform(insertionAxis.position !== undefined);

        const undercutUniforms = {
            defaultColor: new THREE.Uniform(this.defaultColor),
            axisSpaceMatrix: this.axisSpaceMatrixUniform,
            lateralScale: this.lateralScaleUniform,
            lateralBuffer: new THREE.Uniform(0.25),
            disableGating: new THREE.Uniform(false),
            vMin: this._vMinUniform,
            vMax: this._vMaxUniform,
            depthScale: this.depthScaleUniform,
            depthFadeDistance: new THREE.Uniform(0.4),
            depthMapSampleSize: new THREE.Uniform(2),
            overlayOpacity: new THREE.Uniform(0.6),
            depthMapSize: this.depthMapSizeUniform,
            depthMap: this.depthMapUniform,
            colorRamp: new THREE.Uniform(this.colorRampTexture),
            enableCenterOfInterest: this.enableCenterOfInterestUniform,
            center: this.centerUniform,
            radiusSquared: this.radiusSquaredUniform,
        };
        applyShader(this.material, undercutUniforms);
    }

    set insertionAxis(newInsertionAxis: ScanReviewInsertionAxis) {
        const insertionDepthGenerator = this.insertionDepthMapGeneratorFactory(
            this.scanGeometry,
            newInsertionAxis.direction,
        );
        const depthData = insertionDepthGenerator.depthData;

        this.axisSpaceMatrixUniform.value = depthData.axisSpaceMatrix;
        this.lateralScaleUniform.value = depthData.lateralScale;
        this.depthScaleUniform.value = depthData.depthScale;
        this.depthMapSizeUniform.value = depthData.texSize;
        this.depthMapUniform.value = depthData.texture;
        this.enableCenterOfInterestUniform.value = newInsertionAxis.position !== undefined;
        this.centerUniform.value = newInsertionAxis.position ? newInsertionAxis.position : new THREE.Vector3();
        this.radiusSquaredUniform.value = newInsertionAxis.maxDistance * newInsertionAxis.maxDistance;

        this._material.needsUpdate = true;
    }
}
