import type { GradientFunction } from './Gradient';
import type { Mesh } from 'three';
import { Vector3, Quaternion } from 'three';

const vertex = new Vector3();
const distanceFromCenterVector = new Vector3();
const direction = new Vector3();
const quat = new Quaternion();
const worldPosition = new Vector3();

type RGB = { r: number; g: number; b: number };

export const ApplyShade = (mesh: Mesh, dir: Vector3, gradientFunction: GradientFunction) => {
    const geometry = mesh.geometry;
    if (!geometry.attributes.position) {
        throw new Error('Geometry does not have position attribute');
    }
    const vertices = geometry.attributes.position.array;
    const colorAttribute = geometry.getAttribute('color');

    // Get the mesh world position & quaternion
    mesh.getWorldPosition(worldPosition);
    mesh.getWorldQuaternion(quat);

    direction.copy(dir).normalize();

    // 1. First Pass: Find min/max projected distances
    let minDistance = Infinity;
    let maxDistance = -Infinity;

    for (let i = 0; i < vertices.length; i += 3) {
        const x = vertices[i] ?? 0;
        const y = vertices[i + 1] ?? 0;
        const z = vertices[i + 2] ?? 0;

        vertex.set(x, y, z);
        vertex.applyQuaternion(quat);
        vertex.add(worldPosition);

        // Assume the center is the mesh world position because it is cheaper than computing it
        distanceFromCenterVector.subVectors(vertex, worldPosition);
        const projectedDistance = distanceFromCenterVector.dot(direction);

        minDistance = Math.min(minDistance, projectedDistance);
        maxDistance = Math.max(maxDistance, projectedDistance);
    }

    // 2. Second Pass: Calculate normalized distances and apply gradient
    const gradientCache = new Map<number, RGB>();
    const fastGradient = (distance: number, gradientFunction: GradientFunction): RGB => {
        const key = Math.floor(distance * 10);
        if (!gradientCache.has(key)) {
            const color = gradientFunction(key / 10);
            gradientCache.set(key, { r: color.r, g: color.g, b: color.b });
        }
        return gradientCache.get(key) ?? { r: 0, g: 0, b: 0 };
    };

    minDistance = minDistance === Infinity ? 0 : minDistance;
    maxDistance = maxDistance === -Infinity ? 0 : maxDistance;

    const range = maxDistance - minDistance;

    for (let i = 0; i < vertices.length; i += 3) {
        const vertexIndex = i / 3; // Calculate the correct vertex index
        const x = vertices[i] ?? 0;
        const y = vertices[i + 1] ?? 0;
        const z = vertices[i + 2] ?? 0;

        vertex.set(x, y, z);
        vertex.applyQuaternion(quat);
        vertex.add(worldPosition);

        // Assume the center is the mesh world position because it is cheaper than computing it
        distanceFromCenterVector.subVectors(vertex, worldPosition);
        const projectedDistance = distanceFromCenterVector.dot(direction);
        const normalizedDistance = (projectedDistance - minDistance) / range;

        const color = fastGradient(100 - normalizedDistance * 100, gradientFunction);
        colorAttribute.setXYZ(vertexIndex, color.r, color.g, color.b);
    }

    colorAttribute.needsUpdate = true;
};
