import { ClosestPointComputer } from '../GpuAccel';
import type { MainViewCameraControlsRef } from '../ModelViewer';
import { initializeDistanceAttribute } from './ScanReview.utils';
import {
    type ScanReviewRecordFactory,
    type ScanReviewRecordsFactory,
    type ScanReviewViewState,
    ScanReviewViewManager,
    ScanReviewCompositeScene,
} from './ScanReviewTypes';
import { AttributeName, type VertexResult } from '@orthly/forceps';
import React from 'react';
import type * as THREE from 'three';

export function useViewManager(viewState: ScanReviewViewState) {
    const canvasRef: React.MutableRefObject<HTMLCanvasElement | null> = React.useRef(null);
    const cameraRef: React.MutableRefObject<THREE.OrthographicCamera | null> = React.useRef(null);
    const cameraControlsRef: MainViewCameraControlsRef = React.useRef(null);

    const viewManager = React.useMemo(() => {
        return new ScanReviewViewManager(canvasRef, cameraRef, cameraControlsRef, viewState);
    }, [viewState]);
    return viewManager;
}

export function useScanReviewRecordsFactory(
    lowerJawFactory: ScanReviewRecordFactory | null,
    upperJawFactory: ScanReviewRecordFactory | null,
): ScanReviewRecordsFactory {
    const distancesCache = React.useMemo(() => {
        return new Map<string, VertexResult[]>();
    }, []);
    return React.useCallback(() => {
        const lowerJaw = lowerJawFactory?.() ?? null;
        const upperJaw = upperJawFactory?.() ?? null;

        initializeDistanceAttribute(lowerJaw?.scanMesh?.geometry);
        initializeDistanceAttribute(upperJaw?.scanMesh?.geometry);

        if (lowerJaw && upperJaw) {
            for (const { queryMesh, referenceMesh } of [
                { queryMesh: lowerJaw.scanMesh.geometry, referenceMesh: upperJaw.scanMesh.geometry },
                { queryMesh: upperJaw.scanMesh.geometry, referenceMesh: lowerJaw.scanMesh.geometry },
            ]) {
                const cacheKey = `${queryMesh.uuid}|${referenceMesh.uuid}`;
                const cachedDistances = distancesCache.get(cacheKey);
                const distancesAttribute = queryMesh.getAttribute(AttributeName.OcclusalDistance);

                if (cachedDistances) {
                    (distancesAttribute.array as Float32Array).set(cachedDistances.map(el => el.signedDistance));
                    distancesAttribute.needsUpdate = true;
                    continue;
                }

                const computer = new ClosestPointComputer(referenceMesh, queryMesh, undefined, true);
                const distances = computer.compute();
                if (!distances) {
                    computer.dispose();
                    continue;
                }

                (distancesAttribute.array as Float32Array).set(distances.map(el => el.signedDistance));
                distancesAttribute.needsUpdate = true;
                computer.dispose();

                distancesCache.set(cacheKey, distances);
            }
        }
        return {
            lowerJaw: lowerJaw,
            upperJaw: upperJaw,
        };
    }, [distancesCache, lowerJawFactory, upperJawFactory]);
}

export function useCompositeScene(
    lowerJawFactory: ScanReviewRecordFactory | null,
    upperJawFactory: ScanReviewRecordFactory | null,
) {
    const scanRecordsFactory = useScanReviewRecordsFactory(lowerJawFactory, upperJawFactory);
    const scene = React.useMemo(() => {
        return new ScanReviewCompositeScene(scanRecordsFactory);
    }, [scanRecordsFactory]);
    return scene;
}
