datbkpro commited on
Commit
7719f9b
·
verified ·
1 Parent(s): 4988cc6

Update core/silero_vad.py

Browse files
Files changed (1) hide show
  1. core/silero_vad.py +47 -77
core/silero_vad.py CHANGED
@@ -240,14 +240,16 @@
240
  # return 0.0import torch
241
  import torch
242
  import numpy as np
243
- from typing import Optional, Callable
244
  from config.settings import settings
245
  import os
246
  import time
247
 
 
248
  class SileroVAD:
249
  def __init__(self):
250
  self.model = None
 
251
  self.sample_rate = 16000
252
  self.is_streaming = False
253
  self.speech_callback = None
@@ -260,37 +262,38 @@ class SileroVAD:
260
  """Khởi tạo Silero VAD model"""
261
  try:
262
  print("🔄 Đang tải Silero VAD model...")
263
-
264
- # Sử dụng torch.hub
265
- self.model = torch.hub.load(
266
  repo_or_dir='snakers4/silero-vad',
267
  model='silero_vad',
268
  force_reload=False,
269
  trust_repo=True
270
  )
271
-
272
  self.model.eval()
273
  print("✅ Đã tải Silero VAD model thành công")
274
-
275
  except Exception as e:
276
  print(f"❌ Lỗi tải Silero VAD model: {e}")
277
  self._initialize_model_fallback()
278
 
279
  def _initialize_model_fallback(self):
280
- """Fallback method"""
281
  try:
282
- # Tạo model trực tiếp
283
  model_dir = torch.hub.get_dir()
284
- model_path = os.path.join(model_dir, 'snakers4_silero-vad_master', 'files', 'silero_vad.jit')
285
-
 
 
286
  if os.path.exists(model_path):
287
  self.model = torch.jit.load(model_path)
288
  self.model.eval()
289
  print("✅ Đã tải Silero VAD model thành công (fallback)")
290
  else:
291
- print("❌ Không tìm thấy model file")
292
  self.model = None
293
-
294
  except Exception as e:
295
  print(f"❌ Lỗi tải Silero VAD model fallback: {e}")
296
  self.model = None
@@ -300,7 +303,7 @@ class SileroVAD:
300
  if self.model is None:
301
  print("❌ Silero VAD model chưa được khởi tạo")
302
  return False
303
-
304
  self.is_streaming = True
305
  self.speech_callback = speech_callback
306
  self.audio_buffer = []
@@ -317,7 +320,7 @@ class SileroVAD:
317
  print("🛑 Đã dừng Silero VAD streaming")
318
 
319
  def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
320
- """Xử lý audio chunk với Silero VAD cải tiến"""
321
  if not self.is_streaming or self.model is None:
322
  return
323
 
@@ -329,7 +332,7 @@ class SileroVAD:
329
  # Thêm vào buffer
330
  self.audio_buffer.extend(audio_chunk)
331
 
332
- # Xử lý khi buffer đủ lớn (0.5 giây)
333
  buffer_duration = len(self.audio_buffer) / self.sample_rate
334
  if buffer_duration >= 0.5:
335
  self._process_buffer()
@@ -338,55 +341,45 @@ class SileroVAD:
338
  print(f"❌ Lỗi xử lý Silero VAD: {e}")
339
 
340
  def _process_buffer(self):
341
- """Xử lý buffer audio với Silero VAD cải tiến"""
342
  try:
343
- chunk_size = int(self.sample_rate * 0.5) # 0.5 giây
344
  if len(self.audio_buffer) < chunk_size:
345
  return
346
 
347
  # Lấy chunk
348
  audio_chunk = np.array(self.audio_buffer[:chunk_size])
349
-
350
- # Chuẩn hóa audio
351
  audio_chunk = self._normalize_audio(audio_chunk)
352
-
353
- # Phát hiện speech
354
  speech_prob = self._get_speech_probability(audio_chunk)
355
-
356
  print(f"🎯 Silero VAD speech probability: {speech_prob:.3f}")
