import { TEXTURE_SIZE } from './TextureLayer';
import type { BufferAttribute, Mesh } from 'three';
import { BufferGeometry, Color, Vector2, Vector3 } from 'three';

interface VertexIntensity {
    index: number;
    intensity: number;
}

interface NeighborInfo {
    index: number;
    kDistance: number;
    distance: number;
}

export interface VertexNeighborData {
    averageNormal: Vector3;
    averagePosition: Vector3;
    neighbors: NeighborInfo[];
}

const filterAngle = Math.PI / 1.8;

const occlusalColor = new Color('#A67447');

function buildAdjacencyList(geometry: BufferGeometry): Map<number, number[]> {
    const adjacencyList: Map<number, number[]> = new Map();
    const idx = geometry.index;
    if (!idx) {
        return adjacencyList;
    }

    for (let i = 0; i < idx.count; i += 3) {
        const face = idx.array.slice(i, i + 3);
        for (const vert of face) {
            if (!adjacencyList.has(vert)) {
                adjacencyList.set(vert, []);
            }
            const adj = adjacencyList.get(vert);
            face.forEach(v => {
                if (v !== vert && adj) {
                    adj.push(v);
                }
            });
        }
    }
    return adjacencyList;
}

function expandRings(
    vertexIndex: number,
    adjacencyList: Map<number, number[]>,
    k: number,
    position: BufferAttribute,
): NeighborInfo[] {
    const neighborInfo: NeighborInfo[] = [];
    const allNeighbors: Set<number> = new Set();
    let currentRing: number[] = [vertexIndex];
    const vPosition = new Vector3(position.getX(vertexIndex), position.getY(vertexIndex), position.getZ(vertexIndex));

    for (let ring = 0; ring <= k; ring++) {
        const nextRing: number[] = [];
        for (const currVert of currentRing) {
            const neighbors = adjacencyList.get(currVert) || [];
            for (const neighbor of neighbors) {
                if (!allNeighbors.has(neighbor) && neighbor !== vertexIndex) {
                    allNeighbors.add(neighbor);
                    const nPosition = new Vector3(
                        position.getX(neighbor),
                        position.getY(neighbor),
                        position.getZ(neighbor),
                    );
                    neighborInfo.push({
                        index: neighbor,
                        kDistance: ring,
                        distance: nPosition.distanceTo(vPosition),
                    });
                    nextRing.push(neighbor);
                }
            }
        }
        currentRing = nextRing;
    }
    return neighborInfo;
}

function calculateVertexNeighbors(geometry: BufferGeometry, k: number): Map<number, VertexNeighborData> {
    const vertices = geometry?.attributes?.position?.count ?? 0;
    const vertexNeighbors: Map<number, VertexNeighborData> = new Map();
    const adjacencyList = buildAdjacencyList(geometry);
    const position = geometry.getAttribute('position') as BufferAttribute;
    const normal = geometry.getAttribute('normal') as BufferAttribute;

    for (let v = 0; v < vertices; v++) {
        const neighborInfo = expandRings(v, adjacencyList, k, position);
        const { avgPosition, avgNormal } = computeAverages(neighborInfo, position, normal);
        vertexNeighbors.set(v, {
            averageNormal: avgNormal,
            averagePosition: avgPosition,
            neighbors: neighborInfo,
        });
    }
    return vertexNeighbors;
}

function computeAverages(
    neighborInfo: NeighborInfo[],
    position: BufferAttribute,
    normal: BufferAttribute,
): { avgPosition: Vector3; avgNormal: Vector3 } {
    const avgPosition = new Vector3();
    const avgNormal = new Vector3();

    neighborInfo.forEach(n => {
        const nPos = new Vector3(position.getX(n.index), position.getY(n.index), position.getZ(n.index));
        const nNorm = new Vector3(normal.getX(n.index), normal.getY(n.index), normal.getZ(n.index));
        avgPosition.add(nPos);
        avgNormal.add(nNorm);
    });

    if (neighborInfo.length > 0) {
        avgPosition.divideScalar(neighborInfo.length);
        avgNormal.divideScalar(neighborInfo.length).normalize();
    }

    return { avgPosition, avgNormal };
}

