Spaces:
Sleeping
Sleeping
| """ | |
| Gradio demo application for the GigaAM-v3 speech recognition models. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| import threading | |
| import time | |
| from contextlib import contextmanager | |
| from typing import Callable, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from transformers import AutoModel | |
| REPO_ID = "ai-sage/GigaAM-v3" | |
| MODEL_VARIANTS: Dict[str, str] = { | |
| "e2e_rnnt": "End-to-end RNN-T • punctuation + normalization (best quality)", | |
| "e2e_ctc": "End-to-end CTC • punctuation + normalization (faster)", | |
| "rnnt": "RNN-T decoder • raw text without normalization", | |
| "ctc": "CTC decoder • fastest baseline", | |
| } | |
| DEFAULT_VARIANT = "e2e_rnnt" | |
| MAX_SHORT_SECONDS = float(os.getenv("MAX_AUDIO_DURATION_SECONDS", 150)) | |
| MAX_LONG_SECONDS = float(os.getenv("MAX_LONGFORM_DURATION_SECONDS", 600)) | |
| SHORTFORM_MODEL_LIMIT_SECONDS = float(os.getenv("SHORTFORM_MODEL_LIMIT_SECONDS", 25.0)) | |
| TARGET_SAMPLE_RATE = int(os.getenv("TARGET_SAMPLE_RATE", 16_000)) | |
| OUTPUT_MODES = { | |
| f"Short clip (<={str(int(SHORTFORM_MODEL_LIMIT_SECONDS))} s)": { | |
| "id": "short", | |
| "longform": False, | |
| "max_duration": MAX_SHORT_SECONDS, | |
| "limit_msg": f"Запись длиннее {str(int(SHORTFORM_MODEL_LIMIT_SECONDS))} секунд. Выберите режим 'Segmented long-form' для более длинных файлов.", | |
| "description": "Single call to `model.transcribe`; best latency for concise utterances.", | |
| "requires_token": False, | |
| }, | |
| f"Segmented long-form (<={str(int(MAX_LONG_SECONDS))} s)": { | |
| "id": "longform", | |
| "longform": True, | |
| "max_duration": MAX_LONG_SECONDS, | |
| "limit_msg": f"Длина аудио превышает {str(int(SHORTFORM_MODEL_LIMIT_SECONDS))} секунд. Сократите запись для сегментированного режима.", | |
| "description": "Calls `model.transcribe_longform` to obtain timestamped segments.", | |
| "requires_token": True, | |
| }, | |
| } | |
| DEFAULT_MODE_LABEL = next(iter(OUTPUT_MODES)) | |
| DEFAULT_HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_CACHE: Dict[str, AutoModel] = {} | |
| MODEL_LOCKS = {variant: threading.Lock() for variant in MODEL_VARIANTS} | |
| def _format_seconds(value: float) -> str: | |
| return f"{value:.2f}s" | |
| def _prepare_audio(audio_path: str) -> Tuple[str, float, int, Callable[[], None]]: | |
| """ | |
| Convert the incoming audio to mono 16 kHz PCM WAV that GigaAM expects. | |
| Returns a tuple of (normalized_path, duration_seconds, sample_rate, cleanup_fn). | |
| """ | |
| data, sample_rate = sf.read(audio_path, dtype="float32", always_2d=False) | |
| if data.ndim > 1: | |
| data = data.mean(axis=1) | |
| duration = len(data) / float(sample_rate) | |
| waveform = torch.from_numpy(np.copy(data)) | |
| if waveform.ndim > 1: | |
| waveform = waveform.mean(dim=0) | |
| if sample_rate != TARGET_SAMPLE_RATE: | |
| waveform = torchaudio.functional.resample( | |
| waveform.unsqueeze(0), | |
| orig_freq=sample_rate, | |
| new_freq=TARGET_SAMPLE_RATE, | |
| ).squeeze(0) | |
| sample_rate = TARGET_SAMPLE_RATE | |
| normalized = waveform.clamp(min=-1.0, max=1.0).numpy() | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp.close() | |
| sf.write(tmp.name, normalized, sample_rate, subtype="PCM_16") | |
| def cleanup() -> None: | |
| try: | |
| os.remove(tmp.name) | |
| except OSError: | |
| pass | |
| return tmp.name, duration, sample_rate, cleanup | |
| def _extract_token_from_oauth(oauth_token: gr.OAuthToken | None) -> Optional[str]: | |
| """Extract access token from Gradio OAuthToken.""" | |
| if not oauth_token: | |
| return None | |
| return oauth_token.token | |
| def _resolve_access_token(oauth_token: gr.OAuthToken | None) -> Optional[str]: | |
| """Prefer the OAuth-issued token, fall back to the space-level secret.""" | |
| user_token = _extract_token_from_oauth(oauth_token) | |
| if user_token: | |
| return user_token | |
| return DEFAULT_HF_TOKEN | |
| def _temporary_token(token: Optional[str]): | |
| if not token: | |
| yield | |
| return | |
| previous_hf = os.environ.get("HF_TOKEN") | |
| previous_hub = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| os.environ["HF_TOKEN"] = token | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = token | |
| try: | |
| yield | |
| finally: | |
| if previous_hf is None: | |
| os.environ.pop("HF_TOKEN", None) | |
| else: | |
| os.environ["HF_TOKEN"] = previous_hf | |
| if previous_hub is None: | |
| os.environ.pop("HUGGINGFACEHUB_API_TOKEN", None) | |
| else: | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = previous_hub | |
| def _describe_token_source( | |
| profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| ) -> str: | |
| """Describe where the HF token came from.""" | |
| if _extract_token_from_oauth(oauth_token): | |
| username = profile.username if profile else "user" | |
| return f"{username} (OAuth)" | |
| if DEFAULT_HF_TOKEN: | |
| return "space secret" | |
| return "not configured" | |
| def _run_longform(model: AutoModel, audio_path: str, token: Optional[str]) -> Tuple[str, List[List[float | str]]]: | |
| if not token: | |
| raise gr.Error( | |
| "Для сегментированного режима требуется авторизация через Hugging Face OAuth " | |
| "или переменная окружения HF_TOKEN с доступом к 'pyannote/segmentation-3.0'." | |
| ) | |
| with _temporary_token(token): | |
| utterances = model.transcribe_longform(audio_path) | |
| segments: List[List[float | str]] = [] | |
| assembled_text_parts: List[str] = [] | |
| for utt in utterances: | |
| text = _normalize_text(utt) | |
| if isinstance(utt, dict): | |
| boundaries = utt.get("boundaries") or utt.get("timestamps") | |
| else: | |
| boundaries = None | |
| if not boundaries: | |
| boundaries = (0.0, 0.0) | |
| start, end = boundaries | |
| segments.append([round(float(start), 2), round(float(end), 2), text]) | |
| assembled_text_parts.append(text) | |
| transcription_text = "\n".join(filter(None, assembled_text_parts)).strip() | |
| return transcription_text, segments | |
| def _normalize_text(text: object) -> str: | |
| if text is None: | |
| return "" | |
| if isinstance(text, str): | |
| return text.strip() | |
| if isinstance(text, dict): | |
| for key in ("transcription", "text"): | |
| if key in text and isinstance(text[key], str): | |
| return text[key].strip() | |
| return str(text) | |
| def load_model(variant: str) -> AutoModel: | |
| if variant not in MODEL_VARIANTS: | |
| raise gr.Error(f"Вариант модели '{variant}' не поддерживается.") | |
| if variant in MODEL_CACHE: | |
| return MODEL_CACHE[variant] | |
| lock = MODEL_LOCKS[variant] | |
| with lock: | |
| if variant in MODEL_CACHE: | |
| return MODEL_CACHE[variant] | |
| load_kwargs = dict(revision=variant, trust_remote_code=True) | |
| if DEFAULT_HF_TOKEN: | |
| load_kwargs["token"] = DEFAULT_HF_TOKEN | |
| model = AutoModel.from_pretrained(REPO_ID, **load_kwargs) | |
| try: | |
| model.to(DEVICE) | |
| except Exception: | |
| # Some remote implementations manage their own device placement. | |
| pass | |
| MODEL_CACHE[variant] = model | |
| return model | |
| def transcribe_audio( | |
| audio_path: Optional[str], | |
| variant: str, | |
| mode_label: str, | |
| profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| ) -> tuple[str, List[List[float | str]], str]: | |
| if not audio_path or not os.path.exists(audio_path): | |
| raise gr.Error("Загрузите или запишите аудиофайл, чтобы начать распознавание.") | |
| if mode_label not in OUTPUT_MODES: | |
| raise gr.Error("Выберите режим транскрипции.") | |
| mode_cfg = OUTPUT_MODES[mode_label] | |
| prepared_path, duration, sample_rate, cleanup = _prepare_audio(audio_path) | |
| if duration < 0.3: | |
| raise gr.Error("Запись слишком короткая (<300 мс).") | |
| if duration > mode_cfg["max_duration"]: | |
| raise gr.Error(mode_cfg["limit_msg"]) | |
| effective_token = _resolve_access_token(oauth_token) | |
| if mode_cfg["requires_token"] and not effective_token: | |
| raise gr.Error( | |
| "Для сегментированного режима требуется авторизация через Hugging Face OAuth " | |
| "или переменная окружения HF_TOKEN с доступом к модели 'pyannote/segmentation-3.0'." | |
| ) | |
| progress = gr.Progress(track_tqdm=False) | |
| progress(0.1, desc="Загрузка модели") | |
| model = load_model(variant) | |
| start_ts = time.perf_counter() | |
| progress(0.55, desc="Распознавание речи") | |
| auto_switched = False | |
| try: | |
| if mode_cfg["longform"]: | |
| transcription_text, segments = _run_longform(model, prepared_path, effective_token) | |
| else: | |
| if duration > SHORTFORM_MODEL_LIMIT_SECONDS: | |
| if not effective_token: | |
| raise gr.Error( | |
| "Аудио длиннее лимита короткого режима (~25 секунд). " | |
| "Авторизуйтесь через Hugging Face OAuth или добавьте HF_TOKEN, " | |
| "чтобы использовать сегментированное распознавание." | |
| ) | |
| auto_switched = True | |
| transcription_text, segments = _run_longform(model, prepared_path, effective_token) | |
| else: | |
| try: | |
| result = model.transcribe(prepared_path) | |
| transcription_text = _normalize_text(result) | |
| segments = [] | |
| except ValueError as exc: | |
| if "too long" in str(exc).lower(): | |
| if not effective_token: | |
| raise gr.Error( | |
| "GigaAM потребовала режим transcribe_longform. " | |
| "Войдите через OAuth или добавьте HF_TOKEN и повторите попытку." | |
| ) | |
| auto_switched = True | |
| transcription_text, segments = _run_longform(model, prepared_path, effective_token) | |
| else: | |
| raise | |
| finally: | |
| cleanup() | |
| latency = time.perf_counter() - start_ts | |
| progress(1.0, desc="Готово") | |
| mode_description = mode_cfg["description"] | |
| if auto_switched: | |
| mode_description += " · auto switched to segmented" | |
| metadata_lines = [ | |
| f"- **Model variant:** {MODEL_VARIANTS[variant]}", | |
| f"- **Transcription mode:** {mode_description}", | |
| f"- **Audio duration:** {_format_seconds(duration)} @ {sample_rate} Hz", | |
| f"- **Latency:** {_format_seconds(latency)} on `{DEVICE}`", | |
| f"- **Token source:** {_describe_token_source(profile, oauth_token)}", | |
| ] | |
| return transcription_text, segments, "\n".join(metadata_lines) | |
| DESCRIPTION_MD = """ | |
| # GigaAM-v3 · Russian ASR demo | |
| This Space showcases the [`ai-sage/GigaAM-v3`](https://huggingface.co/ai-sage/GigaAM-v3) Conformer-based models. | |
| - Upload or record Russian audio (WAV/MP3/FLAC, mono preferred). | |
| - Pick the model variant and transcription mode that matches your latency/quality needs. | |
| - Clips are resampled to mono 16 kHz automatically for best compatibility. | |
| - Sign in with Hugging Face OAuth to unlock segmented long-form transcription (requires access to `pyannote/segmentation-3.0`). | |
| """ | |
| FOOTER_MD = f""" | |
| **Tips** | |
| - Short clips (<{str(int(SHORTFORM_MODEL_LIMIT_SECONDS))}s) work best with the E2E variants (they include punctuation and normalization). | |
| - Long recordings can take several minutes on CPU-only Spaces; switch to GPU hardware if available. | |
| - `model.transcribe` is limited to ~25 s internally; longer clips will auto-switch to segmented mode when a token is available. | |
| - Source: [salute-developers/GigaAM](https://github.com/salute-developers/GigaAM) | |
| """ | |
| def build_interface() -> gr.Blocks: | |
| with gr.Blocks(title="GigaAM-v3 ASR demo") as demo: | |
| gr.Markdown(DESCRIPTION_MD) | |
| with gr.Row(): | |
| login_button = gr.LoginButton(min_width=200) | |
| with gr.Row(equal_height=True): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Russian audio", | |
| waveform_options=gr.WaveformOptions( | |
| waveform_color="#f97316", | |
| skip_length=2, | |
| ), | |
| ) | |
| with gr.Column(): | |
| variant_dropdown = gr.Dropdown( | |
| choices=list(MODEL_VARIANTS.keys()), | |
| value=DEFAULT_VARIANT, | |
| label="Model variant", | |
| info="End-to-end variants add punctuation; base CTC/RNNT are lighter but raw.", | |
| ) | |
| mode_radio = gr.Radio( | |
| choices=list(OUTPUT_MODES.keys()), | |
| value=DEFAULT_MODE_LABEL, | |
| label="Transcription mode", | |
| info=f"Select segmented mode for >{str(int(SHORTFORM_MODEL_LIMIT_SECONDS))} second clips (requires HF token).", | |
| ) | |
| transcribe_btn = gr.Button("Transcribe", variant="primary") | |
| transcript_output = gr.Textbox( | |
| label="Transcript", | |
| placeholder="Model output will appear here…", | |
| lines=8, | |
| ) | |
| segments_output = gr.Dataframe( | |
| headers=["Start (s)", "End (s)", "Utterance"], | |
| datatype=["number", "number", "str"], | |
| label="Segments (long-form mode)", | |
| interactive=False, | |
| ) | |
| metadata_output = gr.Markdown() | |
| gr.Markdown(FOOTER_MD) | |
| transcribe_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, variant_dropdown, mode_radio], | |
| outputs=[transcript_output, segments_output, metadata_output], | |
| ) | |
| return demo | |
| demo = build_interface() | |
| def _launch_app() -> None: | |
| """Launch the Gradio app with sensible defaults for HF Spaces and local runs.""" | |
| is_space = bool(os.getenv("SPACE_ID")) | |
| launch_kwargs = { | |
| "server_name": os.getenv("GRADIO_SERVER_NAME", "0.0.0.0" if is_space else "127.0.0.1"), | |
| "server_port": int(os.getenv("GRADIO_SERVER_PORT", "7860")), | |
| "theme": gr.themes.Ocean(), | |
| } | |
| if not is_space and os.getenv("GRADIO_SHARE", "0") == "1": | |
| launch_kwargs["share"] = True | |
| enable_queue = os.getenv("ENABLE_GRADIO_QUEUE", "1") != "0" | |
| app = demo.queue(max_size=8) if enable_queue else demo | |
| app.launch(**launch_kwargs) | |
| if __name__ == "__main__": | |
| _launch_app() | |