357
-
358
- # Ngưỡng phát hiện speech
359
  if speech_prob > settings.VAD_THRESHOLD:
360
  current_time = time.time()
361
-
362
  if self.speech_start_time == 0:
363
  self.speech_start_time = current_time
364
  print("🎯 Bắt đầu phát hiện speech")
365
-
366
- # Gọi callback nếu đủ thời gian speech
367
  speech_duration = current_time - self.speech_start_time
368
  if speech_duration >= self.min_speech_duration:
369
  if self.speech_callback:
370
- # Lấy toàn bộ audio từ buffer
371
  full_audio = np.array(self.audio_buffer)
372
  full_audio = self._normalize_audio(full_audio)
373
  self.speech_callback(full_audio, self.sample_rate)
374
-
375
- # Xóa buffer sau khi xử lý
376
  self.audio_buffer = []
377
  self.speech_start_time = 0
378
  else:
379
- # Reset nếu không phải speech
380
  if self.speech_start_time > 0:
381
  print("🔇 Kết thúc speech segment")
 
382
  self.speech_start_time = 0
383
-
384
- # Giữ lại 0.2 giây cuối để overlap
385
  keep_samples = int(self.sample_rate * 0.2)
386
- if len(self.audio_buffer) > keep_samples:
387
- self.audio_buffer = self.audio_buffer[-keep_samples:]
388
- else:
389
- self.audio_buffer = []
390
 
391
  except Exception as e:
392
  print(f"❌ Lỗi xử lý Silero VAD buffer: {e}")
@@ -394,89 +387,66 @@ class SileroVAD:
394
  self.speech_start_time = 0
395
 
396
  def _normalize_audio(self, audio: np.ndarray) -> np.ndarray:
397
- """Chuẩn hóa audio cho Silero VAD"""
398
  if audio.dtype != np.float32:
399
  audio = audio.astype(np.float32)
400
  if np.max(np.abs(audio)) > 1.0:
401
- audio = audio / 32768.0 # Normalize từ int16
402
-
403
  return np.clip(audio, -1.0, 1.0)
404
 
405
  def _get_speech_probability(self, audio_chunk: np.ndarray) -> float:
406
- """Lấy xác suất speech từ audio chunk"""
407
  try:
408
- # Đảm bảo độ dài phù hợp
409
  if len(audio_chunk) < 512:
410
  padding = np.zeros(512 - len(audio_chunk), dtype=np.float32)
411
  audio_chunk = np.concatenate([audio_chunk, padding])
412
-
413
- # Chuyển thành tensor
414
  audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
415
-
416
- # Phát hiện speech
417
  with torch.no_grad():
418
  return self.model(audio_tensor, self.sample_rate).item()
419
-
420
  except Exception as e:
421
  print(f"❌ Lỗi lấy speech probability: {e}")
422
  return 0.0
423
 
424
  def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
425
- """Resample audio"""
426
  if orig_sr == target_sr:
427
  return audio
428
-
429
  try:
430
- # Simple resampling
431
- orig_length = len(audio)
432
- new_length = int(orig_length * target_sr / orig_sr)
433
-
434
- x_old = np.linspace(0, 1, orig_length)
435
- x_new = np.linspace(0, 1, new_length)
436
- resampled_audio = np.interp(x_new, x_old, audio)
437
-
438
- return resampled_audio
439
  except Exception as e:
440
  print(f"⚠️ Lỗi resample: {e}")
441
  return audio
442
 
443
  def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
444
- """Kiểm tra xem audio chunk có phải speech không"""
445
  if self.model is None:
446
  return True
447
-
448
  try:
449
- # Resample nếu cần
450
  if sample_rate != self.sample_rate:
451
  audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
452
-
453
- # Chuẩn hóa audio
454
  audio_chunk = self._normalize_audio(audio_chunk)
455
-
456
- # Lấy xác suất speech
457
- speech_prob = self._get_speech_probability(audio_chunk)
458
-
459
- return speech_prob > settings.VAD_THRESHOLD
460
-
461
  except Exception as e:
