import type { InitialCameraProps } from '../ModelViewer';
import type { ModelPayload } from '../ModelViewer';
import { TriosMaterialFragmentShader } from '../ModelViewer/materials/scanMeshShaderMaterialChairside';
import type { ScanSnapshotGeometries, ScanSnapshotViewingVolumePlanes, ScanSnapshotModels } from './ScanSnapshot.types';
import { CameraPoseType } from './ScanSnapshot.types';
import { ScanSnapshotCameraPropsConstructor } from './ScanSnapshotCameraPropsConstructor';
import * as THREE from 'three';

const BYTES_PER_PIXEL = 4;
const DEFAULT_WIDTH = 400 as const;

/**
 * This class is responsible for computing a scan snapshot image from a set of geometries
 * given a particular scan viewing angle.
 */
export class ScanSnapshotComputer {
    private readonly _target: THREE.WebGLRenderTarget;
    private readonly _camera: THREE.OrthographicCamera;
    private readonly _renderWidth: number;
    private readonly _aspectRatio: number;
    private readonly _renderHeight: number;
    private _scene: THREE.Scene = new THREE.Scene();
    private _renderer: THREE.WebGLRenderer;
    private _disposed: boolean = false;

    constructor(poseType: CameraPoseType, models: ScanSnapshotModels, renderWidth: number = DEFAULT_WIDTH) {
        this._renderer = new THREE.WebGLRenderer();

        const cameraProps = this.getCameraProps(
            {
                upperArch: models.upperModel?.model.geometry,
                lowerArch: models.lowerModel?.model.geometry,
            },
            poseType,
        );
        this.addArchesToScene(models, poseType);
        this._camera = new THREE.OrthographicCamera(
            -cameraProps.width / 2.0,
            cameraProps.width / 2.0,
            cameraProps.height / 2.0,
            -cameraProps.height / 2.0,
            cameraProps.near,
            cameraProps.far,
        );
        this._scene.add(this._camera);
        this.setupCamera(cameraProps);

        this._renderWidth = renderWidth;
        this._aspectRatio = cameraProps.height / cameraProps.width;
        this._renderHeight = Math.ceil(this._renderWidth * this._aspectRatio);
        this._target = new THREE.WebGLRenderTarget(this._renderWidth, this._renderHeight, {
            format: THREE.RGBAFormat,
            type: THREE.UnsignedByteType,
            minFilter: THREE.NearestFilter,
            magFilter: THREE.NearestFilter,
        });
    }

    compute(): string {
        if (this._disposed) {
            throw new Error('Attempted to compute on a disposed instance.');
        }

        if (this._renderer.getContext().isContextLost()) {
            console.warn('Renderer context was lost. Reinitializing new renderer.');
            this._renderer.dispose();

            this._renderer = new THREE.WebGLRenderer();
        }

        this._renderer.setRenderTarget(this._target);
        this._renderer.render(this._scene, this._camera);

        const pixelData = new Uint8ClampedArray(this._renderWidth * this._renderHeight * BYTES_PER_PIXEL);

        try {
            this._renderer.readRenderTargetPixels(this._target, 0, 0, this._renderWidth, this._renderHeight, pixelData);
        } catch (error) {
            console.error('computeScanSnapshot readRenderPixels error', error);
        }

        const flippedPixelData = this.getFlippedPixelData(pixelData);
        return this.getImageUrl(flippedPixelData);
    }

    dispose(): void {
        this._renderer.forceContextLoss();
        this._renderer.dispose();
        this._target.dispose();
        this._disposed = true;
    }

    private setupCamera(cameraProps: InitialCameraProps): void {
        this._camera.position.set(cameraProps.position.x, cameraProps.position.y, cameraProps.position.z);
        this._camera.setRotationFromEuler(cameraProps.rotation);
        this._camera.updateProjectionMatrix();
    }

    private addArchesToScene(models: ScanSnapshotModels, poseType: CameraPoseType): void {
        function createMeshFromGeometry(payload: ModelPayload) {
            payload.model.geometry.computeVertexNormals();

            const scanMeshMaterial = new THREE.MeshPhongMaterial({
                vertexColors: !payload.colorMap,
                map: payload.colorMap,
                side: THREE.DoubleSide,
            });

            scanMeshMaterial.onBeforeCompile = shader => {
                shader.uniforms['vertexColorLerp'] = { value: 1.0 };
                shader.uniforms['ambientStrength'] = { value: 0.1 };
                shader.uniforms['specularStrength'] = { value: 0.352 };
                shader.fragmentShader = TriosMaterialFragmentShader;
            };

            return new THREE.Mesh(payload.model.geometry, scanMeshMaterial);
        }

        if (poseType !== CameraPoseType.Upper && models.lowerModel) {
            this._scene.add(createMeshFromGeometry(models.lowerModel));
        }
        if (poseType !== CameraPoseType.Lower && models.upperModel) {
            this._scene.add(createMeshFromGeometry(models.upperModel));
        }
    }

    private getCameraProps(
        geometries: ScanSnapshotGeometries,
        poseType: CameraPoseType,
    ): ScanSnapshotViewingVolumePlanes {
        const constructor = new ScanSnapshotCameraPropsConstructor(poseType, geometries);
        return constructor.construct();
    }

    private getFlippedPixelData(pixelData: Uint8ClampedArray): Uint8ClampedArray {
        const flippedData = pixelData;
        const newPixelData = new Uint8ClampedArray(this._renderWidth * this._renderHeight * BYTES_PER_PIXEL);
        let pixelDataOffset = 0;
        for (let y = this._renderHeight - 1; y >= 0; y--) {
            for (let x = 0; x < this._renderWidth; x++) {
                const offset = y * this._renderWidth * BYTES_PER_PIXEL + x * BYTES_PER_PIXEL;
                const r = flippedData[offset] ?? 0;
                const g = flippedData[offset + 1] ?? 0;
                const b = flippedData[offset + 2] ?? 0;
                const a = flippedData[offset + 3] ?? 0;
                newPixelData[pixelDataOffset] = r;
                newPixelData[pixelDataOffset + 1] = g;
                newPixelData[pixelDataOffset + 2] = b;
                newPixelData[pixelDataOffset + 3] = a;
                pixelDataOffset += BYTES_PER_PIXEL;
            }
        }
        return newPixelData;
    }

    private getImageUrl(pixelData: Uint8ClampedArray): string {
        const imageData = new ImageData(pixelData, this._renderWidth, this._renderHeight);

        const canvas = document.createElement('canvas');
        canvas.width = this._renderWidth;
        canvas.height = this._renderHeight;
        const ctx = canvas.getContext('2d');
        ctx?.putImageData(imageData, 0, 0);
        return canvas.toDataURL();
    }
}
