import { MeshPhysicalMaterial } from 'three';
import type { WebGLProgramParametersWithUniforms, MeshPhysicalMaterialParameters } from 'three';

class MeshPainterMaterial extends MeshPhysicalMaterial {
    private isGreyScale: boolean = false;
    private uniformsRef:
        | ({ isGreyScale?: { value: boolean } } & WebGLProgramParametersWithUniforms['uniforms'])
        | null = null;

    public setGreyScale(value: boolean) {
        this.isGreyScale = value;
        this.needsUpdate = true;

        // Update the uniform directly if the shader has been compiled
        if (this.uniformsRef && this.uniformsRef?.isGreyScale) {
            this.uniformsRef.isGreyScale.value = value;
        }
    }

    constructor(parameters: MeshPhysicalMaterialParameters) {
        super(parameters);

        this.onBeforeCompile = shader => {
            // Store a reference to the uniforms
            shader.uniforms.isGreyScale = { value: this.isGreyScale };
            this.uniformsRef = shader.uniforms;

            // Add uniform declaration at the beginning of the shader
            shader.fragmentShader = shader.fragmentShader.replace(
                'void main() {',
                `
                uniform bool isGreyScale;
    
                void main() {`,
            );

            // greyscale logic
            shader.fragmentShader = shader.fragmentShader.replace(
                '#include <color_fragment>',
                `
                #include <color_fragment>
                if (isGreyScale) {
                    float gray = dot(diffuseColor.rgb, vec3(0.299, 0.587, 0.114));
                    diffuseColor.rgb = vec3(gray);
                }
                `,
            );
        };
    }
}

export default MeshPainterMaterial;
