import { toothRefAtom, gradientVecAtom, store, disableControlsAtom, selectedToolAtom } from '../store/store';
import { atomToObservable } from '../utils/atomToObservable';
import { FlossPalette } from '@orthly/ui-primitives';
import { fromEvent } from 'rxjs';
import { Subject, combineLatest } from 'rxjs';
import { map, takeUntil, filter, switchMap, tap } from 'rxjs/operators';
import type { Scene, Camera, WebGLRenderer } from 'three';
import { Vector3, SphereGeometry, MeshBasicMaterial, Mesh, ArrowHelper, Vector2, Box3, Raycaster, Group } from 'three';
import { Line2 } from 'three/examples/jsm/lines/Line2.js';
import { LineGeometry } from 'three/examples/jsm/lines/LineGeometry.js';
import { LineMaterial } from 'three/examples/jsm/lines/LineMaterial.js';

interface AxisConfig {
    direction: Vector3;
    color: number;
    name: string;
}

function createSphere(): Mesh {
    const geometry = new SphereGeometry(1, 32, 32);
    const material = new MeshBasicMaterial({
        transparent: true,
        opacity: 0.1,
        color: 0x888888,
    });
    return new Mesh(geometry, material);
}

function createArrow(): ArrowHelper {
    const direction = new Vector3(0, 1, 0);
    const origin = new Vector3(0, 0, 0);
    return new ArrowHelper(direction, origin, 1, FlossPalette.STAR_GRASS, 0.2, 0.1);
}

function createHandle(): Mesh {
    const geometry = new SphereGeometry(2, 16, 16);
    const material = new MeshBasicMaterial({
        transparent: true,
        opacity: 0.0,
    });
    return new Mesh(geometry, material);
}

export class GradientDirectionControl {
    private container: Group = new Group();
    private sphere: Mesh = createSphere();
    private arrow: ArrowHelper = createArrow();
    private handle: Mesh = createHandle();
    private raycaster: Raycaster = new Raycaster();
    private isInScene: boolean = false;
    private isDragging: boolean = false;
    private destroy$ = new Subject<void>();
    private axisLines: Line2[] = [];
    private readonly axes: AxisConfig[] = [
        {
            direction: new Vector3(0, 1, 0),
            color: 0x00ff00, // Green for Y
            name: 'y',
        },
    ];

    constructor(
        private scene: Scene,
        private camera: Camera,
        private renderer: WebGLRenderer,
    ) {
        // Add everything to the container
        this.container.add(this.sphere);
        this.container.add(this.arrow);
        this.container.add(this.handle);

        this.axisLines = this.createAxisLines(this.axes);
        this.axisLines.forEach(line => this.container.add(line));

        this.setupInteraction();
        this.setupActiveSubscription();
    }

    private createAxisLines(axes: Array<{ direction: Vector3; color: number; name: string }>): Line2[] {
        return axes.map(axis => {
            const geometry = new LineGeometry();
            geometry.setPositions([0, 0, 0, ...axis.direction.toArray()]);
            const material = new LineMaterial({
                color: axis.color,
                linewidth: 3,
                polygonOffset: true,
                polygonOffsetFactor: 1, // Small positive offset to render slightly behind arrow
                polygonOffsetUnits: 1,
            });
            const line = new Line2(geometry, material);
            line.name = axis.name;
            return line;
        });
    }

    private setupActiveSubscription() {
        const selectedTool$ = atomToObservable(selectedToolAtom);
        const targetMesh$ = atomToObservable(toothRefAtom);

        combineLatest([selectedTool$, targetMesh$])
            .pipe(takeUntil(this.destroy$))
            .subscribe(([selectedTool, targetMesh]) => {
                if (selectedTool === 'Gradient' && targetMesh && !this.isInScene) {
                    this.addToScene();
                    this.updatePosition(targetMesh);
                } else if ((selectedTool !== 'Gradient' || !targetMesh) && this.isInScene) {
                    this.removeFromScene();
                } else if (this.isInScene && targetMesh) {
                    this.updatePosition(targetMesh);
                }
            });
    }

    private updatePosition(targetMesh: Mesh) {
        const boundingBox = new Box3().setFromObject(targetMesh);
        const center = boundingBox.getCenter(new Vector3());
        const size = boundingBox.getSize(new Vector3());
        const maxDimension = Math.max(size.x, size.y, size.z);

        // Update container position instead of individual objects
        this.container.position.copy(center);

        // Scale sphere relative to container
        this.sphere.scale.setScalar(maxDimension * 1.2);

        // Update arrow length
        this.arrow.setLength(maxDimension * 1.2);

        // Update handle position relative to container
        const direction = store.get(gradientVecAtom);
        const handlePosition = direction.clone().multiplyScalar(maxDimension * 1.2);
        this.handle.position.copy(handlePosition);

        // Update axis lines scale
        this.axisLines.forEach(line => {
            line.scale.setScalar(maxDimension * 0.8);
        });
    }

