Spaces:
Running
on
A100
Running
on
A100
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import logging | |
| import tempfile | |
| import os | |
| import threading | |
| from typing import List, Tuple, Dict, Optional, Any | |
| import silero_vad | |
| import soundfile as sf | |
| import librosa | |
| logger = logging.getLogger(__name__) | |
| TARGET_CHUNK_DURATION = 30.0 | |
| MIN_CHUNK_DURATION = 5.0 | |
| SAMPLE_RATE = 16000 | |
| class AudioChunker: | |
| """ | |
| Handles audio chunking with different strategies: | |
| - 'none': Single chunk (no chunking) | |
| - 'vad': VAD-based intelligent chunking | |
| - 'static': Fixed-duration time-based chunking | |
| """ | |
| _instance = None | |
| _instance_lock = threading.Lock() | |
| vad_model: Optional[Any] | |
| def __new__(cls): | |
| if cls._instance is None: | |
| with cls._instance_lock: | |
| # Check again after acquiring lock as the value could have been set | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| # Only load VAD model here since this only runs once | |
| cls._instance.vad_model = cls.load_vad_model() | |
| return cls._instance | |
| def load_vad_model(): | |
| """Load silero VAD model with error handling.""" | |
| try: | |
| logger.info("Loading Silero VAD model...") | |
| vad_model = silero_vad.load_silero_vad() | |
| logger.info("β VAD model loaded successfully") | |
| return vad_model | |
| except Exception as e: | |
| logger.error(f"Failed to load VAD model: {e}") | |
| logger.warning("VAD chunking will fall back to time-based chunking") | |
| return None | |
| def chunk_audio(self, audio_tensor: torch.Tensor, sample_rate: int = SAMPLE_RATE, mode: str = "vad", chunk_duration: float = 30.0) -> List[Dict]: | |
| """ | |
| Chunk audio tensor using specified strategy. | |
| Args: | |
| audio_tensor: Audio tensor (1D waveform) | |
| sample_rate: Sample rate of the audio tensor | |
| mode: Chunking mode - 'none', 'vad', or 'static' | |
| chunk_duration: Target duration for static chunking (seconds) | |
| Returns: | |
| List of chunk info dicts with uniform format: | |
| - start_time: Start time in seconds | |
| - end_time: End time in seconds | |
| - duration: Duration in seconds | |
| - audio_data: Audio tensor for this chunk | |
| - sample_rate: Sample rate | |
| - chunk_index: Index of this chunk | |
| """ | |
| logger.info(f"Chunking audio tensor: {audio_tensor.shape} at {sample_rate}Hz (mode: {mode})") | |
| try: | |
| # Assert tensor is already 1D (should be preprocessed by MediaTranscriptionProcessor) | |
| assert len(audio_tensor.shape) == 1, f"Expected 1D audio tensor, got shape {audio_tensor.shape}" | |
| # Assert sample rate is already 16kHz (should be preprocessed by MediaTranscriptionProcessor) | |
| assert sample_rate == SAMPLE_RATE, f"Expected {SAMPLE_RATE}Hz sample rate, got {sample_rate}Hz" | |
| # Route to appropriate chunking strategy | |
| if mode == "none": | |
| return self._create_single_chunk(audio_tensor, sample_rate) | |
| elif mode == "vad": | |
| if self.vad_model is not None: | |
| return self._chunk_with_vad(audio_tensor) | |
| else: | |
| logger.warning("VAD model not available, falling back to static chunking") | |
| return self._chunk_static(audio_tensor, chunk_duration) | |
| elif mode == "static": | |
| return self._chunk_static(audio_tensor, chunk_duration) | |
| else: | |
| raise ValueError(f"Unknown chunking mode: {mode}") | |
| except Exception as e: | |
| logger.error(f"Error chunking audio tensor: {e}") | |
| # Ultimate fallback to single chunk | |
| return self._create_single_chunk(audio_tensor, sample_rate) | |
| def _create_single_chunk(self, waveform: torch.Tensor, sample_rate: int = SAMPLE_RATE) -> List[Dict]: | |
| """Create a single chunk containing the entire audio.""" | |
| duration = len(waveform) / sample_rate | |
| return [{ | |
| "start_time": 0.0, | |
| "end_time": duration, | |
| "duration": duration, | |
| "audio_data": waveform, | |
| "sample_rate": sample_rate, | |
| "chunk_index": 0, | |
| }] | |
| def _chunk_static(self, waveform: torch.Tensor, chunk_duration: float) -> List[Dict]: | |
| """Create fixed-duration chunks.""" | |
| chunks = [] | |
| total_samples = len(waveform) | |
| target_samples = int(chunk_duration * SAMPLE_RATE) | |
| start_sample = 0 | |
| chunk_idx = 0 | |
| while start_sample < total_samples: | |
| end_sample = min(start_sample + target_samples, total_samples) | |
| chunk_audio = waveform[start_sample:end_sample] | |
| duration = len(chunk_audio) / SAMPLE_RATE | |
| # Only add chunk if it meets minimum duration | |
| if duration >= MIN_CHUNK_DURATION: | |
| chunks.append({ | |
| "start_time": start_sample / SAMPLE_RATE, | |
| "end_time": end_sample / SAMPLE_RATE, | |
| "duration": duration, | |
| "audio_data": chunk_audio, | |
| "sample_rate": SAMPLE_RATE, | |
| "chunk_index": chunk_idx, | |
| }) | |
| chunk_idx += 1 | |
| start_sample = end_sample | |
| logger.info(f"Created {len(chunks)} static chunks of ~{chunk_duration}s each") | |
| return chunks | |
| def _chunk_fallback(self, audio_path: str) -> List[Dict]: | |
| """Ultimate fallback - create single chunk using librosa (for file-based legacy method).""" | |
| try: | |
| logger.warning("Using librosa fallback for chunking") | |
| data, sr = librosa.load(audio_path, sr=SAMPLE_RATE) | |
| waveform = torch.from_numpy(data) | |
| return self._create_single_chunk(waveform, SAMPLE_RATE) | |
| except Exception as e: | |
| logger.error(f"All chunking methods failed: {e}") | |
| return [] | |
| def _chunk_with_vad(self, waveform: torch.Tensor) -> List[Dict]: | |
| """Chunk audio using VAD for speech detection with uniform return format.""" | |
| try: | |
| # VAD model expects tensor on CPU | |
| vad_waveform = waveform.cpu() if waveform.is_cuda else waveform | |
| # Get speech timestamps using VAD | |
| speech_timestamps = silero_vad.get_speech_timestamps( | |
| vad_waveform, | |
| self.vad_model, | |
| sampling_rate=SAMPLE_RATE, | |
| min_speech_duration_ms=500, # Minimum speech segment | |
| min_silence_duration_ms=300, # Minimum silence to split | |
| window_size_samples=1536, | |
| speech_pad_ms=100, # Padding around speech | |
| ) | |
| logger.info(f"Found {len(speech_timestamps)} speech segments") | |
| # Create chunks based on speech segments and target duration | |
| # Pass original waveform (with device preserved) to chunk creation | |
| chunks = self._create_chunks_from_speech_segments( | |
| waveform, speech_timestamps | |
| ) | |
| logger.info(f"Created {len(chunks)} audio chunks using VAD") | |
| return chunks | |
| except Exception as e: | |
| logger.error(f"VAD chunking failed: {e}") | |
| return self._chunk_static(waveform, TARGET_CHUNK_DURATION) | |
| def _create_chunks_from_speech_segments( | |
| self, waveform: torch.Tensor, speech_segments: List[Dict] | |
| ) -> List[Dict]: | |
| """Create chunks that respect speech boundaries and target duration with uniform format.""" | |
| if not speech_segments: | |
| logger.warning( | |
| "No speech segments found, falling back to static chunking" | |
| ) | |
| return self._chunk_static(waveform, TARGET_CHUNK_DURATION) | |
| chunks = [] | |
| current_chunk_start = 0 | |
| target_samples = int(TARGET_CHUNK_DURATION * SAMPLE_RATE) | |
| total_samples = len(waveform) | |
| chunk_idx = 0 | |
| while current_chunk_start < total_samples: | |
| # Calculate target end for this chunk | |
| target_chunk_end = current_chunk_start + target_samples | |
| # If this would be the last chunk or close to it, just take the rest | |
| if target_chunk_end >= total_samples or ( | |
| total_samples - target_chunk_end | |
| ) < (target_samples * 0.3): | |
| chunk_end = total_samples | |
| else: | |
| # Find the best place to end this chunk using VAD, but ensure continuous coverage | |
| chunk_end = self._find_best_chunk_end_continuous( | |
| speech_segments, | |
| current_chunk_start, | |
| target_chunk_end, | |
| total_samples, | |
| ) | |
| # Create chunk with uniform format | |
| chunk_audio = waveform[current_chunk_start:chunk_end] | |
| duration = len(chunk_audio) / SAMPLE_RATE | |
| chunks.append({ | |
| "start_time": current_chunk_start / SAMPLE_RATE, | |
| "end_time": chunk_end / SAMPLE_RATE, | |
| "duration": duration, | |
| "audio_data": chunk_audio, | |
| "sample_rate": SAMPLE_RATE, | |
| "chunk_index": chunk_idx, | |
| }) | |
| logger.info( | |
| f"Created chunk {chunk_idx + 1}: {current_chunk_start/SAMPLE_RATE:.2f}s - {chunk_end/SAMPLE_RATE:.2f}s ({duration:.2f}s)" | |
| ) | |
| chunk_idx += 1 | |
| # Move to next chunk - IMPORTANT: start exactly where this chunk ended | |
| current_chunk_start = chunk_end | |
| # Verify total coverage | |
| total_audio_duration = len(waveform) / SAMPLE_RATE | |
| total_chunks_duration = sum(chunk["duration"] for chunk in chunks) | |
| logger.info( | |
| f"Audio chunking complete: {len(chunks)} chunks covering {total_chunks_duration:.2f}s of {total_audio_duration:.2f}s total audio" | |
| ) | |
| if ( | |
| abs(total_chunks_duration - total_audio_duration) > 0.01 | |
| ): # Allow 10ms tolerance | |
| logger.error( | |
| f"Duration mismatch: chunks={total_chunks_duration:.2f}s, original={total_audio_duration:.2f}s" | |
| ) | |
| else: | |
| logger.info("β Perfect audio coverage achieved") | |
| return chunks | |
| def _find_best_chunk_end_continuous( | |
| self, | |
| speech_segments: List[Dict], | |
| chunk_start: int, | |
| target_end: int, | |
| total_samples: int, | |
| ) -> int: | |
| """Find the best place to end a chunk while ensuring continuous coverage.""" | |
| # Don't go beyond the audio | |
| target_end = min(target_end, total_samples) | |
| # Look for a good break point within a reasonable window around target | |
| search_window = int(SAMPLE_RATE * 3) # 3 second window | |
| search_start = max(chunk_start, target_end - search_window) | |
| search_end = min(total_samples, target_end + search_window) | |
| best_end = target_end | |
| best_score = 0 | |
| # Look for speech segment boundaries within the search window | |
| for segment in speech_segments: | |
| segment_start = segment["start"] | |
| segment_end = segment["end"] | |
| # Check if segment end is in our search window | |
| if search_start <= segment_end <= search_end: | |
| # Score based on how close to target and if it's a good break point | |
| distance_score = 1.0 - abs(segment_end - target_end) / search_window | |
| # Prefer segment ends (natural pauses) | |
| boundary_score = 1.0 | |
| total_score = distance_score * boundary_score | |
| if total_score > best_score: | |
| best_score = total_score | |
| best_end = segment_end | |
| # Ensure we don't go beyond audio bounds | |
| best_end = min(int(best_end), total_samples) | |
| # Ensure we make progress (don't end before we started) | |
| if best_end <= chunk_start: | |
| best_end = min(target_end, total_samples) | |
| return best_end | |
| def _find_best_chunk_end( | |
| self, | |
| speech_segments: List[Dict], | |
| start_idx: int, | |
| chunk_start: int, | |
| target_end: int, | |
| ) -> int: | |
| """Find the best place to end a chunk (at silence, near target duration).""" | |
| best_end = target_end | |
| # Look for speech segments that could provide good break points | |
| for i in range(start_idx, len(speech_segments)): | |
| segment = speech_segments[i] | |
| segment_start = segment["start"] | |
| segment_end = segment["end"] | |
| # If segment starts after our target end, use the gap before it | |
| if segment_start > target_end: | |
| best_end = min(target_end, segment_start) | |
| break | |
| # If segment ends near our target, use the end of the segment | |
| if abs(segment_end - target_end) < SAMPLE_RATE * 5: # Within 5 seconds | |
| best_end = segment_end | |
| break | |
| # If segment extends way past target, look for a good break point | |
| if segment_end > target_end + SAMPLE_RATE * 10: # 10+ seconds past | |
| # Try to find a silence gap within the segment or use target | |
| best_end = target_end | |
| break | |
| return int(best_end) | |
| def save_chunk_to_file(self, chunk: Dict, output_path: str) -> str: | |
| """Save a chunk to a temporary audio file.""" | |
| try: | |
| # Convert tensor to numpy if needed | |
| audio_data = chunk["audio_data"] | |
| if isinstance(audio_data, torch.Tensor): | |
| # Move to CPU first if on GPU, then convert to numpy | |
| audio_data = audio_data.cpu().numpy() | |
| # Save to file | |
| sf.write(output_path, audio_data, chunk["sample_rate"]) | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Failed to save chunk to file: {e}") | |
| raise | |