function calculateVertexIntensities(
    geometry: BufferGeometry,
    vertexNeighbors: Map<number, VertexNeighborData>,
    filterVec: Vector3,
): VertexIntensity[] {
    const vertices = geometry.getAttribute('position')?.count || 0;
    const intensities: VertexIntensity[] = [];
    const normal = geometry.attributes.normal;
    const position = geometry.attributes.position;

    for (let v = 0; v < vertices; v++) {
        const vertexNormal = new Vector3(normal?.getX(v), normal?.getY(v), normal?.getZ(v));
        const angle = vertexNormal.angleTo(filterVec);

        if (angle >= filterAngle) {
            intensities.push({ index: v, intensity: 0 });
            continue;
        }

        const neighborData = vertexNeighbors.get(v);
        if (!neighborData || neighborData.neighbors.length === 0) {
            intensities.push({ index: v, intensity: 0 });
            continue;
        }

        const vPosition = new Vector3(position?.getX(v), position?.getY(v), position?.getZ(v));

        const vertexVector = vPosition.clone().sub(neighborData.averagePosition);

        let intensity = 0;
        for (const neighbor of neighborData.neighbors) {
            const neighborNormal = new Vector3(
                normal?.getX(neighbor.index),
                normal?.getY(neighbor.index),
                normal?.getZ(neighbor.index),
            );
            const dotProduct = neighborData.averageNormal.dot(neighborNormal);
            intensity += 1 - dotProduct;
        }
        intensity /= neighborData.neighbors.length;

        // Sign the intensity
        const dotProduct = vertexVector.dot(neighborData.averageNormal);
        intensity = dotProduct >= 0 ? intensity : -intensity;

        intensities.push({ index: v, intensity });
    }

    return intensities;
}

export interface GrooveVertex {
    index: number;
    intensity: number;
    neighbors: NeighborInfo[];
}

function filterGrooveVertices(
    intensities: VertexIntensity[],
    vertexNeighbors: Map<number, VertexNeighborData>,
    grooveThreshold: number,
): GrooveVertex[] {
    return intensities
        .filter(v => {
            return Math.abs(v.intensity) > grooveThreshold && v.intensity < 0;
        })
        .map(v => ({
            index: v.index,
            intensity: v.intensity,
            neighbors: vertexNeighbors.get(v.index)?.neighbors || [],
        }));
}

export function detectGrooves(mesh: Mesh): Map<number, VertexNeighborData> {
    const geometry = mesh.geometry;
    geometry.computeVertexNormals();

    return calculateVertexNeighbors(geometry, 3);
}

function blurCanvas(ctx: CanvasRenderingContext2D, blurRadius: number) {
    if (blurRadius <= 0) {
        return;
    }

    const canvas = ctx.canvas;
    const blurredCanvas = document.createElement('canvas');
    blurredCanvas.width = canvas.width;
    blurredCanvas.height = canvas.height;
    const blurredCtx = blurredCanvas.getContext('2d');
    if (!blurredCtx) {
        return;
    }

    blurredCtx.filter = `blur(${blurRadius}px)`;
    blurredCtx.drawImage(canvas, 0, 0);
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    ctx.drawImage(blurredCanvas, 0, 0);
}

function updateOcclusalCanvas(
    geometry: BufferGeometry,
    grooveVertices: GrooveVertex[],
    fadeDistance: number,
    canvas: HTMLCanvasElement,
) {
    const uvAttribute = geometry.attributes.uv;
    if (!uvAttribute) {
        console.warn('Geometry has no UV attributes, cannot create groove texture.');
        return;
    }

    const ctx = canvas.getContext('2d');
    if (!ctx) {
        return;
    }

    ctx.clearRect(0, 0, TEXTURE_SIZE, TEXTURE_SIZE);

    ctx.fillStyle = occlusalColor.getStyle();
    for (const vert of grooveVertices) {
        const uv = new Vector2(uvAttribute.getX(vert.index), uvAttribute.getY(vert.index));
        const pixelX = Math.floor(uv.x * canvas.width);
        const pixelY = Math.floor((1 - uv.y) * canvas.height);
        ctx.fillRect(pixelX, pixelY, 5, 5);
    }

    blurCanvas(ctx, fadeDistance);
}

export function applyOcclusalStaining(
    mesh: Mesh,
    gradientDirection: Vector3,
    vertexNeighbors: Map<number, VertexNeighborData>,
    grooveThreshold: number,
    fadeDistance: number,
    canvas: HTMLCanvasElement,
): void {
    const geometry = mesh.geometry;
    if (!(geometry instanceof BufferGeometry)) {
        console.error('Geometry is not BufferGeometry, groove staining might not work correctly.');
        return;
    }
    const vertexIntensities = calculateVertexIntensities(geometry, vertexNeighbors, gradientDirection);
    const grooveVertices = filterGrooveVertices(vertexIntensities, vertexNeighbors, grooveThreshold);

    updateOcclusalCanvas(geometry, grooveVertices, fadeDistance, canvas);
}
