malek-messaoudii commited on
Commit
4a13628
·
1 Parent(s): 218c6a3
models/audio.py CHANGED
@@ -1,49 +1,78 @@
1
- from pydantic import BaseModel, Field, ConfigDict
2
  from typing import Optional
3
 
4
- # ==============================
5
- # SPEECH TO TEXT RESPONSE
6
- # ==============================
7
  class STTResponse(BaseModel):
8
- model_config = ConfigDict(
9
- json_schema_extra={
 
 
 
 
 
 
10
  "example": {
11
  "text": "hello how are you",
12
- "model_name": "openai/whisper-large-v3",
13
  "language": "en",
14
  "duration_seconds": 3.2
15
  }
16
  }
17
- )
18
- text: str = Field(..., description="Transcribed text from the input audio")
19
- model_name: str = Field(..., description="STT model used for inference")
20
- language: Optional[str] = Field(None, description="Detected language")
21
- duration_seconds: Optional[float] = Field(
22
- None,
23
- description="Approximate audio duration in seconds"
24
- )
25
-
26
- # ==============================
27
- # TEXT TO SPEECH REQUEST / RESPONSE
28
- # ==============================
29
  class TTSRequest(BaseModel):
30
- model_config = ConfigDict(
31
- json_schema_extra={"example": {"text": "Hello, welcome to our AI system."}}
32
- )
33
  text: str = Field(..., min_length=1, max_length=500, description="Text to convert to speech")
 
 
 
 
 
 
 
 
34
 
35
  class TTSResponse(BaseModel):
36
- model_config = ConfigDict(
37
- json_schema_extra={
 
 
 
 
 
 
38
  "example": {
39
  "message": "Audio generated successfully",
40
  "audio_format": "wav",
41
- "length_seconds": 2.5,
42
- "model_name": "suno/bark"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  }
44
  }
45
- )
46
- message: str
47
- audio_format: str
48
- length_seconds: Optional[float] = None
49
- model_name: str
 
1
+ from pydantic import BaseModel, Field
2
  from typing import Optional
3
 
4
+
 
 
5
  class STTResponse(BaseModel):
6
+ """Response model for Speech-to-Text"""
7
+ text: str = Field(..., description="Transcribed text from audio")
8
+ model_name: str = Field(default="gemini-2.5-flash", description="Model used")
9
+ language: Optional[str] = Field(default="en", description="Detected language")
10
+ duration_seconds: Optional[float] = Field(None, description="Audio duration")
11
+
12
+ class Config:
13
+ json_schema_extra = {
14
  "example": {
15
  "text": "hello how are you",
16
+ "model_name": "gemini-2.5-flash",
17
  "language": "en",
18
  "duration_seconds": 3.2
19
  }
20
  }
21
+
22
+
 
 
 
 
 
 
 
 
 
 
23
  class TTSRequest(BaseModel):
24
+ """Request model for Text-to-Speech"""
 
 
25
  text: str = Field(..., min_length=1, max_length=500, description="Text to convert to speech")
26
+
27
+ class Config:
28
+ json_schema_extra = {
29
+ "example": {
30
+ "text": "Hello, welcome to our AI voice system."
31
+ }
32
+ }
33
+
34
 
35
  class TTSResponse(BaseModel):
36
+ """Response model for Text-to-Speech"""
37
+ message: str = Field(..., description="Status message")
38
+ audio_format: str = Field(default="wav", description="Audio format")
39
+ model_name: str = Field(default="gemini-2.5-flash-preview-tts", description="Model used")
40
+ length_seconds: Optional[float] = Field(None, description="Generated audio duration")
41
+
42
+ class Config:
43
+ json_schema_extra = {
44
  "example": {
45
  "message": "Audio generated successfully",
46
  "audio_format": "wav",
47
+ "model_name": "gemini-2.5-flash-preview-tts",
48
+ "length_seconds": 2.5
49
+ }
50
+ }
51
+
52
+
53
+ class ChatbotRequest(BaseModel):
54
+ """Request model for Chatbot"""
55
+ text: str = Field(..., min_length=1, max_length=500, description="User query")
56
+
57
+ class Config:
58
+ json_schema_extra = {
59
+ "example": {
60
+ "text": "What is the weather today?"
61
+ }
62
+ }
63
+
64
+
65
+ class ChatbotResponse(BaseModel):
66
+ """Response model for Chatbot"""
67
+ user_input: str = Field(..., description="User input text")
68
+ bot_response: str = Field(..., description="Bot response text")
69
+ model_name: str = Field(default="gemini-2.5-flash", description="Model used")
70
+
71
+ class Config:
72
+ json_schema_extra = {
73
+ "example": {
74
+ "user_input": "Hello",
75
+ "bot_response": "Hi there! How can I help you?",
76
+ "model_name": "gemini-2.5-flash"
77
  }
78
  }
 
 
 
 
 
