ace-step-webgpu / _source /src /worker.js
shreyask's picture
Initial deploy: built app at root + source under _source/
24b9788 verified
// Main worker: DiT + encoders + VAE on WebGPU. Spawns a dedicated LM worker
// (isolated WASM heap) for autoregressive generation.
import { AutoTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime-web/webgpu";
const MODEL_REPO = "shreyask/ACE-Step-v1.5-ONNX";
const MODEL_REVISION = "bdabfb5684fd70fcc76f98cbb51bb9ebc47ee342";
const ONNX_BASE = `https://huggingface.co/${MODEL_REPO}/resolve/${MODEL_REVISION}/onnx`;
const TEXT_TOKENIZER_REPO = "Qwen/Qwen3-Embedding-0.6B";
const SAMPLE_RATE = 48000;
const LATENT_RATE = 25;
const LATENT_CHANNELS = 64;
const HIDDEN_SIZE = 2048;
const POOL_WINDOW = 5;
const FSQ_DIM = 6;
const NUM_CODES = 64000;
// 8-step turbo schedules (from ACE-Step)
const SHIFT_TIMESTEPS_8 = {
1.0: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125],
2.0: [1.0, 0.9333, 0.8571, 0.7692, 0.6667, 0.5455, 0.4, 0.2222],
3.0: [1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3],
};
// Generate N-step shifted schedule matching MLX port:
// timesteps = linspace(1.0, 0.001, N)
// sigmas = shift * t / (1 + (shift-1) * t)
function buildSchedule(numSteps, shift) {
if (numSteps === 8 && SHIFT_TIMESTEPS_8[shift]) return SHIFT_TIMESTEPS_8[shift];
const sigmaMax = 1.0;
const sigmaMin = 0.001;
const schedule = [];
for (let i = 0; i < numSteps; i++) {
// linspace inclusive of both endpoints
const t = sigmaMax + (sigmaMin - sigmaMax) * (i / (numSteps - 1));
const tShifted = (shift * t) / (1.0 + (shift - 1.0) * t);
schedule.push(tShifted);
}
return schedule;
}
const CACHE_NAME = "ace-step-onnx-v12";
let textTokenizer = null;
let sessions = {};
let silenceLatent = null;
let fsqCodebooks = null;
let fsqScales = null;
let fsqProjectOutW = null;
let fsqProjectOutB = null;
let lmWorker = null;
let lmLoaded = false;
function post(type, data = {}) {
self.postMessage({ type, ...data });
}
async function fetchBuffer(url, label) {
const cache = await caches.open(CACHE_NAME);
const cached = await cache.match(url);
if (cached) {
post("progress", { label, loaded: 1, total: 1, percent: 100 });
return await cached.arrayBuffer();
}
const response = await fetch(url);
const total = parseInt(response.headers.get("content-length") || "0");
const reader = response.body.getReader();
const chunks = [];
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
if (total > 0) post("progress", { label, loaded, total, percent: (loaded / total) * 100 });
}
const buffer = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) { buffer.set(chunk, offset); offset += chunk.length; }
try {
await cache.put(url, new Response(buffer.buffer.slice(0), {
headers: { "Content-Type": "application/octet-stream" },
}));
} catch (_) {}
return buffer.buffer;
}
async function loadSession(name, filename, useUrlData = false, providers = ["webgpu"]) {
post("status", { message: `Loading ${name}...` });
try {
const modelBuffer = await fetchBuffer(`${ONNX_BASE}/${filename}`, `${name} graph`);
if (useUrlData) {
return await ort.InferenceSession.create(modelBuffer, {
executionProviders: providers,
externalData: [{ path: `${filename}.data`, data: `${ONNX_BASE}/${filename}.data` }],
});
}
const weightsBuffer = await fetchBuffer(`${ONNX_BASE}/${filename}.data`, `${name} weights`);
return await ort.InferenceSession.create(modelBuffer, {
executionProviders: providers,
externalData: [{ path: `${filename}.data`, data: weightsBuffer }],
});
} catch (err) {
throw new Error(`Failed loading ${name}: ${err.message}`);
}
}
function tensor(data, dims, type = "float32") {
return new ort.Tensor(type, data, dims);
}
function tensorStats(name, data) {
const arr = data instanceof Float32Array ? data : new Float32Array(data);
let min = Infinity, max = -Infinity, sum = 0;
for (let i = 0; i < arr.length; i++) {
if (arr[i] < min) min = arr[i];
if (arr[i] > max) max = arr[i];
sum += arr[i];
}
console.log(`[stats] ${name}: len=${arr.length} min=${min.toFixed(4)} max=${max.toFixed(4)} mean=${(sum / arr.length).toFixed(4)}`);
}
function randn(shape) {
const size = shape.reduce((a, b) => a * b, 1);
const data = new Float32Array(size);
for (let i = 0; i < size; i += 2) {
const u1 = Math.random();
const u2 = Math.random();
const r = Math.sqrt(-2 * Math.log(u1));
data[i] = r * Math.cos(2 * Math.PI * u2);
if (i + 1 < size) data[i + 1] = r * Math.sin(2 * Math.PI * u2);
}
return data;
}
function packSequences(hidden1, mask1, hidden2, mask2, batchSize, dim) {
const l1 = hidden1.length / (batchSize * dim);
const l2 = hidden2.length / (batchSize * dim);
const totalLen = l1 + l2;
const packedHidden = new Float32Array(batchSize * totalLen * dim);
const packedMask = new Float32Array(batchSize * totalLen);
for (let b = 0; b < batchSize; b++) {
const indices = [];
for (let i = 0; i < l1; i++) indices.push({ src: 1, idx: i, mask: mask1[b * l1 + i] });
for (let i = 0; i < l2; i++) indices.push({ src: 2, idx: i, mask: mask2[b * l2 + i] });
indices.sort((a, c) => c.mask - a.mask);
for (let pos = 0; pos < totalLen; pos++) {
const entry = indices[pos];
const srcArray = entry.src === 1 ? hidden1 : hidden2;
const srcLen = entry.src === 1 ? l1 : l2;
const srcOffset = (b * srcLen + entry.idx) * dim;
const dstOffset = (b * totalLen + pos) * dim;
packedHidden.set(srcArray.slice(srcOffset, srcOffset + dim), dstOffset);
packedMask[b * totalLen + pos] = entry.mask > 0 ? 1 : 0;
}
}
return { hidden: packedHidden, mask: packedMask, seqLen: totalLen };
}
function fsqLookup(indices, batchSize, seqLen) {
const out = new Float32Array(batchSize * seqLen * HIDDEN_SIZE);
for (let b = 0; b < batchSize; b++) {
for (let t = 0; t < seqLen; t++) {
const idx = indices[b * seqLen + t];
const codeOffset = idx * FSQ_DIM;
const scaledCode = new Float32Array(FSQ_DIM);
for (let d = 0; d < FSQ_DIM; d++) scaledCode[d] = fsqCodebooks[codeOffset + d] * fsqScales[d];
const outOffset = (b * seqLen + t) * HIDDEN_SIZE;
for (let h = 0; h < HIDDEN_SIZE; h++) {
let val = fsqProjectOutB[h];
for (let d = 0; d < FSQ_DIM; d++) val += scaledCode[d] * fsqProjectOutW[h * FSQ_DIM + d];
out[outOffset + h] = val;
}
}
}
return out;
}
// Spawn the LM worker and forward its status/progress messages up to the main thread
function spawnLMWorker() {
const worker = new Worker(new URL("./lm-worker.js", import.meta.url), { type: "module" });
worker.onmessage = (e) => {
const { type, ...data } = e.data;
if (type === "status" || type === "progress" || type === "error") {
self.postMessage(e.data); // forward as-is
}
// "loaded" and "audio_codes" are handled by the promise-based callers below
};
return worker;
}
function loadLMWorker() {
return new Promise((resolve, reject) => {
if (!lmWorker) lmWorker = spawnLMWorker();
const onMsg = (e) => {
if (e.data.type === "loaded") {
lmWorker.removeEventListener("message", onMsg);
lmLoaded = true;
resolve();
} else if (e.data.type === "error") {
lmWorker.removeEventListener("message", onMsg);
reject(new Error(e.data.message));
}
};
lmWorker.addEventListener("message", onMsg);
lmWorker.postMessage({ type: "load" });
});
}
function generateAudioCodesViaLM({ caption, lyrics, duration, numLatentFrames }) {
return new Promise((resolve, reject) => {
const onMsg = (e) => {
if (e.data.type === "audio_codes") {
lmWorker.removeEventListener("message", onMsg);
resolve(e.data);
} else if (e.data.type === "error") {
lmWorker.removeEventListener("message", onMsg);
reject(new Error(e.data.message));
}
};
lmWorker.addEventListener("message", onMsg);
lmWorker.postMessage({ type: "generate", caption, lyrics, duration, numLatentFrames });
});
}
async function loadModels() {
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;
ort.env.wasm.proxy = false;
console.log(`[models] ONNX revision ${MODEL_REVISION}`);
post("status", { message: `Using ONNX revision ${MODEL_REVISION.slice(0, 7)}` });
post("status", { message: "Spawning LM worker..." });
// Kick off LM loading in parallel with main-worker model loads
const lmLoadPromise = loadLMWorker();
post("status", { message: "Loading text tokenizer..." });
textTokenizer = await AutoTokenizer.from_pretrained(TEXT_TOKENIZER_REPO);
sessions.embedTokens = await loadSession("Embed Tokens", "text_embed_tokens_fp16.onnx");
sessions.detokenizer = await loadSession("Detokenizer", "detokenizer.onnx");
// VAE on WASM — WebGPU produces constant output past ~1.5s for conv1d upsample chain
sessions.vaeDecoder = await loadSession("VAE Decoder (CPU)", "vae_decoder_fp16.onnx", false, ["wasm"]);
sessions.textEncoder = await loadSession("Text Encoder", "text_encoder_fp16.onnx", true);
// FP32 condition_encoder — q4v2 had max_diff=13.92 vs PyTorch with real inputs,
// degrading conditioning so badly that DiT output was garbled. FP32 is 2.4GB via URL.
sessions.conditionEncoder = await loadSession("Condition Encoder (fp32)", "condition_encoder.onnx", true);
// DEBUG: dit_decoder_fp16_v2 is the quality baseline (max_diff=0.021 per step).
// dit_cached trades quality for speed (max_diff=0.074). Reverting while we diagnose
// the ONNX-vs-MLX spectral gap — compounded drift over 8 steps matters here.
sessions.ditDecoder = await loadSession("DiT Decoder (uncached)", "dit_decoder_fp16_v2.onnx", true);
post("status", { message: "Loading auxiliary data..." });
const [cbBuf, scBuf, powBuf, pobBuf, silBuf] = await Promise.all([
fetchBuffer(`${ONNX_BASE}/fsq_codebooks.bin`, "codebooks"),
fetchBuffer(`${ONNX_BASE}/fsq_scales.bin`, "scales"),
fetchBuffer(`${ONNX_BASE}/fsq_project_out_weight.bin`, "proj_out_w"),
fetchBuffer(`${ONNX_BASE}/fsq_project_out_bias.bin`, "proj_out_b"),
fetchBuffer("/silence_latent.bin", "silence latent"),
]);
fsqCodebooks = new Float32Array(cbBuf);
fsqScales = new Float32Array(scBuf);
fsqProjectOutW = new Float32Array(powBuf);
fsqProjectOutB = new Float32Array(pobBuf);
silenceLatent = new Float32Array(silBuf);
post("status", { message: "Waiting for LM worker..." });
await lmLoadPromise;
post("status", { message: "All models loaded!" });
post("loaded");
}
function buildSFTPrompt(caption, metas) {
const instruction = "Fill the audio semantic mask based on the given conditions:";
return `# Instruction\n${instruction}\n\n# Caption\n${caption}\n\n# Metas\n${metas}<|endoftext|>`;
}
async function encodeText(caption, metas) {
const prompt = buildSFTPrompt(caption, metas);
const encoded = textTokenizer(prompt, { padding: "max_length", max_length: 256, truncation: true });
const idsRaw = encoded.input_ids.data;
const inputIds = idsRaw instanceof BigInt64Array ? idsRaw : new BigInt64Array(Array.from(idsRaw, BigInt));
const result = await sessions.textEncoder.run({ input_ids: tensor(inputIds, [1, 256], "int64") });
const projected = await sessions.textProjector.run({ text_hidden_states: result.hidden_states });
const maskRaw = encoded.attention_mask.data;
const attentionMask = new Float32Array(maskRaw.length);
for (let i = 0; i < maskRaw.length; i++) attentionMask[i] = Number(maskRaw[i]);
return { hidden: projected.projected.data, mask: attentionMask, seqLen: 256 };
}
async function encodeLyrics(lyrics, language = "en") {
const fullText = `# Languages\n${language}\n\n# Lyric\n${lyrics}`;
// max_length=2048 matches the original handler (conditioning_text.py)
const encoded = textTokenizer(fullText, { padding: "max_length", max_length: 2048, truncation: true });
const idsRaw = encoded.input_ids.data;
const inputIds = idsRaw instanceof BigInt64Array ? idsRaw : new BigInt64Array(Array.from(idsRaw, BigInt));
const seqLen = inputIds.length;
const embedResult = await sessions.embedTokens.run({ input_ids: tensor(inputIds, [1, seqLen], "int64") });
const maskRaw = encoded.attention_mask.data;
const attentionMask = new Float32Array(maskRaw.length);
for (let i = 0; i < maskRaw.length; i++) attentionMask[i] = Number(maskRaw[i]);
const lyricResult = await sessions.lyricEncoder.run({
inputs_embeds: embedResult.hidden_states,
attention_mask: tensor(attentionMask, [1, seqLen]),
});
return { hidden: lyricResult.hidden_states.data, mask: attentionMask, seqLen };
}
async function encodeTimbre() {
const silenceRef = silenceLatent.slice(0, 750 * LATENT_CHANNELS);
const result = await sessions.timbreEncoder.run({
refer_audio: tensor(silenceRef, [1, 750, LATENT_CHANNELS]),
});
const timbreHidden = new Float32Array(HIDDEN_SIZE);
timbreHidden.set(result.timbre_embedding.data);
return { hidden: timbreHidden, mask: new Float32Array([1.0]), seqLen: 1 };
}
async function generateLMHints(caption, lyrics, numLatentFrames, duration) {
const { codes, elapsed, tokenCount } = await generateAudioCodesViaLM({ caption, lyrics, duration, numLatentFrames });
post("status", { message: `LM: ${codes.length} codes from ${tokenCount} tokens in ${elapsed}s` });
if (codes.length === 0) {
console.warn("[lm] No audio codes generated, returning silence");
return new Float32Array(numLatentFrames * LATENT_CHANNELS);
}
const numCodes5Hz = codes.length;
post("status", { message: "FSQ codebook lookup..." });
const lmHints5Hz = fsqLookup(codes, 1, numCodes5Hz);
tensorStats("lm_hints_5hz", lmHints5Hz);
post("status", { message: "Detokenizing 5Hz → 25Hz..." });
const detokResult = await sessions.detokenizer.run({
quantized: tensor(lmHints5Hz, [1, numCodes5Hz, HIDDEN_SIZE]),
});
const lmHints25HzRaw = detokResult.lm_hints_25hz.data;
const rawLen = lmHints25HzRaw.length / LATENT_CHANNELS;
tensorStats("lm_hints_25hz_raw", lmHints25HzRaw);
// Pad with last frame (MLX port behavior) or truncate
const lmHints25Hz = new Float32Array(numLatentFrames * LATENT_CHANNELS);
if (rawLen >= numLatentFrames) {
lmHints25Hz.set(lmHints25HzRaw.slice(0, numLatentFrames * LATENT_CHANNELS));
} else {
lmHints25Hz.set(lmHints25HzRaw);
// Repeat last frame to fill remaining
const lastFrameStart = (rawLen - 1) * LATENT_CHANNELS;
const lastFrame = lmHints25HzRaw.slice(lastFrameStart, lastFrameStart + LATENT_CHANNELS);
for (let t = rawLen; t < numLatentFrames; t++) {
lmHints25Hz.set(lastFrame, t * LATENT_CHANNELS);
}
console.log(`[hints] padded ${rawLen}${numLatentFrames} frames with last-frame replication`);
}
tensorStats("lm_hints_25hz_final", lmHints25Hz);
return lmHints25Hz;
}
async function generateAudio({ caption, lyrics, duration, shift, numSteps = 8 }) {
const totalStartTime = performance.now();
const filenameStamp = Date.now();
const batchSize = 1;
const numLatentFrames = Math.round(duration * LATENT_RATE);
const tSchedule = buildSchedule(numSteps, shift);
const metas = `duration: ${duration}s`;
// 1. Text → Qwen3 embedding (1024-dim hidden states, BEFORE projection)
post("status", { message: "Encoding text..." });
const sftPrompt = buildSFTPrompt(caption, metas);
const textEnc = textTokenizer(sftPrompt, { padding: "max_length", max_length: 256, truncation: true });
const textIdsRaw = textEnc.input_ids.data;
const textIds = textIdsRaw instanceof BigInt64Array ? textIdsRaw : new BigInt64Array(Array.from(textIdsRaw, BigInt));
const textHiddenRes = await sessions.textEncoder.run({ input_ids: tensor(textIds, [1, 256], "int64") });
const textHidden = textHiddenRes.hidden_states;
const textMaskRaw = textEnc.attention_mask.data;
const textMask = new Float32Array(textMaskRaw.length);
for (let i = 0; i < textMaskRaw.length; i++) textMask[i] = Number(textMaskRaw[i]);
// 2. Lyric tokens → embed_tokens (1024-dim, passed into condition_encoder's lyric_encoder)
post("status", { message: "Embedding lyrics..." });
const lyricFullText = `# Languages\nen\n\n# Lyric\n${lyrics}`;
const lyricEnc = textTokenizer(lyricFullText, { padding: "max_length", max_length: 2048, truncation: true });
const lyricIdsRaw = lyricEnc.input_ids.data;
const lyricIds = lyricIdsRaw instanceof BigInt64Array ? lyricIdsRaw : new BigInt64Array(Array.from(lyricIdsRaw, BigInt));
const lyricEmbRes = await sessions.embedTokens.run({ input_ids: tensor(lyricIds, [1, 2048], "int64") });
const lyricEmb = lyricEmbRes.hidden_states;
const lyricMaskRaw = lyricEnc.attention_mask.data;
const lyricMask = new Float32Array(lyricMaskRaw.length);
for (let i = 0; i < lyricMaskRaw.length; i++) lyricMask[i] = Number(lyricMaskRaw[i]);
// 3. LM hints (mandatory for turbo model)
const lmHints25Hz = await generateLMHints(caption, lyrics, numLatentFrames, duration);
// 4. Silence for ref audio (timbre) and src_latents
const silenceRef = silenceLatent.slice(0, 750 * LATENT_CHANNELS);
const srcLatents = new Float32Array(numLatentFrames * LATENT_CHANNELS);
const chunkMasks = new Float32Array(numLatentFrames * LATENT_CHANNELS).fill(1.0);
const isCovers = new Float32Array([1.0]); // force use of LM hints
// 5. condition_encoder: does text_projector + lyric_encoder + timbre_encoder + pack_sequences + context_latents
post("status", { message: "Running condition encoder..." });
const condResult = await sessions.conditionEncoder.run({
text_hidden_states: textHidden,
text_attention_mask: tensor(textMask, [1, 256]),
lyric_hidden_states: lyricEmb,
lyric_attention_mask: tensor(lyricMask, [1, 2048]),
refer_audio_acoustic_hidden_states_packed: tensor(silenceRef, [1, 750, LATENT_CHANNELS]),
refer_audio_order_mask: tensor(new BigInt64Array([0n]), [1], "int64"),
src_latents: tensor(srcLatents, [1, numLatentFrames, LATENT_CHANNELS]),
chunk_masks: tensor(chunkMasks, [1, numLatentFrames, LATENT_CHANNELS]),
is_covers: tensor(isCovers, [1]),
precomputed_lm_hints_25hz: tensor(lmHints25Hz, [1, numLatentFrames, LATENT_CHANNELS]),
});
const encoderHiddenStates = condResult.encoder_hidden_states;
const contextLatentsTensor = condResult.context_latents;
tensorStats("encoder_hidden_states", encoderHiddenStates.data);
tensorStats("context_latents", contextLatentsTensor.data);
post("status", { message: "Starting denoising..." });
let xt = randn([batchSize, numLatentFrames, LATENT_CHANNELS]);
const startTime = performance.now();
for (let step = 0; step < tSchedule.length; step++) {
const tCurr = tSchedule[step];
post("status", { message: `Denoising step ${step + 1}/${tSchedule.length}...` });
const timestepData = new Float32Array(batchSize).fill(tCurr);
const result = await sessions.ditDecoder.run({
hidden_states: tensor(xt, [batchSize, numLatentFrames, LATENT_CHANNELS]),
timestep: tensor(timestepData, [batchSize]),
encoder_hidden_states: encoderHiddenStates,
context_latents: contextLatentsTensor,
});
const vt = result.velocity.data;
if (step === tSchedule.length - 1) {
for (let i = 0; i < xt.length; i++) xt[i] = xt[i] - vt[i] * tCurr;
} else {
const dt = tCurr - tSchedule[step + 1];
for (let i = 0; i < xt.length; i++) xt[i] = xt[i] - vt[i] * dt;
}
}
const diffusionTime = ((performance.now() - startTime) / 1000).toFixed(2);
tensorStats("final_latent", xt);
// Per-frame variance check — detects if later frames are constant
const perFrameVariance = new Float32Array(numLatentFrames);
for (let t = 0; t < numLatentFrames; t++) {
let mean = 0;
for (let c = 0; c < LATENT_CHANNELS; c++) mean += xt[t * LATENT_CHANNELS + c];
mean /= LATENT_CHANNELS;
let varSum = 0;
for (let c = 0; c < LATENT_CHANNELS; c++) {
const d = xt[t * LATENT_CHANNELS + c] - mean;
varSum += d * d;
}
perFrameVariance[t] = varSum / LATENT_CHANNELS;
}
console.log("[perframe] variance samples:", Array.from(perFrameVariance.filter((_, i) => i % 25 === 0)).map(v => v.toFixed(3)));
// Also check LM hints per-frame variance
const hintsVar = new Float32Array(numLatentFrames);
for (let t = 0; t < numLatentFrames; t++) {
let mean = 0;
for (let c = 0; c < LATENT_CHANNELS; c++) mean += lmHints25Hz[t * LATENT_CHANNELS + c];
mean /= LATENT_CHANNELS;
let varSum = 0;
for (let c = 0; c < LATENT_CHANNELS; c++) {
const d = lmHints25Hz[t * LATENT_CHANNELS + c] - mean;
varSum += d * d;
}
hintsVar[t] = varSum / LATENT_CHANNELS;
}
console.log("[hints var] samples:", Array.from(hintsVar.filter((_, i) => i % 25 === 0)).map(v => v.toFixed(3)));
post("status", { message: "Decoding audio..." });
const latentsForVae = new Float32Array(batchSize * LATENT_CHANNELS * numLatentFrames);
for (let t = 0; t < numLatentFrames; t++) {
for (let c = 0; c < LATENT_CHANNELS; c++) {
latentsForVae[c * numLatentFrames + t] = xt[t * LATENT_CHANNELS + c];
}
}
const vaeResult = await sessions.vaeDecoder.run({
latents: tensor(latentsForVae, [batchSize, LATENT_CHANNELS, numLatentFrames]),
});
const waveform = vaeResult.waveform.data;
tensorStats("waveform", waveform);
masterWaveform(waveform, SAMPLE_RATE, 2);
const wavBuffer = float32ToWav(waveform, SAMPLE_RATE, 2);
// totalTime measures the whole pipeline (LM + encoders + diffusion + VAE),
// not just the diffusion loop. diffusionTime is reported separately below.
const totalTime = ((performance.now() - totalStartTime) / 1000).toFixed(2);
post("audio", { wavBuffer, duration, diffusionTime, totalTime, filenameStamp }, [wavBuffer]);
}
function measureAudio(samples) {
let peak = 0;
let sumSq = 0;
for (let i = 0; i < samples.length; i++) {
const v = samples[i];
const abs = Math.abs(v);
if (abs > peak) peak = abs;
sumSq += v * v;
}
return { peak, rms: Math.sqrt(sumSq / Math.max(1, samples.length)) };
}
function goertzelPower(data, sampleRate, freq) {
const omega = 2 * Math.PI * freq / sampleRate;
const coeff = 2 * Math.cos(omega);
let s0 = 0, s1 = 0, s2 = 0;
for (let i = 0; i < data.length; i++) {
s0 = data[i] + coeff * s1 - s2;
s2 = s1;
s1 = s0;
}
return s1 * s1 + s2 * s2 - coeff * s1 * s2;
}
function detectDronePeaks(samples, sampleRate, channels) {
const numSamples = samples.length / channels;
const step = Math.max(1, Math.floor(sampleRate / 4000));
const downsampleRate = sampleRate / step;
const downsampledLength = Math.floor(numSamples / step);
if (downsampledLength < 1024) return [];
const mono = new Float32Array(downsampledLength);
let mean = 0;
for (let i = 0; i < downsampledLength; i++) {
const src = i * step;
let v = 0;
for (let ch = 0; ch < channels; ch++) v += samples[ch * numSamples + src];
v /= channels;
mono[i] = v;
mean += v;
}
mean /= downsampledLength;
for (let i = 0; i < mono.length; i++) mono[i] -= mean;
const bins = [];
for (let freq = 250; freq <= 950; freq += 12.5) {
bins.push({ freq, power: goertzelPower(mono, downsampleRate, freq) });
}
const sortedPowers = bins.map((bin) => bin.power).sort((a, b) => a - b);
const median = sortedPowers[Math.floor(sortedPowers.length / 2)] + 1e-12;
bins.sort((a, b) => b.power - a.power);
const peaks = [];
for (const bin of bins) {
const score = bin.power / median;
if (score < 12) break;
if (peaks.every((peak) => Math.abs(peak.freq - bin.freq) >= 50)) {
peaks.push({ freq: bin.freq, score });
if (peaks.length >= 2) break;
}
}
return peaks;
}
function applyNotch(samples, sampleRate, channels, freq, q = 20, depth = 0.45) {
const numSamples = samples.length / channels;
const w0 = 2 * Math.PI * freq / sampleRate;
const cos = Math.cos(w0);
const alpha = Math.sin(w0) / (2 * q);
const a0 = 1 + alpha;
const b0 = 1 / a0;
const b1 = (-2 * cos) / a0;
const b2 = 1 / a0;
const a1 = (-2 * cos) / a0;
const a2 = (1 - alpha) / a0;
for (let ch = 0; ch < channels; ch++) {
const offset = ch * numSamples;
let x1 = 0, x2 = 0, y1 = 0, y2 = 0;
for (let i = 0; i < numSamples; i++) {
const x0 = samples[offset + i];
const y0 = b0 * x0 + b1 * x1 + b2 * x2 - a1 * y1 - a2 * y2;
samples[offset + i] = x0 * (1 - depth) + y0 * depth;
x2 = x1; x1 = x0;
y2 = y1; y1 = y0;
}
}
}
function masterWaveform(samples, sampleRate, channels) {
const before = measureAudio(samples);
if (before.peak <= 0.001) return;
const dronePeaks = detectDronePeaks(samples, sampleRate, channels);
for (const peak of dronePeaks) applyNotch(samples, sampleRate, channels, peak.freq);
const afterEq = measureAudio(samples);
const targetRms = 0.085;
const maxPeak = 0.891;
const maxGain = 12.0;
const gain = Math.min(
maxGain,
targetRms / Math.max(afterEq.rms, 1e-6),
maxPeak / Math.max(afterEq.peak, 1e-6),
);
for (let i = 0; i < samples.length; i++) samples[i] *= gain;
const after = measureAudio(samples);
const peakText = dronePeaks.map((peak) => `${peak.freq.toFixed(1)}Hz/${peak.score.toFixed(0)}x`).join(", ") || "none";
console.log(
`[master] rawPeak=${before.peak.toFixed(4)} rawRms=${before.rms.toFixed(4)} ` +
`dronePeaks=${peakText} gain=${gain.toFixed(2)}x peak=${after.peak.toFixed(4)} rms=${after.rms.toFixed(4)}`,
);
}
function float32ToWav(samples, sampleRate, channels = 2) {
const numSamples = samples.length / channels;
const bitsPerSample = 16;
const blockAlign = channels * (bitsPerSample / 8);
const byteRate = sampleRate * blockAlign;
const dataSize = numSamples * blockAlign;
const buffer = new ArrayBuffer(44 + dataSize);
const view = new DataView(buffer);
const w = (o, s) => { for (let i = 0; i < s.length; i++) view.setUint8(o + i, s.charCodeAt(i)); };
w(0, "RIFF"); view.setUint32(4, 36 + dataSize, true);
w(8, "WAVE"); w(12, "fmt "); view.setUint32(16, 16, true);
view.setUint16(20, 1, true); view.setUint16(22, channels, true);
view.setUint32(24, sampleRate, true); view.setUint32(28, byteRate, true);
view.setUint16(32, blockAlign, true); view.setUint16(34, bitsPerSample, true);
w(36, "data"); view.setUint32(40, dataSize, true);
let offset = 44;
for (let i = 0; i < numSamples; i++) {
for (let ch = 0; ch < channels; ch++) {
const sample = Math.max(-1, Math.min(1, samples[ch * numSamples + i]));
view.setInt16(offset, sample * 32767, true);
offset += 2;
}
}
return buffer;
}
self.onmessage = async (e) => {
const { type, ...data } = e.data;
try {
if (type === "load") await loadModels();
else if (type === "generate") await generateAudio(data);
} catch (err) {
post("error", { message: err.message, stack: err.stack });
}
};