malek-messaoudii commited on
Commit
544d113
·
1 Parent(s): 4f1c42b

Correct files

Browse files
models/audio.py CHANGED
@@ -23,7 +23,6 @@ class STTResponse(BaseModel):
23
  description="Approximate audio duration in seconds"
24
  )
25
 
26
-
27
  # ==============================
28
  # TEXT TO SPEECH REQUEST / RESPONSE
29
  # ==============================
@@ -33,7 +32,6 @@ class TTSRequest(BaseModel):
33
  )
34
  text: str = Field(..., min_length=1, max_length=500, description="Text to convert to speech")
35
 
36
-
37
  class TTSResponse(BaseModel):
38
  model_config = ConfigDict(
39
  json_schema_extra={
 
23
  description="Approximate audio duration in seconds"
24
  )
25
 
 
26
  # ==============================
27
  # TEXT TO SPEECH REQUEST / RESPONSE
28
  # ==============================
 
32
  )
33
  text: str = Field(..., min_length=1, max_length=500, description="Text to convert to speech")
34
 
 
35
  class TTSResponse(BaseModel):
36
  model_config = ConfigDict(
37
  json_schema_extra={
routes/audio.py CHANGED
@@ -1,14 +1,14 @@
1
  from fastapi import APIRouter, UploadFile, File, HTTPException
2
- from services.tts_service import generate_tts
3
- from services.stt_service import speech_to_text
4
  from fastapi.responses import StreamingResponse
5
  import io
 
 
6
 
7
  router = APIRouter(prefix="/audio", tags=["Audio"])
8
 
9
- # ======================
10
- # TEXT TO SPEECH
11
- # ======================
12
  @router.post("/tts")
13
  async def tts(text: str):
14
  try:
@@ -16,13 +16,12 @@ async def tts(text: str):
16
  except Exception as e:
17
  raise HTTPException(status_code=500, detail=str(e))
18
 
19
- # Return as streaming response without saving file
20
  return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/wav")
21
 
22
 
23
- # ======================
24
- # SPEECH TO TEXT
25
- # ======================
26
  @router.post("/stt")
27
  async def stt(file: UploadFile = File(...)):
28
  try:
@@ -30,5 +29,24 @@ async def stt(file: UploadFile = File(...)):
30
  text = await speech_to_text(audio_bytes)
31
  except Exception as e:
32
  raise HTTPException(status_code=500, detail=str(e))
33
-
34
  return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import APIRouter, UploadFile, File, HTTPException
 
 
2
  from fastapi.responses import StreamingResponse
3
  import io
4
+ from services.tts_service import generate_tts
5
+ from services.stt_service import speech_to_text
6
 
7
  router = APIRouter(prefix="/audio", tags=["Audio"])
8
 
9
+ # ------------------------
10
+ # Text to Speech
11
+ # ------------------------
12
  @router.post("/tts")
13
  async def tts(text: str):
14
  try:
 
16
  except Exception as e:
17
  raise HTTPException(status_code=500, detail=str(e))
18
 
 
19
  return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/wav")
20
 
21
 
22
+ # ------------------------
23
+ # Speech to Text
24
+ # ------------------------
25
  @router.post("/stt")
26
  async def stt(file: UploadFile = File(...)):
27
  try:
 
29
  text = await speech_to_text(audio_bytes)
30
  except Exception as e:
31
  raise HTTPException(status_code=500, detail=str(e))
32
+
33
  return {"text": text}
34
+
35
+
36
+ # ------------------------
37
+ # Voice Chatbot: User sends voice → TTS reply
38
+ # ------------------------
39
+ @router.post("/chatbot")
40
+ async def chatbot(file: UploadFile = File(...)):
41
+ try:
42
+ audio_bytes = await file.read()
43
+ user_text = await speech_to_text(audio_bytes)
44
+
45
+ # Replace with your NLP logic or chatbot response
46
+ response_text = f"You said: {user_text}"
47
+
48
+ audio_response = await generate_tts(response_text)
49
+ except Exception as e:
50
+ raise HTTPException(status_code=500, detail=str(e))
51
+
52
+ return StreamingResponse(io.BytesIO(audio_response), media_type="audio/wav")
services/gemini_client.py CHANGED
@@ -2,7 +2,7 @@ from google.genai import Client
2
  import os
3
 
4
  def get_gemini_client():
5
- api_key = os.getenv("GOOGLE_GENAI_API_KEY") # store it in .env
6
  if not api_key:
7
  raise ValueError("Missing GOOGLE_GENAI_API_KEY environment variable")
8
  return Client(api_key=api_key)
 
2
  import os
3
 
4
  def get_gemini_client():
5
+ api_key = os.getenv("GOOGLE_GENAI_API_KEY")
6
  if not api_key:
7
  raise ValueError("Missing GOOGLE_GENAI_API_KEY environment variable")
8
  return Client(api_key=api_key)
services/stt_service.py CHANGED
@@ -4,12 +4,12 @@ from google.genai import types
4
  async def speech_to_text(audio_bytes: bytes) -> str:
5
  client = get_gemini_client()
6
 
7
- # Correctly wrap audio bytes using types.File
8
  audio_file = types.File(data=audio_bytes, mime_type="audio/wav")
9
 
10
  response = client.models.generate_content(
11
  model="gemini-2.5-flash",
12
- contents=[audio_file] # <-- pass as a list of types.File
13
  )
14
 
15
  return response.text
 
4
  async def speech_to_text(audio_bytes: bytes) -> str:
5
  client = get_gemini_client()
6
 
7
+ # Wrap audio bytes correctly
8
  audio_file = types.File(data=audio_bytes, mime_type="audio/wav")
9
 
10
  response = client.models.generate_content(
11
  model="gemini-2.5-flash",
12
+ contents=[audio_file]
13
  )
14
 
15
  return response.text
services/tts_service.py CHANGED
@@ -3,9 +3,6 @@ from google.genai import types
3
  import base64
4
 
5
  async def generate_tts(text: str) -> bytes:
6
- """
7
- Convert text to speech using Gemini API and return valid WAV bytes
8
- """
9
  client = get_gemini_client()
10
 
11
  response = client.models.generate_content(
@@ -21,7 +18,7 @@ async def generate_tts(text: str) -> bytes:
21
  ),
22
  )
23
 
24
- # Gemini returns base64 audio, decode it to bytes
25
  audio_base64 = response.candidates[0].content.parts[0].inline_data.data
26
  audio_bytes = base64.b64decode(audio_base64)
27
 
 
3
  import base64
4
 
5
  async def generate_tts(text: str) -> bytes:
 
 
 
6
  client = get_gemini_client()
7
 
8
  response = client.models.generate_content(
 
18
  ),
19
  )
20
 
21
+ # Decode base64 audio into bytes
22
  audio_base64 = response.candidates[0].content.parts[0].inline_data.data
23
  audio_bytes = base64.b64decode(audio_base64)
24