462
- print(f"❌ Lỗi kiểm tra speech với Silero: {e}")
463
  return True
464
 
465
  def get_speech_probability(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
466
  """Lấy xác suất speech"""
467
  if self.model is None:
468
  return 0.0
469
-
470
  try:
471
- # Resample nếu cần
472
  if sample_rate != self.sample_rate:
473
  audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
474
-
475
- # Chuẩn hóa audio
476
  audio_chunk = self._normalize_audio(audio_chunk)
477
-
478
  return self._get_speech_probability(audio_chunk)
479
-
480
  except Exception as e:
481
  print(f"❌ Lỗi lấy speech probability: {e}")
482
- return 0.0
 
240
  # return 0.0import torch
241
  import torch
242
  import numpy as np
243
+ from typing import Callable
244
  from config.settings import settings
245
  import os
246
  import time
247
 
248
+
249
  class SileroVAD:
250
  def __init__(self):
251
  self.model = None
252
+ self.utils = None # giữ các hàm tiện ích
253
  self.sample_rate = 16000
254
  self.is_streaming = False
255
  self.speech_callback = None
 
262
  """Khởi tạo Silero VAD model"""
263
  try:
264
  print("🔄 Đang tải Silero VAD model...")
265
+
266
+ # Cách tải đúng (model, utils)
267
+ self.model, self.utils = torch.hub.load(
268
  repo_or_dir='snakers4/silero-vad',
269
  model='silero_vad',
270
  force_reload=False,
271
  trust_repo=True
272
  )
273
+
274
  self.model.eval()
275
  print("✅ Đã tải Silero VAD model thành công")
276
+
277
  except Exception as e:
278
  print(f"❌ Lỗi tải Silero VAD model: {e}")
279
  self._initialize_model_fallback()
280
 
281
  def _initialize_model_fallback(self):
282
+ """Fallback nếu torch.hub.load thất bại"""
283
  try:
 
284
  model_dir = torch.hub.get_dir()
285
+ model_path = os.path.join(
286
+ model_dir, 'snakers4_silero-vad_master', 'files', 'silero_vad.jit'
287
+ )
288
+
289
  if os.path.exists(model_path):
290
  self.model = torch.jit.load(model_path)
291
  self.model.eval()
292
  print("✅ Đã tải Silero VAD model thành công (fallback)")
293
  else:
294
+ print("❌ Không tìm thấy model file (fallback thất bại)")
295
  self.model = None
296
+
297
  except Exception as e:
298
  print(f"❌ Lỗi tải Silero VAD model fallback: {e}")
299
  self.model = None
 
303
  if self.model is None:
304
  print("❌ Silero VAD model chưa được khởi tạo")
305
  return False
306
+
307
  self.is_streaming = True
308
  self.speech_callback = speech_callback
309
  self.audio_buffer = []
 
320
  print("🛑 Đã dừng Silero VAD streaming")
321
 
322
  def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
323
+ """Xử lý audio chunk với Silero VAD"""
324
  if not self.is_streaming or self.model is None:
325
  return
326
 
 
332
  # Thêm vào buffer
333
  self.audio_buffer.extend(audio_chunk)
334
 
335
+ # Xử lý khi buffer đủ 0.5 giây
336
  buffer_duration = len(self.audio_buffer) / self.sample_rate
337
  if buffer_duration >= 0.5:
338
  self._process_buffer()
 
341
  print(f"❌ Lỗi xử lý Silero VAD: {e}")
342
 
343
  def _process_buffer(self):
344
+ """Xử lý buffer audio với Silero VAD"""
345
  try:
346
+ chunk_size = int(self.sample_rate * 0.5)
347
  if len(self.audio_buffer) < chunk_size:
348
  return
349
 
350
  # Lấy chunk
351
  audio_chunk = np.array(self.audio_buffer[:chunk_size])
 
 
352
  audio_chunk = self._normalize_audio(audio_chunk)
353
+
354
+ # Dự đoán xác suất speech
355
  speech_prob = self._get_speech_probability(audio_chunk)
 
356
  print(f"🎯 Silero VAD speech probability: {speech_prob:.3f}")
357
+
358
+ # Nếu vượt ngưỡng, xác nhận là speech
359
  if speech_prob > settings.VAD_THRESHOLD:
360
  current_time = time.time()
361
+
362
  if self.speech_start_time == 0:
363
  self.speech_start_time = current_time
364
  print("🎯 Bắt đầu phát hiện speech")
365
+
 
366
  speech_duration = current_time - self.speech_start_time
367
  if speech_duration >= self.min_speech_duration:
368
  if self.speech_callback:
 
369
  full_audio = np.array(self.audio_buffer)
370
  full_audio = self._normalize_audio(full_audio)
371
  self.speech_callback(full_audio, self.sample_rate)
372
+
 
373
  self.audio_buffer = []
374
  self.speech_start_time = 0
375
  else:
 
376
  if self.speech_start_time > 0:
377
  print("🔇 Kết thúc speech segment")
378
+
379
  self.speech_start_time = 0
380
+ # Giữ lại 0.2 giây overlap
 
381
  keep_samples = int(self.sample_rate * 0.2)
382
+ self.audio_buffer = self.audio_buffer[-keep_samples:]
 
 
 
383
 
384
  except Exception as e:
385
  print(f"❌ Lỗi xử lý Silero VAD buffer: {e}")
 
387
  self.speech_start_time = 0
388
 
389
  def _normalize_audio(self, audio: np.ndarray) -> np.ndarray:
390
+ """Chuẩn hóa audio"""
391
  if audio.dtype != np.float32:
392
  audio = audio.astype(np.float32)
393
  if np.max(np.abs(audio)) > 1.0:
394
+ audio = audio / 32768.0
 
395
  return np.clip(audio, -1.0, 1.0)
396
 
397
  def _get_speech_probability(self, audio_chunk: np.ndarray) -> float:
398
+ """Trả về xác suất speech"""
399
  try:
 
400
  if len(audio_chunk) < 512:
401
  padding = np.zeros(512 - len(audio_chunk), dtype=np.float32)
402
  audio_chunk = np.concatenate([audio_chunk, padding])
403
+
 
404
  audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
405
+
 
406
  with torch.no_grad():
407
  return self.model(audio_tensor, self.sample_rate).item()
408
+
409
  except Exception as e:
410
  print(f"❌ Lỗi lấy speech probability: {e}")
411
  return 0.0
412
 
413
  def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
414
+ """Resample đơn giản"""
415
  if orig_sr == target_sr:
416
  return audio
 
417
  try:
418
+ orig_len = len(audio)
419
+ new_len = int(orig_len * target_sr / orig_sr)
420
+ x_old = np.linspace(0, 1, orig_len)
421
+ x_new = np.linspace(0, 1, new_len)
422
+ return np.interp(x_new, x_old, audio)
 
 
 
 
423
  except Exception as e:
424
  print(f"⚠️ Lỗi resample: {e}")
425
  return audio
426
 
427
  def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
428
+ """Kiểm tra chunk có phải speech không"""
429
  if self.model is None:
430
  return True
 
431
  try:
 
432
  if sample_rate != self.sample_rate:
433
  audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
 
 
434
  audio_chunk = self._normalize_audio(audio_chunk)
435
+ prob = self._get_speech_probability(audio_chunk)
436
+ return prob > settings.VAD_THRESHOLD
 
 
 
 
437
  except Exception as e:
438
+ print(f"❌ Lỗi kiểm tra speech: {e}")
439
  return True
440
 
441
  def get_speech_probability(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
442
  """Lấy xác suất speech"""
443
  if self.model is None:
444
  return 0.0
 
445
  try:
 
446
  if sample_rate != self.sample_rate:
447
  audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
 
 
448
  audio_chunk = self._normalize_audio(audio_chunk)
 
449
  return self._get_speech_probability(audio_chunk)
 
450
  except Exception as e:
451
  print(f"❌ Lỗi lấy speech probability: {e}")
452
+ return 0.0