File size: 7,701 Bytes
11c4a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import subprocess
import torch
import numpy as np
import onnxruntime
import warnings
from app.interfaces import IVoiceActivityEngine
from app.logger_config import (
    logger as logging,
    DEBUG
) 

class VoiceActivityDetection():

    def __init__(self, force_onnx_cpu=True):
        logging.info("Initializing VoiceActivityDetection...")
        path = self.download()

        opts = onnxruntime.SessionOptions()
        opts.log_severity_level = 3  # Suppress ONNX runtime's own logs

        opts.inter_op_num_threads = 1
        opts.intra_op_num_threads = 1

        try:
            if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
                self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
                logging.info("ONNX VAD session created with CPUExecutionProvider.")
            else:
                self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
                logging.info("ONNX VAD session created with CUDAExecutionProvider.")
        except Exception as e:
            logging.critical(f"Failed to create ONNX InferenceSession: {e}", exc_info=True)
            raise

        self.reset_states()
        if '16k' in path:
            logging.warning('This VAD model supports only 16000 sampling rate!')
            self.sample_rates = [16000]
        else:
            logging.info("VAD model supports 8000Hz and 16000Hz.")
            self.sample_rates = [8000, 16000]

    def _validate_input(self, x, sr: int):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if x.dim() > 2:
            logging.error(f"Too many dimensions for input audio chunk: {x.dim()}")
            raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")

        if sr != 16000 and (sr % 16000 == 0):
            step = sr // 16000
            x = x[:,::step]
            sr = 16000
            logging.debug(f"Downsampled input audio to 16000Hz from {sr}Hz.")

        if sr not in self.sample_rates:
            logging.error(f"Unsupported sampling rate: {sr}. Supported: {self.sample_rates}")
            raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
        
        return x, sr

    def reset_states(self, batch_size=1):
        logging.debug(f"Resetting VAD states for batch_size: {batch_size}")
        self._state = torch.zeros((2, batch_size, 128)).float()
        self._context = torch.zeros(0)
        self._last_sr = 0
        self._last_batch_size = 0

    def __call__(self, x, sr: int):

        x, sr = self._validate_input(x, sr)
        num_samples = 512 if sr == 16000 else 256

        if x.shape[-1] != num_samples:
            logging.error(f"Invalid audio chunk size: {x.shape[-1]}. Expected {num_samples} for {sr}Hz.")
            raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")

        batch_size = x.shape[0]
        context_size = 64 if sr == 16000 else 32

        if not self._last_batch_size:
            logging.debug("First call, resetting states.")
            self.reset_states(batch_size)
        if (self._last_sr) and (self._last_sr != sr):
            logging.warning(f"Sample rate changed ({self._last_sr}Hz -> {sr}Hz). Resetting states.")
            self.reset_states(batch_size)
        if (self._last_batch_size) and (self._last_batch_size != batch_size):
            logging.warning(f"Batch size changed ({self._last_batch_size} -> {batch_size}). Resetting states.")
            self.reset_states(batch_size)

        if not len(self._context):
            self._context = torch.zeros(batch_size, context_size)

        x = torch.cat([self._context, x], dim=1)
        if sr in [8000, 16000]:
            ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
            ort_outs = self.session.run(None, ort_inputs)
            out, state = ort_outs
            self._state = torch.from_numpy(state)
        else:
            # This should be caught by _validate_input, but as a safeguard:
            logging.critical(f"Unexpected sample rate in VAD __call__: {sr}")
            raise ValueError()

        self._context = x[..., -context_size:]
        self._last_sr = sr
        self._last_batch_size = batch_size

        out = torch.from_numpy(out)
        return out

    def audio_forward(self, x, sr: int):
        outs = []
        x, sr = self._validate_input(x, sr)
        self.reset_states()
        num_samples = 512 if sr == 16000 else 256

        if x.shape[1] % num_samples:
            pad_num = num_samples - (x.shape[1] % num_samples)
            logging.debug(f"Padding audio input with {pad_num} samples.")
            x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)

        for i in range(0, x.shape[1], num_samples):
            wavs_batch = x[:, i:i+num_samples]
            out_chunk = self.__call__(wavs_batch, sr)
            outs.append(out_chunk)

        stacked = torch.cat(outs, dim=1)
        return stacked.cpu()

    @staticmethod
    def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
        target_dir = os.path.expanduser("~/.cache/silero_vad/")
        os.makedirs(target_dir, exist_ok=True)
        model_filename = os.path.join(target_dir, "silero_vad.onnx")

        if not os.path.exists(model_filename):
            logging.info(f"Downloading VAD model to {model_filename}...")
            try:
                subprocess.run(["wget", "-O", model_filename, model_url], check=True)
                logging.info("VAD model downloaded successfully.")
            except subprocess.CalledProcessError as e:
                logging.critical(f"Failed to download the model using wget: {e}")
                raise
        else:
            logging.info(f"VAD model already exists at {model_filename}.")
        return model_filename


class Silero_Vad_Engine(IVoiceActivityEngine):
    def __init__(self, threshold :float =0.5, frame_rate: int =16000):
        """
        Initializes the Silero_Vad_Engine with a voice activity detection model and a threshold.

        Args:
            threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
        """
        logging.info(f"Initializing Silero_Vad_Engine with threshold: {threshold} and frame_rate: {frame_rate}Hz.")
        self.model = VoiceActivityDetection()
        self.threshold = threshold
        self.frame_rate = frame_rate

    def __call__(self, audio_frame):
        """
        Determines if the given audio frame contains speech by comparing the detected speech probability against
        the threshold.

        Args:
            audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
                                      NumPy array of audio samples.

        Returns:
            bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
                  False otherwise.
        """
        # Convert frame to tensor
        audio_tensor = torch.from_numpy(audio_frame.copy())
        
        # Get speech probabilities
        speech_probs = self.model.audio_forward(audio_tensor, self.frame_rate)[0]
        
        # Check against threshold
        is_speech = torch.any(speech_probs > self.threshold).item()
        
        logging.debug(f"VAD check result: {is_speech} (Max prob: {torch.max(speech_probs).item():.4f})")
        
        return is_speech