File size: 3,163 Bytes
24b9788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import { useState, useRef, useCallback, useEffect } from "react";

export function useModel() {
  const workerRef = useRef(null);
  const audioUrlRef = useRef(null);
  const [status, setStatus] = useState("idle");
  const [message, setMessage] = useState("");
  const [progress, setProgress] = useState(null);
  const [audioUrl, setAudioUrl] = useState(null);
  const [audioInfo, setAudioInfo] = useState(null);
  const [error, setError] = useState(null);
  const [isLoaded, setIsLoaded] = useState(false);

  // Revoke a URL owned by this hook and forget it.
  const revokeCurrentAudioUrl = useCallback(() => {
    if (audioUrlRef.current) {
      URL.revokeObjectURL(audioUrlRef.current);
      audioUrlRef.current = null;
    }
  }, []);

  useEffect(() => {
    const worker = new Worker(new URL("../worker.js", import.meta.url), {
      type: "module",
    });

    worker.onmessage = (e) => {
      const { type, ...data } = e.data;
      switch (type) {
        case "status":
          setMessage(data.message);
          break;
        case "progress":
          setProgress(data);
          break;
        case "loaded":
          setIsLoaded(true);
          setStatus("ready");
          setProgress(null);
          break;
        case "audio": {
          // Revoke any previous URL owned by this hook before overwriting.
          if (audioUrlRef.current) URL.revokeObjectURL(audioUrlRef.current);
          const blob = new Blob([data.wavBuffer], { type: "audio/wav" });
          const url = URL.createObjectURL(blob);
          audioUrlRef.current = url;
          setAudioUrl(url);
          setAudioInfo({
            duration: data.duration,
            diffusionTime: data.diffusionTime,
            totalTime: data.totalTime,
            filename: `ace-step-${data.filenameStamp || Date.now()}.wav`,
          });
          setStatus("ready");
          setMessage("Generation complete!");
          break;
        }
        case "error":
          setError(data.message);
          setStatus("error");
          console.error("Worker error:", data.message, data.stack);
          break;
      }
    };

    workerRef.current = worker;
    return () => {
      worker.terminate();
      if (audioUrlRef.current) {
        URL.revokeObjectURL(audioUrlRef.current);
        audioUrlRef.current = null;
      }
    };
  }, []);

  const loadModel = useCallback(() => {
    setStatus("loading");
    setError(null);
    workerRef.current?.postMessage({ type: "load" });
  }, []);

  const generate = useCallback(({ caption, lyrics, duration, shift, numSteps }) => {
    setStatus("generating");
    setError(null);
    // Revoke the previous URL when user starts a new gen so the next "audio" message
    // doesn't compete with a still-displayed blob.
    revokeCurrentAudioUrl();
    setAudioUrl(null);
    setAudioInfo(null);
    workerRef.current?.postMessage({
      type: "generate",
      caption,
      lyrics,
      duration,
      shift,
      numSteps,
    });
  }, [revokeCurrentAudioUrl]);

  return {
    status,
    message,
    progress,
    audioUrl,
    audioInfo,
    error,
    isLoaded,
    loadModel,
    generate,
  };
}