routes/audio.py CHANGED
@@ -1,91 +1,134 @@
1
  from fastapi import APIRouter, UploadFile, File, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  import io
4
- from services.tts_service import generate_tts
 
5
  from services.stt_service import speech_to_text
 
 
 
 
 
6
 
7
  router = APIRouter(prefix="/audio", tags=["Audio"])
8
 
9
- # Allowed MIME types
10
- ALLOWED_AUDIO_TYPES = {
11
- "audio/wav",
12
- "audio/x-wav",
13
- "audio/mpeg", # mp3
14
- "audio/mp3", # mp3
15
- "audio/mp4", # sometimes m4a
16
- "audio/m4a" # m4a
17
- }
18
 
19
- # ------------------------
20
- # Text to Speech
21
- # ------------------------
22
  @router.post("/tts")
23
- async def tts(text: str):
24
- """
25
- Convert text to speech and return audio.
26
  """
27
- if not text or len(text) > 500:
28
- raise HTTPException(
29
- status_code=400,
30
- detail="Text must be between 1 and 500 characters"
31
- )
32
 
 
 
 
 
 
33
  try:
34
- audio_bytes = await generate_tts(text)
 
 
35
  except Exception as e:
 
36
  raise HTTPException(status_code=500, detail=str(e))
37
 
38
- return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/wav")
39
-
40
 
41
- # ------------------------
42
- # Speech to Text
43
- # ------------------------
44
- @router.post("/stt")
45
  async def stt(file: UploadFile = File(...)):
46
  """
47
- Accepts an uploaded audio file (wav, mp3, m4a) and returns the transcribed text.
 
 
 
 
 
48
  """
49
- # Validate MIME type
50
  if file.content_type not in ALLOWED_AUDIO_TYPES:
51
  raise HTTPException(
52
  status_code=400,
53
- detail=f"Unsupported audio format: {file.content_type}. Supported: WAV, MP3, M4A"
54
  )
55
-
56
  try:
 
57
  audio_bytes = await file.read()
58
  text = await speech_to_text(audio_bytes, file.filename)
 
 
 
 
 
 
 
59
  except Exception as e:
 
60
  raise HTTPException(status_code=500, detail=str(e))
61
-
62
- return {"text": text}
63
 
64
 
65
- # ------------------------
66
- # Voice Chatbot: User sends voice → TTS reply
67
- # ------------------------
68
  @router.post("/chatbot")
69
- async def chatbot(file: UploadFile = File(...)):
70
  """
71
- User sends an audio file, the system converts to text, generates response, and returns TTS audio.
 
 
 
 
 
 
 
 
 
 
72
  """
73
- # Validate MIME type
74
  if file.content_type not in ALLOWED_AUDIO_TYPES:
75
  raise HTTPException(
76
  status_code=400,
77
- detail=f"Unsupported audio format: {file.content_type}. Supported: WAV, MP3, M4A"
78
  )
79
-
80
  try:
 
 
 
81
  audio_bytes = await file.read()
82
  user_text = await speech_to_text(audio_bytes, file.filename)
83
-
84
- # Replace this with your NLP or chatbot logic
85
- response_text = f"You said: {user_text}"
86
-
 
 
 
87
  audio_response = await generate_tts(response_text)
 
 
 
 
88
  except Exception as e:
 
89
  raise HTTPException(status_code=500, detail=str(e))
90
 
