malek-messaoudii commited on
Commit
56dc677
·
1 Parent(s): 674469e

Refactor audio processing and chatbot services; enhance STT and TTS functionalities with base64 support and session management

Browse files
config.py CHANGED
@@ -17,8 +17,8 @@ PROJECT_ROOT = API_DIR.parent
17
 
18
  # ============ HUGGING FACE ============
19
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
20
- HUGGINGFACE_STANCE_MODEL_ID = os.getenv("HUGGINGFACE_STANCE_MODEL_ID", "NLP-Debater-Project/debertav3-stance-detection")
21
- HUGGINGFACE_LABEL_MODEL_ID = os.getenv("HUGGINGFACE_LABEL_MODEL_ID", "NLP-Debater-Project/distilBert-keypoint-matching")
22
 
23
  # ============ API CONFIGURATION ============
24
  API_TITLE = "NLP Debater - Voice Chatbot API"
 
17
 
18
  # ============ HUGGING FACE ============
19
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
20
+ HUGGINGFACE_STANCE_MODEL_ID = os.getenv("HUGGINGFACE_STANCE_MODEL_ID")
21
+ HUGGINGFACE_LABEL_MODEL_ID = os.getenv("HUGGINGFACE_LABEL_MODEL_ID")
22
 
23
  # ============ API CONFIGURATION ============
24
  API_TITLE = "NLP Debater - Voice Chatbot API"
main.py CHANGED
@@ -58,7 +58,7 @@ from config import (
58
  HOST, PORT, RELOAD,
59
  CORS_ORIGINS, CORS_CREDENTIALS, CORS_METHODS, CORS_HEADERS,
60
  PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL,
61
- LOAD_STT_MODEL, LOAD_CHATBOT_MODEL
62
  )
63
 
64
  @asynccontextmanager
@@ -92,9 +92,10 @@ async def lifespan(app: FastAPI):
92
  # Load STT Model (Speech-to-Text)
93
  if LOAD_STT_MODEL:
94
  try:
95
- logger.info("Loading STT Model (Whisper)...")
96
- from services.stt_service import load_stt_model
97
- load_stt_model()
 
98
  logger.info("✓ STT model loaded successfully")
99
  except Exception as e:
100
  logger.error(f"✗ STT model loading failed: {str(e)}")
@@ -102,9 +103,10 @@ async def lifespan(app: FastAPI):
102
  # Load Chatbot Model
103
  if LOAD_CHATBOT_MODEL:
104
  try:
105
- logger.info("Loading Chatbot Model (DialoGPT)...")
106
- from services.chatbot_service import load_chatbot_model
107
- load_chatbot_model()
 
108
  logger.info("✓ Chatbot model loaded successfully")
109
  except Exception as e:
110
  logger.error(f"✗ Chatbot model loading failed: {str(e)}")
@@ -139,9 +141,16 @@ app.add_middleware(
139
  )
140
 
141
  # Include routers
 
 
 
 
 
 
 
142
  try:
143
  from routes.audio import router as audio_router
144
- app.include_router(audio_router, prefix="/audio", tags=["Audio - Voice Chatbot"])
145
  logger.info("✓ Audio routes registered")
146
  except Exception as e:
147
  logger.warning(f"⚠️ Audio routes failed to load: {e}")
@@ -163,7 +172,8 @@ async def root():
163
  "version": API_VERSION,
164
  "docs": "/docs",
165
  "endpoints": {
166
- "audio": "/docs#/Audio%20-%20Voice%20Chatbot",
 
167
  "health": "/health",
168
  "models-status": "/models-status"
169
  }
@@ -180,18 +190,22 @@ async def models_status():
180
  status = {
181
  "stt_model": "unknown",
182
  "tts_engine": "gtts (free)",
183
- "chatbot_model": "unknown"
 
 
184
  }
185
 
186
  try:
187
- from services.stt_service import stt_pipeline
188
- status["stt_model"] = "loaded" if stt_pipeline is not None else "not loaded"
 
189
  except:
190
  status["stt_model"] = "error"
191
 
192
  try:
193
- from services.chatbot_service import chatbot_pipeline
194
- status["chatbot_model"] = "loaded" if chatbot_pipeline is not None else "not loaded"
 
195
  except:
196
  status["chatbot_model"] = "error"
197
 
@@ -220,4 +234,4 @@ if __name__ == "__main__":
220
  port=PORT,
221
  reload=RELOAD,
222
  log_level="info"
223
- )
 
58
  HOST, PORT, RELOAD,
59
  CORS_ORIGINS, CORS_CREDENTIALS, CORS_METHODS, CORS_HEADERS,
60
  PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL,
61
+ LOAD_STT_MODEL, LOAD_CHATBOT_MODEL, STT_MODEL_ID, CHATBOT_MODEL_ID
62
  )
63
 
64
  @asynccontextmanager
 
92
  # Load STT Model (Speech-to-Text)
93
  if LOAD_STT_MODEL:
94
  try:
