import BlueNoiseImage from '../../../../assets/images/BlueNoise256.png';
import type { UniformsObject } from './ShaderBuilder';
import { makeShaderBuilder, type ShaderReplacement } from './ShaderBuilder';
import * as THREE from 'three';

const _publicUniformDecls = [
    'vec3 defaultColor',
    'mat4 axisSpaceMatrix',
    'float vMin',
    'float vMax',
    'float lateralScale',
    'float depthFadeDistance',
    'float depthScale',
    'float overlayOpacity',
    'float lateralBuffer',
    'vec2 depthMapSize',
    'sampler2D depthMap',
    'sampler2D colorRamp',
] as const;
const _internalUniformDecls = ['vec2 blueNoiseSize', 'sampler2D blueNoise'] as const;

const _uniformDecls = [..._publicUniformDecls, ..._internalUniformDecls] as const;

const _vertReplacements: ShaderReplacement[] = [
    [
        '#include <common>',
        /*glsl*/ `
varying vec4 vDepthMapPos;
`,
    ],
    [
        '#include <fog_vertex>',
        /*glsl*/ `
vDepthMapPos = axisSpaceMatrix * vec4(position, 1.0);
`,
    ],
];

const _fragmentReplacements: ShaderReplacement[] = [
    [
        '#include <common>',
        /*glsl*/ `
varying vec4 vDepthMapPos;

vec3 sampleColorRamp(float x) {
    return texture(colorRamp, vec2(x, 0.0)).xyz;
}

vec3 sampleBlueNoise() {
    // gl_FragCoord is in pixels, dividing by the blue noise texture size will result in the blue noise
    // tiling over the entire viewport
    return texture(blueNoise, gl_FragCoord.xy / blueNoiseSize).xyz;
}

// Returns value in range [-1.0, 1.0] indicating undercut depth at the specified axis-space coordinates.
// Positive is "below surface", negative is "above surface".
float queryDepthBelow(vec3 depthCoords) {
    float undercutDepth = 1.0 - texture(depthMap, depthCoords.xy).r;
    return depthCoords.z - undercutDepth;
}

float getDepthFade(vec3 depthCoords, float fadeDistance) {
    float lowDiff = 1.0 / 255.0;
    vec2 u = normalize(sampleBlueNoise().xy);
    vec2 v = vec2(-u.y, u.x);
    float fade = 0.0;
    for (int y = -1; y <= 1; y++) {
        for (int x = -1; x <= 1; x++) {
            vec2 offset = (float(x) * u + float(y) * v) / depthMapSize;
            fade += smoothstep(0.0, fadeDistance, queryDepthBelow(depthCoords + vec3(offset, 0.0)) - lowDiff);
        }
    }
    return smoothstep(0.0, 9.0, fade);
}

#define M_PI 3.141592653589793238462

const int NUM_ANGULAR_SAMPLES = 8;
const int NUM_RADIAL_SAMPLES = 8;
// In the below method a sort of variant on a binary search is used to determine escape distance.
// At each iteration NUM_ANGULAR_SAMPLES are sampled in a circle around the current query point at
// the search range midpoint distance. Depending on the number of samples that are above the jaw
// scan depth (meaning not in undercut), the search range is adjusted to be higher or lower. These
// matrices perform the range adjustment. Given our search range as a vec3(lo, mid, hi), if we're
// "above", then mid is too large and range should become vec3(lo, 0.5 * (lo + mid), mid). Likewise
// if "below", then mid is too small and range should become vec3(mid, 0.5 * (mid + hi), hi). I.e.
// range = aboveMat * range; in the "above" case
// range = belowMat * range; in the "below" case
// Further, "above" and "below" are fuzzy concepts. This is encoded by the following constants which
// define a "definitely below" and "definitely above" threshold for depth comparisons.
const float DEFINITELY_BELOW_THRESH = -0.01;
const float DEFINITELY_ABOVE_THRESH = 0.01;

// GLSL has matrices initialized in column-major order
//            [1.0, 0.0, 0.0]
// aboveMat = [0.5, 0.5, 0.0]
//            [0.0, 1.0, 0.0]
const mat3 aboveMat = mat3(1.0, 0.5, 0.0, 0.0, 0.5, 1.0, 0.0, 0.0, 0.0);
//            [0.0, 1.0, 0.0]
// belowMat = [0.0, 0.5, 0.5]
//            [0.0, 0.0, 1.0]
const mat3 belowMat = mat3(0.0, 0.0, 0.0, 1.0, 0.5, 0.0, 0.0, 0.5, 1.0);
float queryLateralDistance(vec3 depthCoords, float minDist, float maxDist) {
    // Represent range in a normalized way, we'll multiply by maxDist when sampling.
    vec3 range = vec3(0.0, 0.5, 1.0);
    // Grab a blue noise sample for the current fragment and use it to rotate our samples by an
    // arbitrary pseudo-random amount to avoid aliasing artifacts.
    float theta0 = sampleBlueNoise().z;
    for (int i = 0; i < NUM_RADIAL_SAMPLES; i++) {
        // A fuzzy number of samples that are above the reference surface.
        float aboveCount = 0.0;
        for (int j = 0; j < NUM_ANGULAR_SAMPLES; j++) {
            float t = 2.0 * M_PI * (theta0 + float(j) / float(NUM_ANGULAR_SAMPLES));
            vec3 offset = (maxDist * range.y + minDist) * vec3(cos(t), sin(t), 0.0);
            aboveCount += smoothstep(DEFINITELY_BELOW_THRESH, DEFINITELY_ABOVE_THRESH, -queryDepthBelow(depthCoords + offset));
        }
        // To further avoid branches, compute both possible updated ranges and interpolate based
        // on the fuzzy count.
        range = mix(belowMat * range, aboveMat * range, smoothstep(0.0, 1.0, aboveCount));
    }
    // Return the midpoint.
    return range.y;
}
`,
    ],
    [
        '#include <color_fragment>',
        /*glsl*/ `
    vec3 depthMapCoords = vDepthMapPos.xyz / vDepthMapPos.w;
    float gateScalar = step(lateralBuffer / lateralScale, 1.0 - length(depthMapCoords.xy));
    // Normalize coordinates to [0, 1]
    depthMapCoords = 0.5 * depthMapCoords + 0.5;
    float lateralDist = queryLateralDistance(depthMapCoords, vMin / lateralScale, vMax / lateralScale);
    vec3 highlightColor = sampleColorRamp(lateralDist);
    float depthFade = gateScalar * overlayOpacity * getDepthFade(depthMapCoords, depthFadeDistance / depthScale);
    vec3 highlight = mix(defaultColor, highlightColor, depthFade);
    highlight *= 0.8 + 0.2 * gateScalar;
#if defined(USE_MAP)
    // diffuseColor is from the color map. Multiply with
    // highlight for similar behavior to vertex colors.
    diffuseColor.rgb *= highlight;
#else
    // diffuseColor will be multiplied by vertex color. See color_fragment.glsl.js in three.js.
    diffuseColor.rgb = highlight;
#endif // USE_MAP
`,
    ],
];

export type UndercutMaterialUniforms = UniformsObject<typeof _publicUniformDecls>;
const _factory = makeShaderBuilder('undercut-depth-shader', _uniformDecls, _vertReplacements, _fragmentReplacements);

const BLUE_NOISE_TEXTURE = (() => {
    const tex = new THREE.TextureLoader().load(BlueNoiseImage);
    tex.wrapS = THREE.RepeatWrapping;
    tex.wrapT = THREE.RepeatWrapping;
    return tex;
})();

const BLUE_NOISE_SIZE = new THREE.Vector2(256, 256);

export function applyShader<Material extends THREE.Material>(material: Material, uniforms: UndercutMaterialUniforms) {
    _factory.applyShader(material, {
        ...uniforms,
        blueNoise: new THREE.Uniform(BLUE_NOISE_TEXTURE),
        blueNoiseSize: new THREE.Uniform(BLUE_NOISE_SIZE),
    });
    return material;
}

export function createDesignMeshUndercutMaterial(
    phongMaterialParams: THREE.MeshPhongMaterialParameters,
    uniforms: UndercutMaterialUniforms,
) {
    const material = new THREE.MeshPhongMaterial(phongMaterialParams);
    return applyShader(material, uniforms);
}