    private addToScene() {
        this.scene.add(this.container);
        this.isInScene = true;
    }

    private removeFromScene() {
        this.scene.remove(this.container);
        this.isInScene = false;
    }

    public dispose() {
        this.destroy$.next();
        this.destroy$.complete();
        if (this.isInScene) {
            this.removeFromScene();
        }
        // Reset cursor style
        this.renderer.domElement.style.cursor = 'default';
    }

    private setGradientDirection(direction: Vector3, length: number) {
        this.arrow.setDirection(direction);
        this.handle.position.copy(direction.clone().multiplyScalar(length));
        store.set(gradientVecAtom, direction.clone());
    }

    private checkAxisIntersection(raycaster: Raycaster): Vector3 | undefined {
        const axisIntersects = raycaster.intersectObjects(this.axisLines);
        if (axisIntersects.length > 0 && axisIntersects[0]?.object instanceof Line2) {
            const intersectedLine = axisIntersects[0].object as Line2;
            const axisIndex = this.axisLines.findIndex(line => line.uuid === intersectedLine.uuid);
            return this.axes[axisIndex]?.direction ?? new Vector3(0, 1, 0);
        }
        return undefined;
    }

    private setupInteraction() {
        const canvas = this.renderer.domElement;

        // Use capture phase for events
        const mouseDown$ = fromEvent<MouseEvent>(canvas, 'mousedown', { capture: true });
        const mouseMove$ = fromEvent<MouseEvent>(canvas, 'mousemove', { capture: true });
        const mouseUp$ = fromEvent<MouseEvent>(canvas, 'mouseup', { capture: true });

        // Handle cursor style on hover
        mouseMove$
            .pipe(
                takeUntil(this.destroy$),
                filter(() => this.isInScene && !this.isDragging),
                tap(event => {
                    const rect = canvas.getBoundingClientRect();
                    const x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
                    const y = -((event.clientY - rect.top) / rect.height) * 2 + 1;

                    this.raycaster.setFromCamera(new Vector2(x, y), this.camera);
                    const handleIntersects = this.raycaster.intersectObject(this.handle);
                    const axisIntersects = this.raycaster.intersectObjects(this.axisLines);

                    if (handleIntersects.length > 0 || axisIntersects.length > 0) {
                        event.stopPropagation();
                        event.preventDefault();
                        canvas.style.cursor = 'pointer';
                    } else {
                        canvas.style.cursor = 'default';
                    }
                }),
            )
            .subscribe();

        // Handle dragging
        mouseDown$
            .pipe(
                filter(() => this.isInScene),
                map(event => {
                    const rect = canvas.getBoundingClientRect();
                    const x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
                    const y = -((event.clientY - rect.top) / rect.height) * 2 + 1;

                    this.raycaster.setFromCamera(new Vector2(x, y), this.camera);

                    // Check for handle intersections first
                    const handleIntersects = this.raycaster.intersectObject(this.handle);
                    if (handleIntersects.length > 0) {
                        return handleIntersects[0];
                    }

                    // If no handle intersection, check for axis line intersections
                    const direction = this.checkAxisIntersection(this.raycaster);
                    if (direction) {
                        this.setGradientDirection(direction, this.sphere.scale.x);
                        return null;
                    }

                    return null;
                }),
                filter(intersection => !!intersection),
                switchMap(() => {
                    this.isDragging = true;
                    canvas.style.cursor = 'grabbing';
                    store.set(disableControlsAtom, true);

                    return mouseMove$.pipe(
                        takeUntil(
                            mouseUp$.pipe(
                                tap(() => {
                                    this.isDragging = false;
                                    canvas.style.cursor = 'pointer';
                                    store.set(disableControlsAtom, false);
                                }),
                            ),
                        ),
                        tap(event => {
                            // Stop propagation for all move events while dragging
                            event.stopPropagation();
                            event.preventDefault();
                        }),
                        map(event => {
                            const rect = canvas.getBoundingClientRect();
                            const x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
                            const y = -((event.clientY - rect.top) / rect.height) * 2 + 1;

                            this.raycaster.setFromCamera(new Vector2(x, y), this.camera);
                            return this.raycaster.intersectObject(this.sphere)[0];
                        }),
                        filter(intersection => !!intersection),
                    );
                }),
            )
            .subscribe(intersection => {
                // Convert intersection point to local space
                const localPoint = this.container.worldToLocal(intersection.point.clone());
                const direction = localPoint.normalize();
                this.setGradientDirection(direction, this.sphere.scale.x);
            });
    }
}