91
- return StreamingResponse(io.BytesIO(audio_response), media_type="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter, UploadFile, File, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  import io
4
+ import logging
5
+ from config import ALLOWED_AUDIO_TYPES, MAX_TEXT_LENGTH, MIN_TEXT_LENGTH
6
  from services.stt_service import speech_to_text
7
+ from services.tts_service import generate_tts
8
+ from services.chatbot_service import get_chatbot_response
9
+ from models.audio import STTResponse, TTSRequest, TTSResponse, ChatbotRequest, ChatbotResponse
10
+
11
+ logger = logging.getLogger(__name__)
12
 
13
  router = APIRouter(prefix="/audio", tags=["Audio"])
14
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
16
  @router.post("/tts")
17
+ async def tts(request: TTSRequest):
 
 
18
  """
19
+ Convert text to speech and return audio file.
 
 
 
 
20
 
21
+ Example:
22
+ - POST /audio/tts
23
+ - Body: {"text": "Hello, welcome to our system"}
24
+ - Returns: WAV audio file
25
+ """
26
  try:
27
+ logger.info(f"TTS request received for text: '{request.text}'")
28
+ audio_bytes = await generate_tts(request.text)
29
+ return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/wav")
30
  except Exception as e:
31
+ logger.error(f"TTS error: {str(e)}")
32
  raise HTTPException(status_code=500, detail=str(e))
33
 
 
 
34
 
35
+ @router.post("/stt", response_model=STTResponse)
 
 
 
36
  async def stt(file: UploadFile = File(...)):
37
  """
38
+ Convert audio file to text.
39
+
40
+ Example:
41
+ - POST /audio/stt
42
+ - File: audio.mp3 (or .wav, .m4a)
43
+ - Returns: {"text": "transcribed text", "model_name": "gemini-2.5-flash", ...}
44
  """
45
+ # Validate file type
46
  if file.content_type not in ALLOWED_AUDIO_TYPES:
47
  raise HTTPException(
48
  status_code=400,
49
+ detail=f"Unsupported format: {file.content_type}. Supported: WAV, MP3, M4A"
50
  )
51
+
52
  try:
53
+ logger.info(f"STT request received for file: {file.filename}")
54
  audio_bytes = await file.read()
55
  text = await speech_to_text(audio_bytes, file.filename)
56
+
57
+ return STTResponse(
58
+ text=text,
59
+ model_name="gemini-2.5-flash",
60
+ language="en",
61
+ duration_seconds=None
62
+ )
63
  except Exception as e:
64
+ logger.error(f"STT error: {str(e)}")
65
  raise HTTPException(status_code=500, detail=str(e))
 
 
66
 
67
 
 
 
 
68
  @router.post("/chatbot")
69
+ async def chatbot_voice(file: UploadFile = File(...)):
70
  """
71
+ Full voice chatbot flow (Audio Text Response Audio).
72
+
73
+ Example:
74
+ - POST /audio/chatbot
75
+ - File: user_voice.mp3
76
+ - Returns: Response audio file (WAV)
77
+
78
+ Process:
79
+ 1. Converts user's audio to text (STT)
80
+ 2. Generates chatbot response to user's text
81
+ 3. Converts response back to audio (TTS)
82
  """
83
+ # Validate file type
84
  if file.content_type not in ALLOWED_AUDIO_TYPES:
85
  raise HTTPException(
86
  status_code=400,
87
+ detail=f"Unsupported format: {file.content_type}. Supported: WAV, MP3, M4A"
88
  )
89
+
90
  try:
91
+ logger.info(f"Voice chatbot request received for file: {file.filename}")
92
+
93
+ # Step 1: Convert audio to text
94
  audio_bytes = await file.read()
95
  user_text = await speech_to_text(audio_bytes, file.filename)
96
+ logger.info(f"Step 1 - STT: {user_text}")
97
+
98
+ # Step 2: Generate chatbot response
99
+ response_text = await get_chatbot_response(user_text)
100
+ logger.info(f"Step 2 - Response: {response_text}")
101
+
102
+ # Step 3: Convert response to audio
103
  audio_response = await generate_tts(response_text)
104
+ logger.info("Step 3 - TTS: Complete")
105
+
106
+ return StreamingResponse(io.BytesIO(audio_response), media_type="audio/wav")
107
+
108
  except Exception as e:
109
+ logger.error(f"Voice chatbot error: {str(e)}")
110
  raise HTTPException(status_code=500, detail=str(e))
111
 
112
+
113
+ @router.post("/chatbot-text", response_model=ChatbotResponse)
114
+ async def chatbot_text(request: ChatbotRequest):
115
+ """
116
+ Chatbot interaction with text input/output (no audio).
117
+
118
+ Example:
119
+ - POST /audio/chatbot-text
120
+ - Body: {"text": "What is the capital of France?"}
121
+ - Returns: {"user_input": "What is...", "bot_response": "The capital...", ...}
122
+ """
123
+ try:
124
+ logger.info(f"Text chatbot request: {request.text}")
125
+ response_text = await get_chatbot_response(request.text)
126
+
127
+ return ChatbotResponse(
128
+ user_input=request.text,
129
+ bot_response=response_text,
130
+ model_name="gemini-2.5-flash"
131
+ )
132
+ except Exception as e:
133
+ logger.error(f"Text chatbot error: {str(e)}")
134
+ raise HTTPException(status_code=500, detail=str(e))
services/chatbot_service.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from services.gemini_client import get_gemini_client
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ async def get_chatbot_response(user_text: str) -> str:
8
+ """
9
+ Generate chatbot response using Gemini API.
10
+
11
+ Args:
12
+ user_text: User input text
13
+
14
+ Returns:
15
+ Chatbot response text
16
+
17
+ Raises:
18
+ Exception: If response generation fails
19
+ """
20
+ try:
21
+ client = get_gemini_client()
22
+
23
+ logger.info(f"Generating chatbot response for: '{user_text}'")
24
+
25
+ # Create a system prompt for better responses
26
+ system_prompt = """You are a helpful, friendly AI assistant.
27
+ Respond concisely and naturally to user queries.
28
+ Keep responses brief (1-2 sentences) for voice interaction."""
29
+
30
+ response = client.models.generate_content(
31
+ model="gemini-2.5-flash",
32
+ contents=[
33
+ {"role": "user", "parts": [{"text": system_prompt}]},
34
+ {"role": "user", "parts": [{"text": user_text}]}
35
+ ]
36
+ )
37
+
38
+ response_text = response.text
39
+ logger.info(f"✓ Response generated: '{response_text}'")
40
+
41
+ return response_text
42
+
43
+ except Exception as e:
44
+ logger.error(f"✗ Chatbot response failed: {str(e)}")
45
+ # Fallback response
46
+ return f"I understood you said: '{user_text}'. Could you tell me more?"
services/gemini_client.py CHANGED
@@ -1,8 +1,18 @@
1
  from google.genai import Client
2
- import os
 
 
 
 
3
 
4
  def get_gemini_client():
5
- api_key = os.getenv("GOOGLE_GENAI_API_KEY")
6
- if not api_key:
7
- raise ValueError("Missing GOOGLE_GENAI_API_KEY environment variable")
8
- return Client(api_key=api_key)
 
 
 
 
 
 
 
1
  from google.genai import Client
2
+ from config import GOOGLE_GENAI_API_KEY
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
 
8
  def get_gemini_client():
9
+ """
10
+ Initialize and return Gemini client.
11
+ """
12
+ try:
13
+ client = Client(api_key=GOOGLE_GENAI_API_KEY)
14
+ logger.info("✓ Gemini client initialized successfully")
15
+ return client
16
+ except Exception as e:
17
+ logger.error(f"✗ Failed to initialize Gemini client: {str(e)}")
18
+ raise ValueError(f"Gemini client initialization failed: {str(e)}")
services/stt_service.py CHANGED
@@ -1,23 +1,49 @@
1
  from services.gemini_client import get_gemini_client
2
  from google.genai import types
3
- import mimetypes # <- Add this
 
 
 
 
4
 
5
  async def speech_to_text(audio_bytes: bytes, filename: str) -> str:
6
  """
7
- Convert audio bytes to text using Gemini API. Supports WAV, MP3, M4A.
 
 
 
 
 
 
 
 
 
 
8
  """
9
- client = get_gemini_client()
10
-
11
- # Detect MIME type from filename
12
- mime_type, _ = mimetypes.guess_type(filename)
13
- if mime_type is None:
14
- mime_type = "audio/wav" # fallback
15
-
16
- audio_file = types.File(data=audio_bytes, mime_type=mime_type)
17
-
18
- response = client.models.generate_content(
19
- model="gemini-2.5-flash",
20
- contents=[audio_file]
21
- )
22
-
23
- return response.text
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from services.gemini_client import get_gemini_client
2
  from google.genai import types
3
+ import mimetypes
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
 
9
  async def speech_to_text(audio_bytes: bytes, filename: str) -> str:
10
  """
11
+ Convert audio bytes to text using Gemini API.
12
+
13
+ Args:
14
+ audio_bytes: Raw audio file bytes
15
+ filename: Name of the audio file (used to detect format)
16
+
17
+ Returns:
18
+ Transcribed text
19
+
20
+ Raises:
21
+ Exception: If transcription fails
22
  """
23
+ try:
24
+ client = get_gemini_client()
25
+
26
+ # Detect MIME type from filename
27
+ mime_type, _ = mimetypes.guess_type(filename)
28
+ if mime_type is None:
29
+ mime_type = "audio/wav" # fallback
30
+
31
+ logger.info(f"Converting audio to text (format: {mime_type})")
32
+
33
+ # Create audio file object
34
+ audio_file = types.File(data=audio_bytes, mime_type=mime_type)
35
+
36
+ # Call Gemini API
37
+ response = client.models.generate_content(
38
+ model="gemini-2.5-flash",
39
+ contents=[audio_file]
40
+ )
41
+
42
+ transcribed_text = response.text
43
+ logger.info(f"✓ STT successful: '{transcribed_text}'")
44
+
45
+ return transcribed_text
46
+
47
+ except Exception as e:
48
+ logger.error(f"✗ STT failed: {str(e)}")
49
+ raise Exception(f"Speech-to-text conversion failed: {str(e)}")
services/tts_service.py CHANGED
@@ -1,25 +1,53 @@
1
  from services.gemini_client import get_gemini_client
2
  from google.genai import types
3
  import base64
 
4
 
5
- async def generate_tts(text: str) -> bytes:
6
- client = get_gemini_client()
7
-
8
- response = client.models.generate_content(
9
- model="gemini-2.5-flash-preview-tts",
10
- contents=text,
11
- config=types.GenerateContentConfig(
12
- response_modalities=["AUDIO"],
13
- speech_config=types.SpeechConfig(
14
- voice_config=types.VoiceConfig(
15
- prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Kore")
16
- )
17
- ),
18
- ),
19
- )
20
 
21
- # Decode base64 audio into bytes
22
- audio_base64 = response.candidates[0].content.parts[0].inline_data.data
23
- audio_bytes = base64.b64decode(audio_base64)
24
 
25
- return audio_bytes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from services.gemini_client import get_gemini_client
2
  from google.genai import types
3
  import base64
4
+ import logging
5
 
6
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
8
 
9
+ async def generate_tts(text: str) -> bytes:
10
+ """
11
+ Convert text to speech using Gemini API.
12
+
13
+ Args:
14
+ text: Text to convert to speech
15
+
16
+ Returns:
17
+ Audio bytes in WAV format
18
+
19
+ Raises:
20
+ Exception: If TTS generation fails
21
+ """
22
+ try:
23
+ client = get_gemini_client()
24
+
25
+ logger.info(f"Generating speech for: '{text}'")
26
+
27
+ # Call Gemini TTS API
28
+ response = client.models.generate_content(
29
+ model="gemini-2.5-flash-preview-tts",
30
+ contents=text,
31
+ config=types.GenerateContentConfig(
32
+ response_modalities=["AUDIO"],
33
+ speech_config=types.SpeechConfig(
34
+ voice_config=types.VoiceConfig(
35
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
36
+ voice_name="Kore" # Options: Kore, Peri, Charon, Fenrir, Orbit
37
+ )
38
+ )
39
+ ),
40
+ ),
41
+ )
42
+
43
+ # Extract and decode base64 audio
44
+ audio_base64 = response.candidates[0].content.parts[0].inline_data.data
45
+ audio_bytes = base64.b64decode(audio_base64)
46
+
47
+ logger.info(f"✓ TTS successful: {len(audio_bytes)} bytes generated")
48
+
49
+ return audio_bytes
50
+
51
+ except Exception as e:
52
+ logger.error(f"✗ TTS failed: {str(e)}")
53
+ raise Exception(f"Text-to-speech generation failed: {str(e)}")