jeanma's picture
Omnilingual ASR transcription demo
ae238b3 verified
import torch
import torchaudio
import numpy as np
import logging
import tempfile
import os
import threading
from typing import List, Tuple, Dict, Optional, Any
import silero_vad
import soundfile as sf
import librosa
logger = logging.getLogger(__name__)
TARGET_CHUNK_DURATION = 30.0
MIN_CHUNK_DURATION = 5.0
SAMPLE_RATE = 16000
class AudioChunker:
"""
Handles audio chunking with different strategies:
- 'none': Single chunk (no chunking)
- 'vad': VAD-based intelligent chunking
- 'static': Fixed-duration time-based chunking
"""
_instance = None
_instance_lock = threading.Lock()
vad_model: Optional[Any]
def __new__(cls):
if cls._instance is None:
with cls._instance_lock:
# Check again after acquiring lock as the value could have been set
if cls._instance is None:
cls._instance = super().__new__(cls)
# Only load VAD model here since this only runs once
cls._instance.vad_model = cls.load_vad_model()
return cls._instance
@staticmethod
def load_vad_model():
"""Load silero VAD model with error handling."""
try:
logger.info("Loading Silero VAD model...")
vad_model = silero_vad.load_silero_vad()
logger.info("βœ“ VAD model loaded successfully")
return vad_model
except Exception as e:
logger.error(f"Failed to load VAD model: {e}")
logger.warning("VAD chunking will fall back to time-based chunking")
return None
@torch.inference_mode()
def chunk_audio(self, audio_tensor: torch.Tensor, sample_rate: int = SAMPLE_RATE, mode: str = "vad", chunk_duration: float = 30.0) -> List[Dict]:
"""
Chunk audio tensor using specified strategy.
Args:
audio_tensor: Audio tensor (1D waveform)
sample_rate: Sample rate of the audio tensor
mode: Chunking mode - 'none', 'vad', or 'static'
chunk_duration: Target duration for static chunking (seconds)
Returns:
List of chunk info dicts with uniform format:
- start_time: Start time in seconds
- end_time: End time in seconds
- duration: Duration in seconds
- audio_data: Audio tensor for this chunk
- sample_rate: Sample rate
- chunk_index: Index of this chunk
"""
logger.info(f"Chunking audio tensor: {audio_tensor.shape} at {sample_rate}Hz (mode: {mode})")
try:
# Assert tensor is already 1D (should be preprocessed by MediaTranscriptionProcessor)
assert len(audio_tensor.shape) == 1, f"Expected 1D audio tensor, got shape {audio_tensor.shape}"
# Assert sample rate is already 16kHz (should be preprocessed by MediaTranscriptionProcessor)
assert sample_rate == SAMPLE_RATE, f"Expected {SAMPLE_RATE}Hz sample rate, got {sample_rate}Hz"
# Route to appropriate chunking strategy
if mode == "none":
return self._create_single_chunk(audio_tensor, sample_rate)
elif mode == "vad":
if self.vad_model is not None:
return self._chunk_with_vad(audio_tensor)
else:
logger.warning("VAD model not available, falling back to static chunking")
return self._chunk_static(audio_tensor, chunk_duration)
elif mode == "static":
return self._chunk_static(audio_tensor, chunk_duration)
else:
raise ValueError(f"Unknown chunking mode: {mode}")
except Exception as e:
logger.error(f"Error chunking audio tensor: {e}")
# Ultimate fallback to single chunk
return self._create_single_chunk(audio_tensor, sample_rate)
def _create_single_chunk(self, waveform: torch.Tensor, sample_rate: int = SAMPLE_RATE) -> List[Dict]:
"""Create a single chunk containing the entire audio."""
duration = len(waveform) / sample_rate
return [{
"start_time": 0.0,
"end_time": duration,
"duration": duration,
"audio_data": waveform,
"sample_rate": sample_rate,
"chunk_index": 0,
}]
def _chunk_static(self, waveform: torch.Tensor, chunk_duration: float) -> List[Dict]:
"""Create fixed-duration chunks."""
chunks = []
total_samples = len(waveform)
target_samples = int(chunk_duration * SAMPLE_RATE)
start_sample = 0
chunk_idx = 0
while start_sample < total_samples:
end_sample = min(start_sample + target_samples, total_samples)
chunk_audio = waveform[start_sample:end_sample]
duration = len(chunk_audio) / SAMPLE_RATE
# Only add chunk if it meets minimum duration
if duration >= MIN_CHUNK_DURATION:
chunks.append({
"start_time": start_sample / SAMPLE_RATE,
"end_time": end_sample / SAMPLE_RATE,
"duration": duration,
"audio_data": chunk_audio,
"sample_rate": SAMPLE_RATE,
"chunk_index": chunk_idx,
})
chunk_idx += 1
start_sample = end_sample
logger.info(f"Created {len(chunks)} static chunks of ~{chunk_duration}s each")
return chunks
def _chunk_fallback(self, audio_path: str) -> List[Dict]:
"""Ultimate fallback - create single chunk using librosa (for file-based legacy method)."""
try:
logger.warning("Using librosa fallback for chunking")
data, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
waveform = torch.from_numpy(data)
return self._create_single_chunk(waveform, SAMPLE_RATE)
except Exception as e:
logger.error(f"All chunking methods failed: {e}")
return []
def _chunk_with_vad(self, waveform: torch.Tensor) -> List[Dict]:
"""Chunk audio using VAD for speech detection with uniform return format."""
try:
# VAD model expects tensor on CPU
vad_waveform = waveform.cpu() if waveform.is_cuda else waveform
# Get speech timestamps using VAD
speech_timestamps = silero_vad.get_speech_timestamps(
vad_waveform,
self.vad_model,
sampling_rate=SAMPLE_RATE,
min_speech_duration_ms=500, # Minimum speech segment
min_silence_duration_ms=300, # Minimum silence to split
window_size_samples=1536,
speech_pad_ms=100, # Padding around speech
)
logger.info(f"Found {len(speech_timestamps)} speech segments")
# Create chunks based on speech segments and target duration
# Pass original waveform (with device preserved) to chunk creation
chunks = self._create_chunks_from_speech_segments(
waveform, speech_timestamps
)
logger.info(f"Created {len(chunks)} audio chunks using VAD")
return chunks
except Exception as e:
logger.error(f"VAD chunking failed: {e}")
return self._chunk_static(waveform, TARGET_CHUNK_DURATION)
def _create_chunks_from_speech_segments(
self, waveform: torch.Tensor, speech_segments: List[Dict]
) -> List[Dict]:
"""Create chunks that respect speech boundaries and target duration with uniform format."""
if not speech_segments:
logger.warning(
"No speech segments found, falling back to static chunking"
)
return self._chunk_static(waveform, TARGET_CHUNK_DURATION)
chunks = []
current_chunk_start = 0
target_samples = int(TARGET_CHUNK_DURATION * SAMPLE_RATE)
total_samples = len(waveform)
chunk_idx = 0
while current_chunk_start < total_samples:
# Calculate target end for this chunk
target_chunk_end = current_chunk_start + target_samples
# If this would be the last chunk or close to it, just take the rest
if target_chunk_end >= total_samples or (
total_samples - target_chunk_end
) < (target_samples * 0.3):
chunk_end = total_samples
else:
# Find the best place to end this chunk using VAD, but ensure continuous coverage
chunk_end = self._find_best_chunk_end_continuous(
speech_segments,
current_chunk_start,
target_chunk_end,
total_samples,
)
# Create chunk with uniform format
chunk_audio = waveform[current_chunk_start:chunk_end]
duration = len(chunk_audio) / SAMPLE_RATE
chunks.append({
"start_time": current_chunk_start / SAMPLE_RATE,
"end_time": chunk_end / SAMPLE_RATE,
"duration": duration,
"audio_data": chunk_audio,
"sample_rate": SAMPLE_RATE,
"chunk_index": chunk_idx,
})
logger.info(
f"Created chunk {chunk_idx + 1}: {current_chunk_start/SAMPLE_RATE:.2f}s - {chunk_end/SAMPLE_RATE:.2f}s ({duration:.2f}s)"
)
chunk_idx += 1
# Move to next chunk - IMPORTANT: start exactly where this chunk ended
current_chunk_start = chunk_end
# Verify total coverage
total_audio_duration = len(waveform) / SAMPLE_RATE
total_chunks_duration = sum(chunk["duration"] for chunk in chunks)
logger.info(
f"Audio chunking complete: {len(chunks)} chunks covering {total_chunks_duration:.2f}s of {total_audio_duration:.2f}s total audio"
)
if (
abs(total_chunks_duration - total_audio_duration) > 0.01
): # Allow 10ms tolerance
logger.error(
f"Duration mismatch: chunks={total_chunks_duration:.2f}s, original={total_audio_duration:.2f}s"
)
else:
logger.info("βœ“ Perfect audio coverage achieved")
return chunks
def _find_best_chunk_end_continuous(
self,
speech_segments: List[Dict],
chunk_start: int,
target_end: int,
total_samples: int,
) -> int:
"""Find the best place to end a chunk while ensuring continuous coverage."""
# Don't go beyond the audio
target_end = min(target_end, total_samples)
# Look for a good break point within a reasonable window around target
search_window = int(SAMPLE_RATE * 3) # 3 second window
search_start = max(chunk_start, target_end - search_window)
search_end = min(total_samples, target_end + search_window)
best_end = target_end
best_score = 0
# Look for speech segment boundaries within the search window
for segment in speech_segments:
segment_start = segment["start"]
segment_end = segment["end"]
# Check if segment end is in our search window
if search_start <= segment_end <= search_end:
# Score based on how close to target and if it's a good break point
distance_score = 1.0 - abs(segment_end - target_end) / search_window
# Prefer segment ends (natural pauses)
boundary_score = 1.0
total_score = distance_score * boundary_score
if total_score > best_score:
best_score = total_score
best_end = segment_end
# Ensure we don't go beyond audio bounds
best_end = min(int(best_end), total_samples)
# Ensure we make progress (don't end before we started)
if best_end <= chunk_start:
best_end = min(target_end, total_samples)
return best_end
def _find_best_chunk_end(
self,
speech_segments: List[Dict],
start_idx: int,
chunk_start: int,
target_end: int,
) -> int:
"""Find the best place to end a chunk (at silence, near target duration)."""
best_end = target_end
# Look for speech segments that could provide good break points
for i in range(start_idx, len(speech_segments)):
segment = speech_segments[i]
segment_start = segment["start"]
segment_end = segment["end"]
# If segment starts after our target end, use the gap before it
if segment_start > target_end:
best_end = min(target_end, segment_start)
break
# If segment ends near our target, use the end of the segment
if abs(segment_end - target_end) < SAMPLE_RATE * 5: # Within 5 seconds
best_end = segment_end
break
# If segment extends way past target, look for a good break point
if segment_end > target_end + SAMPLE_RATE * 10: # 10+ seconds past
# Try to find a silence gap within the segment or use target
best_end = target_end
break
return int(best_end)
def save_chunk_to_file(self, chunk: Dict, output_path: str) -> str:
"""Save a chunk to a temporary audio file."""
try:
# Convert tensor to numpy if needed
audio_data = chunk["audio_data"]
if isinstance(audio_data, torch.Tensor):
# Move to CPU first if on GPU, then convert to numpy
audio_data = audio_data.cpu().numpy()
# Save to file
sf.write(output_path, audio_data, chunk["sample_rate"])
return output_path
except Exception as e:
logger.error(f"Failed to save chunk to file: {e}")
raise