|
|
import asyncio |
|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from typing import Callable, Optional, AsyncGenerator |
|
|
import numpy as np |
|
|
from google import genai |
|
|
from google.genai.types import ( |
|
|
LiveConnectConfig, |
|
|
PrebuiltVoiceConfig, |
|
|
SpeechConfig, |
|
|
VoiceConfig, |
|
|
) |
|
|
from google.api_core import exceptions as google_exceptions |
|
|
|
|
|
class GeminiRealtimeService: |
|
|
"""Dịch vụ Gemini Realtime API với audio streaming thực""" |
|
|
|
|
|
def __init__(self, api_key: str = None): |
|
|
self.api_key = api_key or os.getenv("GEMINI_API_KEY") |
|
|
self.client = None |
|
|
self.session = None |
|
|
self.is_active = False |
|
|
self.callback = None |
|
|
self.voice_name = "Puck" |
|
|
self.input_queue = asyncio.Queue() |
|
|
self.output_queue = asyncio.Queue() |
|
|
self._session_task = None |
|
|
|
|
|
async def initialize(self): |
|
|
"""Khởi tạo client Gemini""" |
|
|
if not self.api_key: |
|
|
raise ValueError("Gemini API key is required") |
|
|
|
|
|
try: |
|
|
self.client = genai.Client( |
|
|
api_key=self.api_key, |
|
|
http_options={"api_version": "v1alpha"}, |
|
|
) |
|
|
return True |
|
|
except Exception as e: |
|
|
raise Exception(f"Không thể khởi tạo Gemini client: {str(e)}") |
|
|
|
|
|
def encode_audio(self, data: np.ndarray) -> str: |
|
|
"""Encode audio data to base64""" |
|
|
return base64.b64encode(data.tobytes()).decode("UTF-8") |
|
|
|
|
|
async def start_session(self, voice_name: str = "Puck", callback: Callable = None): |
|
|
"""Bắt đầu session Gemini Realtime với audio streaming""" |
|
|
try: |
|
|
if not self.client: |
|
|
await self.initialize() |
|
|
|
|
|
self.voice_name = voice_name |
|
|
self.callback = callback |
|
|
|
|
|
|
|
|
config = LiveConnectConfig( |
|
|
response_modalities=["AUDIO"], |
|
|
speech_config=SpeechConfig( |
|
|
voice_config=VoiceConfig( |
|
|
prebuilt_voice_config=PrebuiltVoiceConfig( |
|
|
voice_name=voice_name, |
|
|
) |
|
|
) |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self.session = self.client.aio.live.connect( |
|
|
model="gemini-2.0-flash-exp", |
|
|
config=config |
|
|
) |
|
|
|
|
|
self.is_active = True |
|
|
|
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'status', |
|
|
'message': f'✅ Đã kết nối Gemini Audio Streaming - Giọng: {voice_name}', |
|
|
'status': 'connected' |
|
|
}) |
|
|
|
|
|
print("✅ Gemini Realtime Audio session started") |
|
|
|
|
|
|
|
|
self._session_task = asyncio.create_task(self._handle_session()) |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Lỗi khởi động Gemini Audio: {e}" |
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'error', |
|
|
'message': error_msg, |
|
|
'status': 'error' |
|
|
}) |
|
|
print(error_msg) |
|
|
return False |
|
|
|
|
|
async def _handle_session(self): |
|
|
"""Xử lý session realtime với async with đúng cách""" |
|
|
try: |
|
|
|
|
|
async with self.session as session: |
|
|
|
|
|
sender_task = asyncio.create_task(self._audio_sender(session)) |
|
|
receiver_task = asyncio.create_task(self._audio_receiver(session)) |
|
|
|
|
|
|
|
|
await asyncio.gather(sender_task, receiver_task) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Lỗi trong session: {e}" |
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'error', |
|
|
'message': error_msg, |
|
|
'status': 'error' |
|
|
}) |
|
|
print(error_msg) |
|
|
|
|
|
async def _audio_sender(self, session): |
|
|
"""Gửi audio data đến Gemini""" |
|
|
try: |
|
|
async for audio_chunk in self._audio_stream_generator(): |
|
|
await session.send(audio_chunk) |
|
|
except Exception as e: |
|
|
error_msg = f"❌ Lỗi gửi audio: {e}" |
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'error', |
|
|
'message': error_msg, |
|
|
'status': 'error' |
|
|
}) |
|
|
|
|
|
async def _audio_stream_generator(self) -> AsyncGenerator[bytes, None]: |
|
|
"""Generator cho audio streaming""" |
|
|
while self.is_active: |
|
|
try: |
|
|
audio_data = await asyncio.wait_for(self.input_queue.get(), timeout=1.0) |
|
|
yield audio_data |
|
|
except asyncio.TimeoutError: |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi audio generator: {e}") |
|
|
break |
|
|
|
|
|
async def _audio_receiver(self, session): |
|
|
"""Nhận audio response từ Gemini""" |
|
|
try: |
|
|
async for response in session: |
|
|
if hasattr(response, 'data') and response.data: |
|
|
|
|
|
audio_data = np.frombuffer(response.data, dtype=np.int16) |
|
|
|
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'audio', |
|
|
'audio_data': audio_data, |
|
|
'sample_rate': 24000, |
|
|
'status': 'audio_streaming' |
|
|
}) |
|
|
|
|
|
|
|
|
self.output_queue.put_nowait((24000, audio_data)) |
|
|
|
|
|
elif hasattr(response, 'text') and response.text: |
|
|
|
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'text', |
|
|
'content': response.text, |
|
|
'role': 'assistant', |
|
|
'status': 'text_response' |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Lỗi nhận audio: {e}" |
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'error', |
|
|
'message': error_msg, |
|
|
'status': 'error' |
|
|
}) |
|
|
|
|
|
async def send_audio_chunk(self, audio_chunk: np.ndarray, sample_rate: int = 16000): |
|
|
"""Gửi audio chunk đến Gemini""" |
|
|
if not self.is_active: |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
if sample_rate != 16000: |
|
|
audio_chunk = self._resample_audio(audio_chunk, sample_rate, 16000) |
|
|
|
|
|
|
|
|
audio_bytes = audio_chunk.tobytes() |
|
|
await self.input_queue.put(audio_bytes) |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi gửi audio chunk: {e}") |
|
|
return False |
|
|
|
|
|
async def receive_audio(self) -> tuple[int, np.ndarray] | None: |
|
|
"""Nhận audio từ Gemini""" |
|
|
try: |
|
|
return await asyncio.wait_for(self.output_queue.get(), timeout=1.0) |
|
|
except asyncio.TimeoutError: |
|
|
return None |
|
|
|
|
|
def _resample_audio(self, audio_chunk: np.ndarray, original_rate: int, target_rate: int): |
|
|
"""Resample audio chunk (đơn giản)""" |
|
|
if original_rate == target_rate: |
|
|
return audio_chunk |
|
|
|
|
|
ratio = target_rate / original_rate |
|
|
new_length = int(len(audio_chunk) * ratio) |
|
|
return np.interp( |
|
|
np.linspace(0, len(audio_chunk) - 1, new_length), |
|
|
np.arange(len(audio_chunk)), |
|
|
audio_chunk |
|
|
).astype(np.int16) |
|
|
|
|
|
async def send_text(self, text: str): |
|
|
"""Gửi text message (fallback)""" |
|
|
if not self.client: |
|
|
return None |
|
|
|
|
|
try: |
|
|
response = await self.client.aio.models.generate_content( |
|
|
model="gemini-2.0-flash-exp", |
|
|
contents=text |
|
|
) |
|
|
|
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'text', |
|
|
'content': response.text, |
|
|
'role': 'assistant', |
|
|
'status': 'text_response' |
|
|
}) |
|
|
|
|
|
return response.text |
|
|
|
|
|
except google_exceptions.ResourceExhausted: |
|
|
error_msg = "❌ Quota Gemini đã hết. Vui lòng kiểm tra billing." |
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'error', |
|
|
'message': error_msg, |
|
|
'status': 'quota_exceeded' |
|
|
}) |
|
|
return error_msg |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Lỗi gửi text: {e}" |
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'error', |
|
|
'message': error_msg, |
|
|
'status': 'error' |
|
|
}) |
|
|
return error_msg |
|
|
|
|
|
async def close(self): |
|
|
"""Đóng kết nối""" |
|
|
self.is_active = False |
|
|
|
|
|
|
|
|
if self._session_task: |
|
|
self._session_task.cancel() |
|
|
try: |
|
|
await self._session_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
if self.callback: |
|
|
await self.callback({ |
|
|
'type': 'status', |
|
|
'message': '🛑 Đã đóng kết nối Gemini Audio', |
|
|
'status': 'disconnected' |
|
|
}) |
|
|
|
|
|
print("🛑 Gemini Audio session closed") |