95
+ logger.info(f"Loading STT Model: {STT_MODEL_ID}")
96
+ from services.stt_service import STTService
97
+ stt_service = STTService()
98
+ await stt_service.initialize()
99
  logger.info("✓ STT model loaded successfully")
100
  except Exception as e:
101
  logger.error(f"✗ STT model loading failed: {str(e)}")
 
103
  # Load Chatbot Model
104
  if LOAD_CHATBOT_MODEL:
105
  try:
106
+ logger.info(f"Loading Chatbot Model: {CHATBOT_MODEL_ID}")
107
+ from services.chatbot_service import ChatbotService
108
+ chatbot_service = ChatbotService()
109
+ await chatbot_service.initialize()
110
  logger.info("✓ Chatbot model loaded successfully")
111
  except Exception as e:
112
  logger.error(f"✗ Chatbot model loading failed: {str(e)}")
 
141
  )
142
 
143
  # Include routers
144
+ try:
145
+ from routes.audio import router as chatbot_router
146
+ app.include_router(chatbot_router, prefix="/api/v1", tags=["Voice Chatbot"])
147
+ logger.info("✓ Chatbot routes registered")
148
+ except Exception as e:
149
+ logger.warning(f"⚠️ Chatbot routes failed to load: {e}")
150
+
151
  try:
152
  from routes.audio import router as audio_router
153
+ app.include_router(audio_router, prefix="/audio", tags=["Audio Processing"])
154
  logger.info("✓ Audio routes registered")
155
  except Exception as e:
156
  logger.warning(f"⚠️ Audio routes failed to load: {e}")
 
172
  "version": API_VERSION,
173
  "docs": "/docs",
174
  "endpoints": {
175
+ "voice_chatbot": "/api/v1/chat/message",
176
+ "audio_processing": "/docs#/Audio%20Processing",
177
  "health": "/health",
178
  "models-status": "/models-status"
179
  }
 
190
  status = {
191
  "stt_model": "unknown",
192
  "tts_engine": "gtts (free)",
193
+ "chatbot_model": "unknown",
194
+ "stance_model": "unknown",
195
+ "kpa_model": "unknown"
196
  }
197
 
198
  try:
199
+ from services.stt_service import STTService
200
+ stt_service = STTService()
201
+ status["stt_model"] = "loaded" if hasattr(stt_service, 'initialized') and stt_service.initialized else "not loaded"
202
  except:
203
  status["stt_model"] = "error"
204
 
205
  try:
206
+ from services.chatbot_service import ChatbotService
207
+ chatbot_service = ChatbotService()
208
+ status["chatbot_model"] = "loaded" if hasattr(chatbot_service, 'initialized') and chatbot_service.initialized else "not loaded"
209
  except:
210
  status["chatbot_model"] = "error"
211
 
 
234
  port=PORT,
235
  reload=RELOAD,
236
  log_level="info"
237
+ )
models/audio.py CHANGED
@@ -1,44 +1,30 @@
1
  from pydantic import BaseModel, Field
2
- from typing import Optional
 
 
3
 
4
- class STTResponse(BaseModel):
5
- text: str = Field(..., description="Transcribed text")
6
- model_name: str = Field(default="whisper-base", description="Model used")
7
- language: Optional[str] = Field(default="en", description="Language detected")
8
- duration_seconds: Optional[float] = Field(None, description="Audio duration")
9
-
10
- class Config:
11
- json_schema_extra = {
12
- "example": {
13
- "text": "hello how are you",
14
- "model_name": "whisper-base",
15
- "language": "en",
16
- "duration_seconds": 3.2
17
- }
18
- }
19
 
20
- class TTSRequest(BaseModel):
21
- text: str = Field(..., min_length=1, max_length=500, description="Text to convert")
22
-
23
- class Config:
24
- json_schema_extra = {"example": {"text": "Hello world"}}
25
-
26
- class ChatbotRequest(BaseModel):
27
- text: str = Field(..., min_length=1, max_length=500, description="User input")
28
-
29
- class Config:
30
- json_schema_extra = {"example": {"text": "What is AI?"}}
31
 
32
  class ChatbotResponse(BaseModel):
33
- user_input: str
34
- bot_response: str
35
- model_name: str = Field(default="DialoGPT-medium")
36
-
37
- class Config:
38
- json_schema_extra = {
39
- "example": {
40
- "user_input": "Hello",
41
- "bot_response": "Hi there! How can I help?",
42
- "model_name": "DialoGPT-medium"
43
- }
44
- }
 
 
1
  from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Dict, Any
3
+ from enum import Enum
4
+ from datetime import datetime
5
 
6
+ class MessageType(str, Enum):
7
+ TEXT = "text"
8
+ AUDIO = "audio"
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ class UserMessage(BaseModel):
11
+ message_id: str = Field(..., description="Unique message ID")
12
+ content: str = Field(..., description="Text content or audio base64")
13
+ message_type: MessageType = Field(..., description="Message type")
14
+ session_id: str = Field(..., description="User session ID")
15
+ timestamp: datetime = Field(default_factory=datetime.now)
 
 
 
 
 
