Spaces:
Running
Running
| // Main worker: DiT + encoders + VAE on WebGPU. Spawns a dedicated LM worker | |
| // (isolated WASM heap) for autoregressive generation. | |
| import { AutoTokenizer } from "@huggingface/transformers"; | |
| import * as ort from "onnxruntime-web/webgpu"; | |
| const MODEL_REPO = "shreyask/ACE-Step-v1.5-ONNX"; | |
| const MODEL_REVISION = "bdabfb5684fd70fcc76f98cbb51bb9ebc47ee342"; | |
| const ONNX_BASE = `https://huggingface.co/${MODEL_REPO}/resolve/${MODEL_REVISION}/onnx`; | |
| const TEXT_TOKENIZER_REPO = "Qwen/Qwen3-Embedding-0.6B"; | |
| const SAMPLE_RATE = 48000; | |
| const LATENT_RATE = 25; | |
| const LATENT_CHANNELS = 64; | |
| const HIDDEN_SIZE = 2048; | |
| const POOL_WINDOW = 5; | |
| const FSQ_DIM = 6; | |
| const NUM_CODES = 64000; | |
| // 8-step turbo schedules (from ACE-Step) | |
| const SHIFT_TIMESTEPS_8 = { | |
| 1.0: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125], | |
| 2.0: [1.0, 0.9333, 0.8571, 0.7692, 0.6667, 0.5455, 0.4, 0.2222], | |
| 3.0: [1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3], | |
| }; | |
| // Generate N-step shifted schedule matching MLX port: | |
| // timesteps = linspace(1.0, 0.001, N) | |
| // sigmas = shift * t / (1 + (shift-1) * t) | |
| function buildSchedule(numSteps, shift) { | |
| if (numSteps === 8 && SHIFT_TIMESTEPS_8[shift]) return SHIFT_TIMESTEPS_8[shift]; | |
| const sigmaMax = 1.0; | |
| const sigmaMin = 0.001; | |
| const schedule = []; | |
| for (let i = 0; i < numSteps; i++) { | |
| // linspace inclusive of both endpoints | |
| const t = sigmaMax + (sigmaMin - sigmaMax) * (i / (numSteps - 1)); | |
| const tShifted = (shift * t) / (1.0 + (shift - 1.0) * t); | |
| schedule.push(tShifted); | |
| } | |
| return schedule; | |
| } | |
| const CACHE_NAME = "ace-step-onnx-v12"; | |
| let textTokenizer = null; | |
| let sessions = {}; | |
| let silenceLatent = null; | |
| let fsqCodebooks = null; | |
| let fsqScales = null; | |
| let fsqProjectOutW = null; | |
| let fsqProjectOutB = null; | |
| let lmWorker = null; | |
| let lmLoaded = false; | |
| function post(type, data = {}) { | |
| self.postMessage({ type, ...data }); | |
| } | |
| async function fetchBuffer(url, label) { | |
| const cache = await caches.open(CACHE_NAME); | |
| const cached = await cache.match(url); | |
| if (cached) { | |
| post("progress", { label, loaded: 1, total: 1, percent: 100 }); | |
| return await cached.arrayBuffer(); | |
| } | |
| const response = await fetch(url); | |
| const total = parseInt(response.headers.get("content-length") || "0"); | |
| const reader = response.body.getReader(); | |
| const chunks = []; | |
| let loaded = 0; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| chunks.push(value); | |
| loaded += value.length; | |
| if (total > 0) post("progress", { label, loaded, total, percent: (loaded / total) * 100 }); | |
| } | |
| const buffer = new Uint8Array(loaded); | |
| let offset = 0; | |
| for (const chunk of chunks) { buffer.set(chunk, offset); offset += chunk.length; } | |
| try { | |
| await cache.put(url, new Response(buffer.buffer.slice(0), { | |
| headers: { "Content-Type": "application/octet-stream" }, | |
| })); | |
| } catch (_) {} | |
| return buffer.buffer; | |
| } | |
| async function loadSession(name, filename, useUrlData = false, providers = ["webgpu"]) { | |
| post("status", { message: `Loading ${name}...` }); | |
| try { | |
| const modelBuffer = await fetchBuffer(`${ONNX_BASE}/${filename}`, `${name} graph`); | |
| if (useUrlData) { | |
| return await ort.InferenceSession.create(modelBuffer, { | |
| executionProviders: providers, | |
| externalData: [{ path: `${filename}.data`, data: `${ONNX_BASE}/${filename}.data` }], | |
| }); | |
| } | |
| const weightsBuffer = await fetchBuffer(`${ONNX_BASE}/${filename}.data`, `${name} weights`); | |
| return await ort.InferenceSession.create(modelBuffer, { | |
| executionProviders: providers, | |
| externalData: [{ path: `${filename}.data`, data: weightsBuffer }], | |
| }); | |
| } catch (err) { | |
| throw new Error(`Failed loading ${name}: ${err.message}`); | |
| } | |
| } | |
| function tensor(data, dims, type = "float32") { | |
| return new ort.Tensor(type, data, dims); | |
| } | |
| function tensorStats(name, data) { | |
| const arr = data instanceof Float32Array ? data : new Float32Array(data); | |
| let min = Infinity, max = -Infinity, sum = 0; | |
| for (let i = 0; i < arr.length; i++) { | |
| if (arr[i] < min) min = arr[i]; | |
| if (arr[i] > max) max = arr[i]; | |
| sum += arr[i]; | |
| } | |
| console.log(`[stats] ${name}: len=${arr.length} min=${min.toFixed(4)} max=${max.toFixed(4)} mean=${(sum / arr.length).toFixed(4)}`); | |
| } | |
| function randn(shape) { | |
| const size = shape.reduce((a, b) => a * b, 1); | |
| const data = new Float32Array(size); | |
| for (let i = 0; i < size; i += 2) { | |
| const u1 = Math.random(); | |
| const u2 = Math.random(); | |
| const r = Math.sqrt(-2 * Math.log(u1)); | |
| data[i] = r * Math.cos(2 * Math.PI * u2); | |
| if (i + 1 < size) data[i + 1] = r * Math.sin(2 * Math.PI * u2); | |
| } | |
| return data; | |
| } | |
| function packSequences(hidden1, mask1, hidden2, mask2, batchSize, dim) { | |
| const l1 = hidden1.length / (batchSize * dim); | |
| const l2 = hidden2.length / (batchSize * dim); | |
| const totalLen = l1 + l2; | |
| const packedHidden = new Float32Array(batchSize * totalLen * dim); | |
| const packedMask = new Float32Array(batchSize * totalLen); | |
| for (let b = 0; b < batchSize; b++) { | |
| const indices = []; | |
| for (let i = 0; i < l1; i++) indices.push({ src: 1, idx: i, mask: mask1[b * l1 + i] }); | |
| for (let i = 0; i < l2; i++) indices.push({ src: 2, idx: i, mask: mask2[b * l2 + i] }); | |
| indices.sort((a, c) => c.mask - a.mask); | |
| for (let pos = 0; pos < totalLen; pos++) { | |
| const entry = indices[pos]; | |
| const srcArray = entry.src === 1 ? hidden1 : hidden2; | |
| const srcLen = entry.src === 1 ? l1 : l2; | |
| const srcOffset = (b * srcLen + entry.idx) * dim; | |
| const dstOffset = (b * totalLen + pos) * dim; | |
| packedHidden.set(srcArray.slice(srcOffset, srcOffset + dim), dstOffset); | |
| packedMask[b * totalLen + pos] = entry.mask > 0 ? 1 : 0; | |
| } | |
| } | |
| return { hidden: packedHidden, mask: packedMask, seqLen: totalLen }; | |
| } | |
| function fsqLookup(indices, batchSize, seqLen) { | |
| const out = new Float32Array(batchSize * seqLen * HIDDEN_SIZE); | |
| for (let b = 0; b < batchSize; b++) { | |
| for (let t = 0; t < seqLen; t++) { | |
| const idx = indices[b * seqLen + t]; | |
| const codeOffset = idx * FSQ_DIM; | |
| const scaledCode = new Float32Array(FSQ_DIM); | |
| for (let d = 0; d < FSQ_DIM; d++) scaledCode[d] = fsqCodebooks[codeOffset + d] * fsqScales[d]; | |
| const outOffset = (b * seqLen + t) * HIDDEN_SIZE; | |
| for (let h = 0; h < HIDDEN_SIZE; h++) { | |
| let val = fsqProjectOutB[h]; | |
| for (let d = 0; d < FSQ_DIM; d++) val += scaledCode[d] * fsqProjectOutW[h * FSQ_DIM + d]; | |
| out[outOffset + h] = val; | |
| } | |
| } | |
| } | |
| return out; | |
| } | |
| // Spawn the LM worker and forward its status/progress messages up to the main thread | |
| function spawnLMWorker() { | |
| const worker = new Worker(new URL("./lm-worker.js", import.meta.url), { type: "module" }); | |
| worker.onmessage = (e) => { | |
| const { type, ...data } = e.data; | |
| if (type === "status" || type === "progress" || type === "error") { | |
| self.postMessage(e.data); // forward as-is | |
| } | |
| // "loaded" and "audio_codes" are handled by the promise-based callers below | |
| }; | |
| return worker; | |
| } | |
| function loadLMWorker() { | |
| return new Promise((resolve, reject) => { | |
| if (!lmWorker) lmWorker = spawnLMWorker(); | |
| const onMsg = (e) => { | |
| if (e.data.type === "loaded") { | |
| lmWorker.removeEventListener("message", onMsg); | |
| lmLoaded = true; | |
| resolve(); | |
| } else if (e.data.type === "error") { | |
| lmWorker.removeEventListener("message", onMsg); | |
| reject(new Error(e.data.message)); | |
| } | |
| }; | |
| lmWorker.addEventListener("message", onMsg); | |
| lmWorker.postMessage({ type: "load" }); | |
| }); | |
| } | |
| function generateAudioCodesViaLM({ caption, lyrics, duration, numLatentFrames }) { | |
| return new Promise((resolve, reject) => { | |
| const onMsg = (e) => { | |
| if (e.data.type === "audio_codes") { | |
| lmWorker.removeEventListener("message", onMsg); | |
| resolve(e.data); | |
| } else if (e.data.type === "error") { | |
| lmWorker.removeEventListener("message", onMsg); | |
| reject(new Error(e.data.message)); | |
| } | |
| }; | |
| lmWorker.addEventListener("message", onMsg); | |
| lmWorker.postMessage({ type: "generate", caption, lyrics, duration, numLatentFrames }); | |
| }); | |
| } | |
| async function loadModels() { | |
| ort.env.wasm.numThreads = 1; | |
| ort.env.wasm.simd = true; | |
| ort.env.wasm.proxy = false; | |
| console.log(`[models] ONNX revision ${MODEL_REVISION}`); | |
| post("status", { message: `Using ONNX revision ${MODEL_REVISION.slice(0, 7)}` }); | |
| post("status", { message: "Spawning LM worker..." }); | |
| // Kick off LM loading in parallel with main-worker model loads | |
| const lmLoadPromise = loadLMWorker(); | |
| post("status", { message: "Loading text tokenizer..." }); | |
| textTokenizer = await AutoTokenizer.from_pretrained(TEXT_TOKENIZER_REPO); | |
| sessions.embedTokens = await loadSession("Embed Tokens", "text_embed_tokens_fp16.onnx"); | |
| sessions.detokenizer = await loadSession("Detokenizer", "detokenizer.onnx"); | |
| // VAE on WASM — WebGPU produces constant output past ~1.5s for conv1d upsample chain | |
| sessions.vaeDecoder = await loadSession("VAE Decoder (CPU)", "vae_decoder_fp16.onnx", false, ["wasm"]); | |
| sessions.textEncoder = await loadSession("Text Encoder", "text_encoder_fp16.onnx", true); | |
| // FP32 condition_encoder — q4v2 had max_diff=13.92 vs PyTorch with real inputs, | |
| // degrading conditioning so badly that DiT output was garbled. FP32 is 2.4GB via URL. | |
| sessions.conditionEncoder = await loadSession("Condition Encoder (fp32)", "condition_encoder.onnx", true); | |
| // DEBUG: dit_decoder_fp16_v2 is the quality baseline (max_diff=0.021 per step). | |
| // dit_cached trades quality for speed (max_diff=0.074). Reverting while we diagnose | |
| // the ONNX-vs-MLX spectral gap — compounded drift over 8 steps matters here. | |
| sessions.ditDecoder = await loadSession("DiT Decoder (uncached)", "dit_decoder_fp16_v2.onnx", true); | |
| post("status", { message: "Loading auxiliary data..." }); | |
| const [cbBuf, scBuf, powBuf, pobBuf, silBuf] = await Promise.all([ | |
| fetchBuffer(`${ONNX_BASE}/fsq_codebooks.bin`, "codebooks"), | |
| fetchBuffer(`${ONNX_BASE}/fsq_scales.bin`, "scales"), | |
| fetchBuffer(`${ONNX_BASE}/fsq_project_out_weight.bin`, "proj_out_w"), | |
| fetchBuffer(`${ONNX_BASE}/fsq_project_out_bias.bin`, "proj_out_b"), | |
| fetchBuffer("/silence_latent.bin", "silence latent"), | |
| ]); | |
| fsqCodebooks = new Float32Array(cbBuf); | |
| fsqScales = new Float32Array(scBuf); | |
| fsqProjectOutW = new Float32Array(powBuf); | |
| fsqProjectOutB = new Float32Array(pobBuf); | |
| silenceLatent = new Float32Array(silBuf); | |
| post("status", { message: "Waiting for LM worker..." }); | |
| await lmLoadPromise; | |
| post("status", { message: "All models loaded!" }); | |
| post("loaded"); | |
| } | |
| function buildSFTPrompt(caption, metas) { | |
| const instruction = "Fill the audio semantic mask based on the given conditions:"; | |
| return `# Instruction\n${instruction}\n\n# Caption\n${caption}\n\n# Metas\n${metas}<|endoftext|>`; | |
| } | |
| async function encodeText(caption, metas) { | |
| const prompt = buildSFTPrompt(caption, metas); | |
| const encoded = textTokenizer(prompt, { padding: "max_length", max_length: 256, truncation: true }); | |
| const idsRaw = encoded.input_ids.data; | |
| const inputIds = idsRaw instanceof BigInt64Array ? idsRaw : new BigInt64Array(Array.from(idsRaw, BigInt)); | |
| const result = await sessions.textEncoder.run({ input_ids: tensor(inputIds, [1, 256], "int64") }); | |
| const projected = await sessions.textProjector.run({ text_hidden_states: result.hidden_states }); | |
| const maskRaw = encoded.attention_mask.data; | |
| const attentionMask = new Float32Array(maskRaw.length); | |
| for (let i = 0; i < maskRaw.length; i++) attentionMask[i] = Number(maskRaw[i]); | |
| return { hidden: projected.projected.data, mask: attentionMask, seqLen: 256 }; | |
| } | |
| async function encodeLyrics(lyrics, language = "en") { | |
| const fullText = `# Languages\n${language}\n\n# Lyric\n${lyrics}`; | |
| // max_length=2048 matches the original handler (conditioning_text.py) | |
| const encoded = textTokenizer(fullText, { padding: "max_length", max_length: 2048, truncation: true }); | |
| const idsRaw = encoded.input_ids.data; | |
| const inputIds = idsRaw instanceof BigInt64Array ? idsRaw : new BigInt64Array(Array.from(idsRaw, BigInt)); | |
| const seqLen = inputIds.length; | |
| const embedResult = await sessions.embedTokens.run({ input_ids: tensor(inputIds, [1, seqLen], "int64") }); | |
| const maskRaw = encoded.attention_mask.data; | |
| const attentionMask = new Float32Array(maskRaw.length); | |
| for (let i = 0; i < maskRaw.length; i++) attentionMask[i] = Number(maskRaw[i]); | |
| const lyricResult = await sessions.lyricEncoder.run({ | |
| inputs_embeds: embedResult.hidden_states, | |
| attention_mask: tensor(attentionMask, [1, seqLen]), | |
| }); | |
| return { hidden: lyricResult.hidden_states.data, mask: attentionMask, seqLen }; | |
| } | |
| async function encodeTimbre() { | |
| const silenceRef = silenceLatent.slice(0, 750 * LATENT_CHANNELS); | |
| const result = await sessions.timbreEncoder.run({ | |
| refer_audio: tensor(silenceRef, [1, 750, LATENT_CHANNELS]), | |
| }); | |
| const timbreHidden = new Float32Array(HIDDEN_SIZE); | |
| timbreHidden.set(result.timbre_embedding.data); | |
| return { hidden: timbreHidden, mask: new Float32Array([1.0]), seqLen: 1 }; | |
| } | |
| async function generateLMHints(caption, lyrics, numLatentFrames, duration) { | |
| const { codes, elapsed, tokenCount } = await generateAudioCodesViaLM({ caption, lyrics, duration, numLatentFrames }); | |
| post("status", { message: `LM: ${codes.length} codes from ${tokenCount} tokens in ${elapsed}s` }); | |
| if (codes.length === 0) { | |
| console.warn("[lm] No audio codes generated, returning silence"); | |
| return new Float32Array(numLatentFrames * LATENT_CHANNELS); | |
| } | |
| const numCodes5Hz = codes.length; | |
| post("status", { message: "FSQ codebook lookup..." }); | |
| const lmHints5Hz = fsqLookup(codes, 1, numCodes5Hz); | |
| tensorStats("lm_hints_5hz", lmHints5Hz); | |
| post("status", { message: "Detokenizing 5Hz → 25Hz..." }); | |
| const detokResult = await sessions.detokenizer.run({ | |
| quantized: tensor(lmHints5Hz, [1, numCodes5Hz, HIDDEN_SIZE]), | |
| }); | |
| const lmHints25HzRaw = detokResult.lm_hints_25hz.data; | |
| const rawLen = lmHints25HzRaw.length / LATENT_CHANNELS; | |
| tensorStats("lm_hints_25hz_raw", lmHints25HzRaw); | |
| // Pad with last frame (MLX port behavior) or truncate | |
| const lmHints25Hz = new Float32Array(numLatentFrames * LATENT_CHANNELS); | |
| if (rawLen >= numLatentFrames) { | |
| lmHints25Hz.set(lmHints25HzRaw.slice(0, numLatentFrames * LATENT_CHANNELS)); | |
| } else { | |
| lmHints25Hz.set(lmHints25HzRaw); | |
| // Repeat last frame to fill remaining | |
| const lastFrameStart = (rawLen - 1) * LATENT_CHANNELS; | |
| const lastFrame = lmHints25HzRaw.slice(lastFrameStart, lastFrameStart + LATENT_CHANNELS); | |
| for (let t = rawLen; t < numLatentFrames; t++) { | |
| lmHints25Hz.set(lastFrame, t * LATENT_CHANNELS); | |
| } | |
| console.log(`[hints] padded ${rawLen} → ${numLatentFrames} frames with last-frame replication`); | |
| } | |
| tensorStats("lm_hints_25hz_final", lmHints25Hz); | |
| return lmHints25Hz; | |
| } | |
| async function generateAudio({ caption, lyrics, duration, shift, numSteps = 8 }) { | |
| const totalStartTime = performance.now(); | |
| const filenameStamp = Date.now(); | |
| const batchSize = 1; | |
| const numLatentFrames = Math.round(duration * LATENT_RATE); | |
| const tSchedule = buildSchedule(numSteps, shift); | |
| const metas = `duration: ${duration}s`; | |
| // 1. Text → Qwen3 embedding (1024-dim hidden states, BEFORE projection) | |
| post("status", { message: "Encoding text..." }); | |
| const sftPrompt = buildSFTPrompt(caption, metas); | |
| const textEnc = textTokenizer(sftPrompt, { padding: "max_length", max_length: 256, truncation: true }); | |
| const textIdsRaw = textEnc.input_ids.data; | |
| const textIds = textIdsRaw instanceof BigInt64Array ? textIdsRaw : new BigInt64Array(Array.from(textIdsRaw, BigInt)); | |
| const textHiddenRes = await sessions.textEncoder.run({ input_ids: tensor(textIds, [1, 256], "int64") }); | |
| const textHidden = textHiddenRes.hidden_states; | |
| const textMaskRaw = textEnc.attention_mask.data; | |
| const textMask = new Float32Array(textMaskRaw.length); | |
| for (let i = 0; i < textMaskRaw.length; i++) textMask[i] = Number(textMaskRaw[i]); | |
| // 2. Lyric tokens → embed_tokens (1024-dim, passed into condition_encoder's lyric_encoder) | |
| post("status", { message: "Embedding lyrics..." }); | |
| const lyricFullText = `# Languages\nen\n\n# Lyric\n${lyrics}`; | |
| const lyricEnc = textTokenizer(lyricFullText, { padding: "max_length", max_length: 2048, truncation: true }); | |
| const lyricIdsRaw = lyricEnc.input_ids.data; | |
| const lyricIds = lyricIdsRaw instanceof BigInt64Array ? lyricIdsRaw : new BigInt64Array(Array.from(lyricIdsRaw, BigInt)); | |
| const lyricEmbRes = await sessions.embedTokens.run({ input_ids: tensor(lyricIds, [1, 2048], "int64") }); | |
| const lyricEmb = lyricEmbRes.hidden_states; | |
| const lyricMaskRaw = lyricEnc.attention_mask.data; | |
| const lyricMask = new Float32Array(lyricMaskRaw.length); | |
| for (let i = 0; i < lyricMaskRaw.length; i++) lyricMask[i] = Number(lyricMaskRaw[i]); | |
| // 3. LM hints (mandatory for turbo model) | |
| const lmHints25Hz = await generateLMHints(caption, lyrics, numLatentFrames, duration); | |
| // 4. Silence for ref audio (timbre) and src_latents | |
| const silenceRef = silenceLatent.slice(0, 750 * LATENT_CHANNELS); | |
| const srcLatents = new Float32Array(numLatentFrames * LATENT_CHANNELS); | |
| const chunkMasks = new Float32Array(numLatentFrames * LATENT_CHANNELS).fill(1.0); | |
| const isCovers = new Float32Array([1.0]); // force use of LM hints | |
| // 5. condition_encoder: does text_projector + lyric_encoder + timbre_encoder + pack_sequences + context_latents | |
| post("status", { message: "Running condition encoder..." }); | |
| const condResult = await sessions.conditionEncoder.run({ | |
| text_hidden_states: textHidden, | |
| text_attention_mask: tensor(textMask, [1, 256]), | |
| lyric_hidden_states: lyricEmb, | |
| lyric_attention_mask: tensor(lyricMask, [1, 2048]), | |
| refer_audio_acoustic_hidden_states_packed: tensor(silenceRef, [1, 750, LATENT_CHANNELS]), | |
| refer_audio_order_mask: tensor(new BigInt64Array([0n]), [1], "int64"), | |
| src_latents: tensor(srcLatents, [1, numLatentFrames, LATENT_CHANNELS]), | |
| chunk_masks: tensor(chunkMasks, [1, numLatentFrames, LATENT_CHANNELS]), | |
| is_covers: tensor(isCovers, [1]), | |
| precomputed_lm_hints_25hz: tensor(lmHints25Hz, [1, numLatentFrames, LATENT_CHANNELS]), | |
| }); | |
| const encoderHiddenStates = condResult.encoder_hidden_states; | |
| const contextLatentsTensor = condResult.context_latents; | |
| tensorStats("encoder_hidden_states", encoderHiddenStates.data); | |
| tensorStats("context_latents", contextLatentsTensor.data); | |
| post("status", { message: "Starting denoising..." }); | |
| let xt = randn([batchSize, numLatentFrames, LATENT_CHANNELS]); | |
| const startTime = performance.now(); | |
| for (let step = 0; step < tSchedule.length; step++) { | |
| const tCurr = tSchedule[step]; | |
| post("status", { message: `Denoising step ${step + 1}/${tSchedule.length}...` }); | |
| const timestepData = new Float32Array(batchSize).fill(tCurr); | |
| const result = await sessions.ditDecoder.run({ | |
| hidden_states: tensor(xt, [batchSize, numLatentFrames, LATENT_CHANNELS]), | |
| timestep: tensor(timestepData, [batchSize]), | |
| encoder_hidden_states: encoderHiddenStates, | |
| context_latents: contextLatentsTensor, | |
| }); | |
| const vt = result.velocity.data; | |
| if (step === tSchedule.length - 1) { | |
| for (let i = 0; i < xt.length; i++) xt[i] = xt[i] - vt[i] * tCurr; | |
| } else { | |
| const dt = tCurr - tSchedule[step + 1]; | |
| for (let i = 0; i < xt.length; i++) xt[i] = xt[i] - vt[i] * dt; | |
| } | |
| } | |
| const diffusionTime = ((performance.now() - startTime) / 1000).toFixed(2); | |
| tensorStats("final_latent", xt); | |
| // Per-frame variance check — detects if later frames are constant | |
| const perFrameVariance = new Float32Array(numLatentFrames); | |
| for (let t = 0; t < numLatentFrames; t++) { | |
| let mean = 0; | |
| for (let c = 0; c < LATENT_CHANNELS; c++) mean += xt[t * LATENT_CHANNELS + c]; | |
| mean /= LATENT_CHANNELS; | |
| let varSum = 0; | |
| for (let c = 0; c < LATENT_CHANNELS; c++) { | |
| const d = xt[t * LATENT_CHANNELS + c] - mean; | |
| varSum += d * d; | |
| } | |
| perFrameVariance[t] = varSum / LATENT_CHANNELS; | |
| } | |
| console.log("[perframe] variance samples:", Array.from(perFrameVariance.filter((_, i) => i % 25 === 0)).map(v => v.toFixed(3))); | |
| // Also check LM hints per-frame variance | |
| const hintsVar = new Float32Array(numLatentFrames); | |
| for (let t = 0; t < numLatentFrames; t++) { | |
| let mean = 0; | |
| for (let c = 0; c < LATENT_CHANNELS; c++) mean += lmHints25Hz[t * LATENT_CHANNELS + c]; | |
| mean /= LATENT_CHANNELS; | |
| let varSum = 0; | |
| for (let c = 0; c < LATENT_CHANNELS; c++) { | |
| const d = lmHints25Hz[t * LATENT_CHANNELS + c] - mean; | |
| varSum += d * d; | |
| } | |
| hintsVar[t] = varSum / LATENT_CHANNELS; | |
| } | |
| console.log("[hints var] samples:", Array.from(hintsVar.filter((_, i) => i % 25 === 0)).map(v => v.toFixed(3))); | |
| post("status", { message: "Decoding audio..." }); | |
| const latentsForVae = new Float32Array(batchSize * LATENT_CHANNELS * numLatentFrames); | |
| for (let t = 0; t < numLatentFrames; t++) { | |
| for (let c = 0; c < LATENT_CHANNELS; c++) { | |
| latentsForVae[c * numLatentFrames + t] = xt[t * LATENT_CHANNELS + c]; | |
| } | |
| } | |
| const vaeResult = await sessions.vaeDecoder.run({ | |
| latents: tensor(latentsForVae, [batchSize, LATENT_CHANNELS, numLatentFrames]), | |
| }); | |
| const waveform = vaeResult.waveform.data; | |
| tensorStats("waveform", waveform); | |
| masterWaveform(waveform, SAMPLE_RATE, 2); | |
| const wavBuffer = float32ToWav(waveform, SAMPLE_RATE, 2); | |
| // totalTime measures the whole pipeline (LM + encoders + diffusion + VAE), | |
| // not just the diffusion loop. diffusionTime is reported separately below. | |
| const totalTime = ((performance.now() - totalStartTime) / 1000).toFixed(2); | |
| post("audio", { wavBuffer, duration, diffusionTime, totalTime, filenameStamp }, [wavBuffer]); | |
| } | |
| function measureAudio(samples) { | |
| let peak = 0; | |
| let sumSq = 0; | |
| for (let i = 0; i < samples.length; i++) { | |
| const v = samples[i]; | |
| const abs = Math.abs(v); | |
| if (abs > peak) peak = abs; | |
| sumSq += v * v; | |
| } | |
| return { peak, rms: Math.sqrt(sumSq / Math.max(1, samples.length)) }; | |
| } | |
| function goertzelPower(data, sampleRate, freq) { | |
| const omega = 2 * Math.PI * freq / sampleRate; | |
| const coeff = 2 * Math.cos(omega); | |
| let s0 = 0, s1 = 0, s2 = 0; | |
| for (let i = 0; i < data.length; i++) { | |
| s0 = data[i] + coeff * s1 - s2; | |
| s2 = s1; | |
| s1 = s0; | |
| } | |
| return s1 * s1 + s2 * s2 - coeff * s1 * s2; | |
| } | |
| function detectDronePeaks(samples, sampleRate, channels) { | |
| const numSamples = samples.length / channels; | |
| const step = Math.max(1, Math.floor(sampleRate / 4000)); | |
| const downsampleRate = sampleRate / step; | |
| const downsampledLength = Math.floor(numSamples / step); | |
| if (downsampledLength < 1024) return []; | |
| const mono = new Float32Array(downsampledLength); | |
| let mean = 0; | |
| for (let i = 0; i < downsampledLength; i++) { | |
| const src = i * step; | |
| let v = 0; | |
| for (let ch = 0; ch < channels; ch++) v += samples[ch * numSamples + src]; | |
| v /= channels; | |
| mono[i] = v; | |
| mean += v; | |
| } | |
| mean /= downsampledLength; | |
| for (let i = 0; i < mono.length; i++) mono[i] -= mean; | |
| const bins = []; | |
| for (let freq = 250; freq <= 950; freq += 12.5) { | |
| bins.push({ freq, power: goertzelPower(mono, downsampleRate, freq) }); | |
| } | |
| const sortedPowers = bins.map((bin) => bin.power).sort((a, b) => a - b); | |
| const median = sortedPowers[Math.floor(sortedPowers.length / 2)] + 1e-12; | |
| bins.sort((a, b) => b.power - a.power); | |
| const peaks = []; | |
| for (const bin of bins) { | |
| const score = bin.power / median; | |
| if (score < 12) break; | |
| if (peaks.every((peak) => Math.abs(peak.freq - bin.freq) >= 50)) { | |
| peaks.push({ freq: bin.freq, score }); | |
| if (peaks.length >= 2) break; | |
| } | |
| } | |
| return peaks; | |
| } | |
| function applyNotch(samples, sampleRate, channels, freq, q = 20, depth = 0.45) { | |
| const numSamples = samples.length / channels; | |
| const w0 = 2 * Math.PI * freq / sampleRate; | |
| const cos = Math.cos(w0); | |
| const alpha = Math.sin(w0) / (2 * q); | |
| const a0 = 1 + alpha; | |
| const b0 = 1 / a0; | |
| const b1 = (-2 * cos) / a0; | |
| const b2 = 1 / a0; | |
| const a1 = (-2 * cos) / a0; | |
| const a2 = (1 - alpha) / a0; | |
| for (let ch = 0; ch < channels; ch++) { | |
| const offset = ch * numSamples; | |
| let x1 = 0, x2 = 0, y1 = 0, y2 = 0; | |
| for (let i = 0; i < numSamples; i++) { | |
| const x0 = samples[offset + i]; | |
| const y0 = b0 * x0 + b1 * x1 + b2 * x2 - a1 * y1 - a2 * y2; | |
| samples[offset + i] = x0 * (1 - depth) + y0 * depth; | |
| x2 = x1; x1 = x0; | |
| y2 = y1; y1 = y0; | |
| } | |
| } | |
| } | |
| function masterWaveform(samples, sampleRate, channels) { | |
| const before = measureAudio(samples); | |
| if (before.peak <= 0.001) return; | |
| const dronePeaks = detectDronePeaks(samples, sampleRate, channels); | |
| for (const peak of dronePeaks) applyNotch(samples, sampleRate, channels, peak.freq); | |
| const afterEq = measureAudio(samples); | |
| const targetRms = 0.085; | |
| const maxPeak = 0.891; | |
| const maxGain = 12.0; | |
| const gain = Math.min( | |
| maxGain, | |
| targetRms / Math.max(afterEq.rms, 1e-6), | |
| maxPeak / Math.max(afterEq.peak, 1e-6), | |
| ); | |
| for (let i = 0; i < samples.length; i++) samples[i] *= gain; | |
| const after = measureAudio(samples); | |
| const peakText = dronePeaks.map((peak) => `${peak.freq.toFixed(1)}Hz/${peak.score.toFixed(0)}x`).join(", ") || "none"; | |
| console.log( | |
| `[master] rawPeak=${before.peak.toFixed(4)} rawRms=${before.rms.toFixed(4)} ` + | |
| `dronePeaks=${peakText} gain=${gain.toFixed(2)}x peak=${after.peak.toFixed(4)} rms=${after.rms.toFixed(4)}`, | |
| ); | |
| } | |
| function float32ToWav(samples, sampleRate, channels = 2) { | |
| const numSamples = samples.length / channels; | |
| const bitsPerSample = 16; | |
| const blockAlign = channels * (bitsPerSample / 8); | |
| const byteRate = sampleRate * blockAlign; | |
| const dataSize = numSamples * blockAlign; | |
| const buffer = new ArrayBuffer(44 + dataSize); | |
| const view = new DataView(buffer); | |
| const w = (o, s) => { for (let i = 0; i < s.length; i++) view.setUint8(o + i, s.charCodeAt(i)); }; | |
| w(0, "RIFF"); view.setUint32(4, 36 + dataSize, true); | |
| w(8, "WAVE"); w(12, "fmt "); view.setUint32(16, 16, true); | |
| view.setUint16(20, 1, true); view.setUint16(22, channels, true); | |
| view.setUint32(24, sampleRate, true); view.setUint32(28, byteRate, true); | |
| view.setUint16(32, blockAlign, true); view.setUint16(34, bitsPerSample, true); | |
| w(36, "data"); view.setUint32(40, dataSize, true); | |
| let offset = 44; | |
| for (let i = 0; i < numSamples; i++) { | |
| for (let ch = 0; ch < channels; ch++) { | |
| const sample = Math.max(-1, Math.min(1, samples[ch * numSamples + i])); | |
| view.setInt16(offset, sample * 32767, true); | |
| offset += 2; | |
| } | |
| } | |
| return buffer; | |
| } | |
| self.onmessage = async (e) => { | |
| const { type, ...data } = e.data; | |
| try { | |
| if (type === "load") await loadModels(); | |
| else if (type === "generate") await generateAudio(data); | |
| } catch (err) { | |
| post("error", { message: err.message, stack: err.stack }); | |
| } | |
| }; | |