16
 
17
  class ChatbotResponse(BaseModel):
18
+ response_id: str = Field(..., description="Unique response ID")
19
+ text_response: str = Field(..., description="Chatbot text response")
20
+ audio_response: Optional[str] = Field(None, description="Audio response in base64")
21
+ audio_url: Optional[str] = Field(None, description="Generated audio URL")
22
+ session_id: str = Field(..., description="User session ID")
23
+ timestamp: datetime = Field(default_factory=datetime.now)
24
+
25
+ class ChatSession(BaseModel):
26
+ session_id: str = Field(..., description="Session ID")
27
+ user_id: Optional[str] = Field(None, description="User ID")
28
+ created_at: datetime = Field(default_factory=datetime.now)
29
+ last_activity: datetime = Field(default_factory=datetime.now)
30
+ conversation_history: List[Dict[str, Any]] = Field(default_factory=list)
requirements.txt CHANGED
@@ -9,14 +9,14 @@ pydantic==2.5.0
9
  python-dotenv==1.0.0
10
  torch>=2.0.0
11
  transformers>=4.35.0
12
- accelerate>=0.24.0
13
  protobuf>=3.20.0
14
  huggingface_hub>=0.19.0
15
  python-multipart
16
  google-genai>=0.4.0
17
- gtts==2.5.1
18
  requests==2.31.0
19
- ffmpeg-python==0.2.0
20
- librosa==0.10.1
21
  soundfile==0.12.1
22
- librosa==0.10.0
 
 
 
 
 
9
  python-dotenv==1.0.0
10
  torch>=2.0.0
11
  transformers>=4.35.0
 
12
  protobuf>=3.20.0
13
  huggingface_hub>=0.19.0
14
  python-multipart
15
  google-genai>=0.4.0
 
16
  requests==2.31.0
 
 
17
  soundfile==0.12.1
18
+ gtts==2.3.2
19
+ SpeechRecognition==3.10.0
20
+ pyttsx3==2.90
21
+ accelerate>=0.20.0
22
+ coqui-tts==0.21.0
routes/audio.py CHANGED
@@ -1,170 +1,86 @@
1
- from fastapi import APIRouter, UploadFile, File, HTTPException
2
- from fastapi.responses import StreamingResponse
3
- import io
4
- import logging
5
- from config import ALLOWED_AUDIO_TYPES, MAX_AUDIO_SIZE
6
- from services.stt_service import speech_to_text, load_stt_model
7
- from services.tts_service import generate_tts
8
- from services.chatbot_service import get_chatbot_response, load_chatbot_model
9
- from models.audio import STTResponse, TTSRequest, ChatbotRequest, ChatbotResponse
10
 
11
- logger = logging.getLogger(__name__)
12
- router = APIRouter(prefix="/audio", tags=["Audio"])
13
 
14
- @router.on_event("startup")
15
- async def startup_models():
16
- """Load models on startup"""
17
- logger.info("🚀 Loading models...")
 
 
 
18
  try:
19
- load_stt_model()
20
- load_chatbot_model()
21
- logger.info(" All models loaded successfully")
22
- except Exception as e:
23
- logger.error(f"⚠️ Model loading issues: {str(e)}")
24
-
25
- @router.post("/tts")
26
- async def tts(request: TTSRequest):
27
- """
28
- Convert text to speech.
29
- Returns MP3 audio file.
30
-
31
- Example:
32
- ```
33
- POST /audio/tts
34
- {
35
- "text": "Hello, welcome to voice chatbot"
36
- }
37
- ```
38
- """
39
- try:
40
- logger.info(f"TTS Request: '{request.text}'")
41
- audio_bytes = await generate_tts(request.text)
42
- return StreamingResponse(
43
- io.BytesIO(audio_bytes),
44
- media_type="audio/mpeg",
45
- headers={"Content-Disposition": "attachment; filename=output.mp3"}
46
- )
47
- except Exception as e:
48
- logger.error(f"TTS Error: {str(e)}")
49
- raise HTTPException(status_code=500, detail=str(e))
50
-
51
- @router.post("/stt", response_model=STTResponse)
52
- async def stt(file: UploadFile = File(...)):
53
- """
54
- Convert audio file to text.
55
- Supports: WAV, MP3, M4A
56
-
57
- Example:
58
- ```
59
- POST /audio/stt
60
- File: audio.mp3
61
- ```
62
- """
63
- if file.content_type not in ALLOWED_AUDIO_TYPES:
64
- raise HTTPException(
65
- status_code=400,
66
- detail=f"Unsupported format. Allowed: {', '.join(ALLOWED_AUDIO_TYPES)}"
67
- )
68
-
69
- try:
70
- logger.info(f"STT Request: {file.filename}")
71
- audio_bytes = await file.read()
72
 
73
- if len(audio_bytes) > MAX_AUDIO_SIZE:
74
- raise HTTPException(
75
- status_code=400,
76
- detail=f"File too large. Max: {MAX_AUDIO_SIZE / 1024 / 1024}MB"
77
- )
78
 
79
- text = await speech_to_text(audio_bytes, file.filename)
 
 
 
 
 
 
80
 
81
- return STTResponse(
82
- text=text,
83
- model_name="whisper-base",
84
- language="en"
85
- )
86
- except HTTPException:
87
- raise
88
- except Exception as e:
89
- logger.error(f"STT Error: {str(e)}")
90
- raise HTTPException(status_code=500, detail=str(e))
91
-
92
- @router.post("/chatbot")
93
- async def chatbot_voice(file: UploadFile = File(...)):
94
- """
95
- Full voice chatbot flow: Audio → Text → Response → Audio
96
-
97
- Example:
98
- ```
99
- POST /audio/chatbot
100
- File: user_voice.mp3
101
- Returns: Response MP3 audio
102
- ```
103
- """
104
- if file.content_type not in ALLOWED_AUDIO_TYPES:
105
- raise HTTPException(
106
- status_code=400,
107
- detail=f"Unsupported format. Allowed: {', '.join(ALLOWED_AUDIO_TYPES)}"
108
  )
109
-
110
- try:
111
- logger.info(f"Voice Chatbot: Processing {file.filename}")
112
-
113
- audio_bytes = await file.read()
114
- if len(audio_bytes) > MAX_AUDIO_SIZE:
115
- raise HTTPException(
116
- status_code=400,
117
- detail=f"File too large. Max: {MAX_AUDIO_SIZE / 1024 / 1024}MB"
118
- )
119
-
120
- # Step 1: STT
121
- logger.info("Step 1/3: Converting speech to text...")
122
- user_text = await speech_to_text(audio_bytes, file.filename)
123
-
124
- # Step 2: Chatbot response
125
- logger.info("Step 2/3: Generating response...")
126
- response_text = await get_chatbot_response(user_text)
127
-
128
- # Step 3: TTS
129
- logger.info("Step 3/3: Converting response to speech...")
130
- audio_response = await generate_tts(response_text)
131
 
132
- logger.info("✓ Voice chatbot complete")
 
133
 
134
- return StreamingResponse(
135
- io.BytesIO(audio_response),
136
- media_type="audio/mpeg",
137
- headers={"Content-Disposition": "attachment; filename=response.mp3"}
138
- )
139
 
140
- except HTTPException:
141
- raise
142
  except Exception as e:
143
- logger.error(f"Voice Chatbot Error: {str(e)}")
144
- raise HTTPException(status_code=500, detail=str(e))
145
 
146
- @router.post("/chatbot-text", response_model=ChatbotResponse)
147
- async def chatbot_text(request: ChatbotRequest):
148
- """
149
- Text-only chatbot (no audio).
150
-
151
- Example:
152
- ```
153
- POST /audio/chatbot-text
154
- {
155
- "text": "What is artificial intelligence?"
156
- }
157
- ```
158
- """
159
  try:
160
- logger.info(f"Text Chatbot: '{request.text}'")
161
- bot_response = await get_chatbot_response(request.text)
162
 
163
- return ChatbotResponse(
164
- user_input=request.text,
165
- bot_response=bot_response,
166
- model_name="DialoGPT-medium"
 
167
  )
 
 
 
 
168
  except Exception as e:
169
- logger.error(f"Text Chatbot Error: {str(e)}")
170
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
2
+ from fastapi.responses import JSONResponse
3
+ import uuid
4
+ import base64
5
+ from models.audio import UserMessage, ChatbotResponse, MessageType
6
+ from services.chatbot_service import ChatbotService
 
 
 
7
 
8
+ router = APIRouter()
9
+ chatbot_service = ChatbotService()
10
 
11
+ @router.post("/chat/message", response_model=ChatbotResponse)
12
+ async def send_chat_message(
13
+ session_id: str = Form(...),
14
+ message_type: str = Form(...),
15
+ message: str = Form(None),
16
+ audio_file: UploadFile = File(None)
17
+ ):
18
  try:
19
+ # Validate input
20
+ if not message and not audio_file:
21
+ raise HTTPException(status_code=400, detail="Either message or audio file must be provided")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ if message_type == "audio" and not audio_file:
24
+ raise HTTPException(status_code=400, detail="Audio file required for audio messages")
 
 
 
25
 
26
+ # Process audio file if provided
27
+ content = ""
28
+ if audio_file:
29
+ audio_data = await audio_file.read()
30
+ content = base64.b64encode(audio_data).decode('utf-8')
31
+ else:
32
+ content = message
33
 
34
+ # Create user message
35
+ user_message = UserMessage(
36
+ message_id=str(uuid.uuid4()),
37
+ content=content,
38
+ message_type=MessageType(message_type),
39
+ session_id=session_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Process through chatbot service
43
+ response = await chatbot_service.process_user_message(user_message)
44
 
45
+ return response
 
 
 
 
46
 
 
 
47
  except Exception as e:
48
+ raise HTTPException(status_code=500, detail=f"Error processing message: {str(e)}")
 
49
 
50
+ @router.post("/chat/audio")
51
+ async def send_audio_message(
52
+ session_id: str = Form(...),
53
+ audio_file: UploadFile = File(...)
54
+ ):
55
+ """Endpoint specifically for audio messages"""
 
 
 
 
 
 
 
56
  try:
57
+ audio_data = await audio_file.read()
58
+ audio_base64 = base64.b64encode(audio_data).decode('utf-8')
59
 
60
+ user_message = UserMessage(
61
+ message_id=str(uuid.uuid4()),
62
+ content=audio_base64,
63
+ message_type=MessageType.AUDIO,
64
+ session_id=session_id
65
  )
66
+
67
+ response = await chatbot_service.process_user_message(user_message)
68
+ return response
69
+
70
  except Exception as e:
71
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
72
+
73
+ @router.get("/session/{session_id}/history")
74
+ async def get_session_history(session_id: str):
75
+ """Get conversation history for a session"""
76
+ history = chatbot_service.get_session_history(session_id)
77
+ if not history:
78
+ raise HTTPException(status_code=404, detail="Session not found")
79
+ return history
80
+
81
+ @router.post("/session/new")
82
+ async def create_new_session():
83
+ """Create a new chat session"""
84
+ session_id = str(uuid.uuid4())
85
+ chatbot_service._get_or_create_session(session_id)
86
+ return {"session_id": session_id, "message": "New session created"}
services/chatbot_service.py CHANGED
@@ -1,69 +1,99 @@
1
- import logging
2
- from transformers import pipeline, Conversation
3
- import random
 
 
 
 
4
 
5
- logger = logging.getLogger(__name__)
6
- chatbot_pipeline = None
7
- conversation_history = {}
8
-
9
- def load_chatbot_model():
10
- global chatbot_pipeline
11
- try:
12
- logger.info("Loading DialoGPT chatbot model...")
13
- chatbot_pipeline = pipeline(
14
- "conversational",
15
- model="microsoft/DialoGPT-medium",
16
- device="cpu" # Use "cuda" if GPU available
17
- )
18
- logger.info("✓ Chatbot model loaded successfully")
19
- except Exception as e:
20
- logger.error(f"✗ Failed to load chatbot model: {str(e)}")
21
- chatbot_pipeline = None
22
-
23
- async def get_chatbot_response(user_text: str, user_id: str = "default") -> str:
24
- """
25
- Generate chatbot response using DialoGPT.
26
- Maintains conversation history per user.
27
- """
28
- global chatbot_pipeline, conversation_history
29
 
30
- try:
31
- if chatbot_pipeline is None:
32
- load_chatbot_model()
33
- if chatbot_pipeline is None:
34
- return get_fallback_response(user_text)
 
 
 
 
35
 
36
- logger.info(f"Chatbot: Processing '{user_text}'")
 
 
 
 
 
 
 
37
 
38
- # Initialize conversation for this user if needed
39
- if user_id not in conversation_history:
40
- conversation_history[user_id] = Conversation()
 
 
 
41
 
42
- # Add user input to conversation
43
- conversation = conversation_history[user_id]
44
- conversation.add_user_input(user_text)
45
 
46
- # Generate response
47
- response = chatbot_pipeline(conversation)
48
- bot_response = response.generated_responses[-1].strip()
 
 
 
49
 
50
- if not bot_response:
51
- bot_response = get_fallback_response(user_text)
 
 
 
 
 
52
 
53
- logger.info(f"✓ Chatbot Response: '{bot_response}'")
54
- return bot_response
 
 
 
 
 
55
 
56
- except Exception as e:
57
- logger.error(f"✗ Chatbot Error: {str(e)}")
58
- return get_fallback_response(user_text)
59
-
60
- def get_fallback_response(user_text: str) -> str:
61
- """Fallback responses when model fails"""
62
- responses = [
63
- f"I understand: '{user_text}'. How can I assist?",
64
- f"Interesting point about '{user_text}'. Tell me more?",
65
- f"Regarding '{user_text}', what would you like to know?",
66
- "I'm listening. Please continue.",
67
- f"That's a great question about '{user_text}'!"
68
- ]
69
- return random.choice(responses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import uuid
3
+ from typing import Optional, Dict, Any
4
+ from datetime import datetime
5
+ from models.audio import ChatbotResponse, UserMessage
6
+ from services.tts_service import SimpleTTSService # Use simple version
7
+ from services.stt_service import STTService # Use basic version
8
 
9
+ class ChatbotService:
10
+ def __init__(self):
11
+ self.tts_service = SimpleTTSService() # Use simple TTS
12
+ self.stt_service = STTService() # Use basic STT
13
+ self.sessions: Dict[str, Dict[str, Any]] = {}
14
+ self.initialized = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ async def initialize(self):
17
+ """Initialize the chatbot service"""
18
+ await self.stt_service.initialize()
19
+ self.initialized = True
20
+ print("✓ Chatbot Service initialized")
21
+
22
+ async def process_user_message(self, user_message: UserMessage) -> ChatbotResponse:
23
+ # Update session
24
+ session = self._get_or_create_session(user_message.session_id)
25
 
26
+ # Process message based on type
27
+ if user_message.message_type == "audio":
28
+ # STT: Convert audio to text
29
+ text_input = await self.stt_service.transcribe_audio_base64(
30
+ user_message.content
31
+ )
32
+ else:
33
+ text_input = user_message.content
34
 
35
+ # Add to conversation history
36
+ session["conversation_history"].append({
37
+ "role": "user",
38
+ "content": text_input,
39
+ "timestamp": user_message.timestamp
40
+ })
41
 
42
+ # Generate chatbot response
43
+ chatbot_text = await self._generate_chatbot_response(text_input, session)
 
44
 
45
+ # TTS: Convert response to audio
46
+ try:
47
+ audio_base64 = await self.tts_service.text_to_speech_base64(chatbot_text)
48
+ except Exception as e:
49
+ print(f"TTS error: {e}")
50
+ audio_base64 = None
51
 
52
+ # Create response
53
+ response = ChatbotResponse(
54
+ response_id=str(uuid.uuid4()),
55
+ text_response=chatbot_text,
56
+ audio_response=audio_base64,
57
+ session_id=user_message.session_id
58
+ )
59
 
60
+ # Add response to history
61
+ session["conversation_history"].append({
62
+ "role": "assistant",
63
+ "content": chatbot_text,
64
+ "audio_response": audio_base64,
65
+ "timestamp": response.timestamp
66
+ })
67
 
68
+ return response
69
+
70
+ async def _generate_chatbot_response(self, user_input: str, session: Dict[str, Any]) -> str:
71
+ """Chatbot response generation logic"""
72
+ # Simple response logic - replace with your actual chatbot model
73
+ user_input_lower = user_input.lower()
74
+
75
+ if "hello" in user_input_lower or "hi" in user_input_lower:
76
+ return "Hello! How can I assist you today?"
77
+ elif "time" in user_input_lower:
78
+ return f"The current time is {datetime.now().strftime('%H:%M')}"
79
+ elif "help" in user_input_lower:
80
+ return "I'm here to help you. You can ask me questions or request assistance."
81
+ elif "audio" in user_input_lower or "voice" in user_input_lower:
82
+ return "I can process both text and voice messages. Try sending me a voice message!"
83
+ else:
84
+ return f"I received your message: '{user_input}'. How can I assist you further?"
85
+
86
+ def _get_or_create_session(self, session_id: str) -> Dict[str, Any]:
87
+ if session_id not in self.sessions:
88
+ self.sessions[session_id] = {
89
+ "conversation_history": [],
90
+ "created_at": datetime.now(),
91
+ "last_activity": datetime.now()
92
+ }
93
+ else:
94
+ self.sessions[session_id]["last_activity"] = datetime.now()
95
+
96
+ return self.sessions[session_id]
97
+
98
+ def get_session_history(self, session_id: str) -> Optional[Dict[str, Any]]:
99
+ return self.sessions.get(session_id)
services/stt_service.py CHANGED
@@ -1,66 +1,94 @@
1
- import logging
 
2
  import tempfile
3
  import os
4
- from transformers import pipeline
5
- import librosa
6
- import numpy as np
7
 
8
- logger = logging.getLogger(__name__)
9
- stt_pipeline = None
10
-
11
- def load_stt_model():
12
- global stt_pipeline
13
- try:
14
- logger.info("Loading Whisper-base STT model...")
15
- stt_pipeline = pipeline(
16
- "automatic-speech-recognition",
17
- model="openai/whisper-base",
18
- device="cpu", # Use "cuda" if GPU available
19
- chunk_length_s=30,
20
- )
21
- logger.info("✓ Whisper STT model loaded successfully")
22
- except Exception as e:
23
- logger.error(f"✗ Failed to load STT model: {str(e)}")
24
- stt_pipeline = None
25
-
26
- async def speech_to_text(audio_bytes: bytes, filename: str) -> str:
27
- """
28
- Convert audio bytes to text using Whisper.
29
- Handles WAV, MP3, M4A formats automatically.
30
- """
31
- global stt_pipeline
32
 
33
- try:
34
- if stt_pipeline is None:
35
- load_stt_model()
36
- if stt_pipeline is None:
37
- raise Exception("STT model not loaded")
38
-
39
- logger.info(f"STT: Converting audio file '{filename}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Save to temporary file
42
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
43
- tmp.write(audio_bytes)
44
- tmp_path = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  try:
47
- # Load and resample audio to 16kHz
48
- audio, sr = librosa.load(tmp_path, sr=16000)
49
 
50
- # Transcribe
51
- result = stt_pipeline(audio, generate_kwargs={"language": "english"})
52
- text = result["text"].strip()
53
 
54
- if not text:
55
- text = "[Silent audio or unrecognizable speech]"
56
 
57
- logger.info(f"✓ STT Success: '{text}'")
58
- return text
59
 
60
- finally:
61
- if os.path.exists(tmp_path):
62
- os.unlink(tmp_path)
63
-
64
- except Exception as e:
65
- logger.error(f"✗ STT Error: {str(e)}")
66
- raise Exception(f"STT failed: {str(e)}")
 
1
+ import base64
2
+ import io
3
  import tempfile
4
  import os
5
+ import wave
6
+ import audioop
 
7
 
8
+ class STTService:
9
+ def __init__(self):
10
+ self.initialized = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ async def initialize(self):
13
+ """Initialize STT service"""
14
+ # For now, we'll use a simple approach without external dependencies
15
+ self.initialized = True
16
+ print("STT Service initialized (basic mode)")
17
+
18
+ async def transcribe_audio_base64(self, audio_base64: str, language: str = "en-US") -> str:
19
+ """Transcribe base64 audio to text - SIMPLIFIED VERSION"""
20
+ try:
21
+ # Decode audio
22
+ audio_data = base64.b64decode(audio_base64)
23
+
24
+ # For now, return a placeholder since we don't have STT models configured
25
+ # In a real implementation, you would use Whisper, Vosk, or other STT models here
26
+
27
+ audio_info = await self._get_audio_info(audio_data)
28
+ return f"[Audio received: {audio_info}. STT service needs model configuration.]"
29
+
30
+ except Exception as e:
31
+ print(f"Transcription error: {e}")
32
+ return "Sorry, I couldn't process the audio message."
33
+
34
+ async def _get_audio_info(self, audio_data: bytes) -> str:
35
+ """Get basic information about the audio file"""
36
+ try:
37
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
38
+ temp_path = temp_file.name
39
+ temp_file.write(audio_data)
40
+
41
+ try:
42
+ with wave.open(temp_path, 'rb') as wav_file:
43
+ frames = wav_file.getnframes()
44
+ rate = wav_file.getframerate()
45
+ duration = frames / float(rate)
46
+ return f"Duration: {duration:.2f}s, Sample Rate: {rate}Hz"
47
+ except:
48
+ return f"Size: {len(audio_data)} bytes"
49
 
50
+ finally:
51
+ if os.path.exists(temp_path):
52
+ os.unlink(temp_path)
53
+
54
+ # Alternative STT service using Whisper if available
55
+ class WhisperSTTService:
56
+ def __init__(self):
57
+ self.model = None
58
+ self.initialized = False
59
+
60
+ async def initialize(self):
61
+ """Initialize Whisper STT service"""
62
+ try:
63
+ import whisper
64
+ self.model = whisper.load_model("medium")
65
+ self.initialized = True
66
+ print("✓ Whisper STT Service initialized")
67
+ except ImportError:
68
+ print("⚠️ Whisper not available. Install with: pip install openai-whisper")
69
+ self.initialized = False
70
+ except Exception as e:
71
+ print(f"⚠️ Whisper initialization failed: {e}")
72
+ self.initialized = False
73
+
74
+ async def transcribe_audio_base64(self, audio_base64: str, language: str = "en") -> str:
75
+ """Transcribe using Whisper"""
76
+ if not self.initialized:
77
+ return "STT service not available. Please install Whisper."
78
 
79
  try:
80
+ audio_data = base64.b64decode(audio_base64)
 
81
 
82
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
83
+ temp_path = temp_file.name
84
+ temp_file.write(audio_data)
85
 
86
+ result = self.model.transcribe(temp_path, language=language)
87
+ transcription = result["text"]
88
 
89
+ os.unlink(temp_path)
90
+ return transcription
91
 
92
+ except Exception as e:
93
+ print(f"Whisper transcription error: {e}")
94
+ return "Sorry, I couldn't transcribe the audio."
 
 
 
 
services/tts_service.py CHANGED
@@ -1,49 +1,121 @@
1
- import logging
2
  import io
 
 
3
  from gtts import gTTS
4
- import numpy as np
5
- import wave
6
 
7
- logger = logging.getLogger(__name__)
8
-
9
- async def generate_tts(text: str) -> bytes:
10
- """
11
- Convert text to speech using Google Text-to-Speech (gTTS).
12
- Returns MP3 audio bytes.
13
- """
14
- try:
15
- if not text or len(text) > 500:
16
- raise ValueError("Text must be between 1-500 characters")
17
 
18
- logger.info(f"TTS: Generating audio for '{text[:50]}...'")
 
 
 
 
 
 
19
 
20
- # Generate audio using gTTS
21
- tts = gTTS(text=text, lang='en', slow=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Save to bytes buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  audio_buffer = io.BytesIO()
25
  tts.write_to_fp(audio_buffer)
26
- audio_bytes = audio_buffer.getvalue()
 
 
 
 
 
 
 
27
 
28
- logger.info(f"✓ TTS Success: {len(audio_bytes)} bytes generated")
29
- return audio_bytes
30
 
31
- except Exception as e:
32
- logger.error(f"✗ TTS Error: {str(e)}")
33
- # Return fallback silent audio
34
- return generate_silent_wav()
35
-
36
- def generate_silent_wav() -> bytes:
37
- """Generate 1-second silent WAV file as fallback"""
38
- sample_rate = 22050
39
- duration = 1.0
40
- silence = np.zeros(int(sample_rate * duration), dtype=np.int16)
41
 
42
- buffer = io.BytesIO()
43
- with wave.open(buffer, 'wb') as wav:
44
- wav.setnchannels(1)
45
- wav.setsampwidth(2)
46
- wav.setframerate(sample_rate)
47
- wav.writeframes(silence.tobytes())
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- return buffer.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
  import io
3
+ import tempfile
4
+ import os
5
  from gtts import gTTS
6
+ import pyttsx3
 
7
 
8
+ class TTSService:
9
+ def __init__(self):
10
+ self.models = {}
11
+ self._initialize_models()
12
+
13
+ def _initialize_models(self):
14
+ """Initialize TTS models"""
15
+ # gTTS is our primary method (always available)
16
+ self.models["gtts"] = True
 
17
 
18
+ # Try to initialize pyttsx3 as fallback
19
+ try:
20
+ self.models["pyttsx3"] = pyttsx3.init()
21
+ print("✓ pyttsx3 TTS initialized")
22
+ except:
23
+ print("⚠️ pyttsx3 not available")
24
+ self.models["pyttsx3"] = None
25
 
26
+ # Coqui TTS is optional
27
+ self.models["coqui"] = self._initialize_coqui_tts()
28
+
29
+ def _initialize_coqui_tts(self):
30
+ """Initialize Coqui TTS if available"""
31
+ try:
32
+ from TTS.api import TTS
33
+ tts_model = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False)
34
+ print("✓ Coqui TTS initialized")
35
+ return tts_model
36
+ except ImportError:
37
+ print("⚠️ Coqui TTS not available. Install with: pip install TTS")
38
+ return None
39
+ except Exception as e:
40
+ print(f"⚠️ Coqui TTS initialization failed: {e}")
41
+ return None
42
+
43
+ async def text_to_speech_base64(self, text: str, language: str = "en") -> str:
44
+ """Convert text to base64 audio"""
45
+ # Try gTTS first (most reliable and free)
46
+ try:
47
+ return await self._gtts_to_base64(text, language)
48
+ except Exception as e:
49
+ print(f"gTTS error: {e}")
50
 
51
+ # Fallback to pyttsx3
52
+ try:
53
+ if self.models.get("pyttsx3"):
54
+ return await self._pyttsx3_to_base64(text)
55
+ except Exception as e:
56
+ print(f"pyttsx3 error: {e}")
57
+
58
+ # Final fallback to Coqui TTS
59
+ try:
60
+ if self.models.get("coqui"):
61
+ return await self._coqui_to_base64(text)
62
+ except Exception as e:
63
+ print(f"Coqui TTS error: {e}")
64
+
65
+ raise Exception("All TTS services failed")
66
+
67
+ async def _gtts_to_base64(self, text: str, language: str) -> str:
68
+ """Convert using gTTS"""
69
+ tts = gTTS(text=text, lang=language, slow=False)
70
  audio_buffer = io.BytesIO()
71
  tts.write_to_fp(audio_buffer)
72
+ audio_buffer.seek(0)
73
+ return base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
74
+
75
+ async def _pyttsx3_to_base64(self, text: str) -> str:
76
+ """Convert using pyttsx3"""
77
+ engine = self.models["pyttsx3"]
78
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
79
+ temp_path = temp_file.name
80
 
81
+ engine.save_to_file(text, temp_path)
82
+ engine.runAndWait()
83
 
84
+ with open(temp_path, 'rb') as audio_file:
85
+ audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
86
+
87
+ # Cleanup
88
+ os.unlink(temp_path)
89
+ return audio_base64
 
 
 
 
90
 
91
+ async def _coqui_to_base64(self, text: str) -> str:
92
+ """Convert using Coqui TTS"""
93
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
94
+ temp_path = temp_file.name
95
+
96
+ self.models["coqui"].tts_to_file(text=text, file_path=temp_path)
97
+
98
+ with open(temp_path, 'rb') as audio_file:
99
+ audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
100
+
101
+ # Cleanup
102
+ os.unlink(temp_path)
103
+ return audio_base64
104
+
105
+ # Simple TTS service that only uses gTTS (minimal dependencies)
106
+ class SimpleTTSService:
107
+ def __init__(self):
108
+ pass
109
 
110
+ async def text_to_speech_base64(self, text: str, language: str = "en") -> str:
111
+ """Convert text to base64 audio using only gTTS"""
112
+ try:
113
+ tts = gTTS(text=text, lang=language, slow=False)
114
+ audio_buffer = io.BytesIO()
115
+ tts.write_to_fp(audio_buffer)
116
+ audio_buffer.seek(0)
117
+ return base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
118
+ except Exception as e:
119
+ print(f"gTTS error: {e}")
120
+ # Return a placeholder audio or error message
121
+ return "TTS_ERROR_PLACEHOLDER"