Merge branch 'main' of https://huggingface.co/spaces/NLP-Debater-Project/FastAPI-Backend-Models
Browse files- config.py +59 -14
- kpa-t5-improved.ipynb +0 -0
- main.py +248 -22
- models/__init__.py +27 -0
- models/mcp_models.py +113 -0
- models/stt.py +11 -0
- models/tts.py +15 -0
- models/voice_chat.py +20 -0
- requirements.txt +21 -8
- routes/__init__.py +10 -0
- routes/mcp_routes.py +502 -0
- routes/stt_routes.py +42 -0
- routes/tts_routes.py +40 -0
- routes/voice_chat_routes.py +220 -0
- services/__init__.py +11 -0
- services/chat_service.py +138 -0
- services/mcp_service.py +89 -0
- services/stt_service.py +38 -0
- services/tts_service.py +57 -0
config.py
CHANGED
|
@@ -3,38 +3,83 @@
|
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
| 5 |
from dotenv import load_dotenv
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
load_dotenv()
|
| 9 |
|
| 10 |
-
#
|
| 11 |
API_DIR = Path(__file__).parent
|
| 12 |
PROJECT_ROOT = API_DIR.parent
|
| 13 |
|
| 14 |
-
#
|
| 15 |
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
|
| 16 |
-
|
| 17 |
HUGGINGFACE_STANCE_MODEL_ID = os.getenv("HUGGINGFACE_STANCE_MODEL_ID")
|
| 18 |
HUGGINGFACE_LABEL_MODEL_ID = os.getenv("HUGGINGFACE_LABEL_MODEL_ID")
|
| 19 |
HUGGINGFACE_GENERATE_MODEL_ID = os.getenv("HUGGINGFACE_GENERATE_MODEL_ID", "YOUR_ORG/kpa-t5-improved")
|
| 20 |
|
|
|
|
| 21 |
# Use Hugging Face model ID instead of local path
|
| 22 |
STANCE_MODEL_ID = HUGGINGFACE_STANCE_MODEL_ID
|
| 23 |
LABEL_MODEL_ID = HUGGINGFACE_LABEL_MODEL_ID
|
| 24 |
GENERATE_MODEL_ID = HUGGINGFACE_GENERATE_MODEL_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
API_DESCRIPTION = "API for various NLP models including stance detection and more"
|
| 29 |
-
API_VERSION = "1.0.0"
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
CORS_CREDENTIALS = True
|
| 39 |
CORS_METHODS = ["*"]
|
| 40 |
CORS_HEADERS = ["*"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
+
import logging
|
| 7 |
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
# Load .env variables
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
+
# ============ DIRECTORIES ============
|
| 14 |
API_DIR = Path(__file__).parent
|
| 15 |
PROJECT_ROOT = API_DIR.parent
|
| 16 |
|
| 17 |
+
# ============ HUGGING FACE MODELS ============
|
| 18 |
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
|
|
|
|
| 19 |
HUGGINGFACE_STANCE_MODEL_ID = os.getenv("HUGGINGFACE_STANCE_MODEL_ID")
|
| 20 |
HUGGINGFACE_LABEL_MODEL_ID = os.getenv("HUGGINGFACE_LABEL_MODEL_ID")
|
| 21 |
HUGGINGFACE_GENERATE_MODEL_ID = os.getenv("HUGGINGFACE_GENERATE_MODEL_ID", "YOUR_ORG/kpa-t5-improved")
|
| 22 |
|
| 23 |
+
<<<<<<< HEAD
|
| 24 |
# Use Hugging Face model ID instead of local path
|
| 25 |
STANCE_MODEL_ID = HUGGINGFACE_STANCE_MODEL_ID
|
| 26 |
LABEL_MODEL_ID = HUGGINGFACE_LABEL_MODEL_ID
|
| 27 |
GENERATE_MODEL_ID = HUGGINGFACE_GENERATE_MODEL_ID
|
| 28 |
+
=======
|
| 29 |
+
# ============ GROQ MODELS ============
|
| 30 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
|
| 31 |
+
>>>>>>> 45e145b23965e35eefb7990d09b535073040e40a
|
| 32 |
+
|
| 33 |
+
# **Speech-to-Text**
|
| 34 |
+
GROQ_STT_MODEL = "whisper-large-v3-turbo"
|
| 35 |
+
|
| 36 |
+
# **Text-to-Speech**
|
| 37 |
+
GROQ_TTS_MODEL = "playai-tts"
|
| 38 |
+
GROQ_TTS_VOICE = "Aaliyah-PlayAI"
|
| 39 |
+
GROQ_TTS_FORMAT = "wav"
|
| 40 |
|
| 41 |
+
# **Chat Model**
|
| 42 |
+
GROQ_CHAT_MODEL = "llama3-70b-8192"
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
# ============ API META ============
|
| 45 |
+
API_TITLE = "NLP Debater - Voice Chatbot"
|
| 46 |
+
API_DESCRIPTION = "NLP stance detection, KPA, and Groq STT/TTS chatbot"
|
| 47 |
+
API_VERSION = "2.0.0"
|
| 48 |
|
| 49 |
+
# ============ SERVER ============
|
| 50 |
+
HOST = os.getenv("HOST", "0.0.0.0")
|
| 51 |
+
PORT = int(os.getenv("PORT", 7860))
|
| 52 |
+
RELOAD = os.getenv("RELOAD", "false").lower() == "true"
|
| 53 |
+
|
| 54 |
+
# ============ CORS ============
|
| 55 |
+
CORS_ORIGINS = ["*"]
|
| 56 |
CORS_CREDENTIALS = True
|
| 57 |
CORS_METHODS = ["*"]
|
| 58 |
CORS_HEADERS = ["*"]
|
| 59 |
+
|
| 60 |
+
# ============ AUDIO SETTINGS ============
|
| 61 |
+
MAX_AUDIO_SIZE = 10 * 1024 * 1024 # 10MB
|
| 62 |
+
AUDIO_SAMPLE_RATE = 16000
|
| 63 |
+
AUDIO_DURATION_LIMIT = 120 # seconds
|
| 64 |
+
ALLOWED_AUDIO_TYPES = {
|
| 65 |
+
"audio/wav", "audio/x-wav",
|
| 66 |
+
"audio/mpeg", "audio/mp3",
|
| 67 |
+
"audio/mp4", "audio/m4a"
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# ============ MODEL PRELOADING ============
|
| 71 |
+
PRELOAD_MODELS_ON_STARTUP = True
|
| 72 |
+
LOAD_STANCE_MODEL = True
|
| 73 |
+
LOAD_KPA_MODEL = True
|
| 74 |
+
LOAD_STT_MODEL = False # Groq STT = no preload
|
| 75 |
+
LOAD_CHATBOT_MODEL = False # Groq Chat = no preload
|
| 76 |
+
LOAD_TTS_MODEL = False # Groq TTS = no preload
|
| 77 |
+
|
| 78 |
+
logger.info("="*60)
|
| 79 |
+
logger.info("✓ Configuration loaded successfully")
|
| 80 |
+
logger.info(f" HF Stance Model : {HUGGINGFACE_STANCE_MODEL_ID}")
|
| 81 |
+
logger.info(f" HF Label Model : {HUGGINGFACE_LABEL_MODEL_ID}")
|
| 82 |
+
logger.info(f" GROQ STT Model : {GROQ_STT_MODEL}")
|
| 83 |
+
logger.info(f" GROQ TTS Model : {GROQ_TTS_MODEL}")
|
| 84 |
+
logger.info(f" GROQ Chat Model : {GROQ_CHAT_MODEL}")
|
| 85 |
+
logger.info("="*60)
|
kpa-t5-improved.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main.py
CHANGED
|
@@ -1,19 +1,16 @@
|
|
| 1 |
-
"""Main FastAPI application entry point"""
|
| 2 |
-
|
| 3 |
import sys
|
| 4 |
from pathlib import Path
|
| 5 |
-
|
| 6 |
-
# Add the app directory to Python path to ensure imports work
|
| 7 |
-
app_dir = Path(__file__).parent
|
| 8 |
-
if str(app_dir) not in sys.path:
|
| 9 |
-
sys.path.insert(0, str(app_dir))
|
| 10 |
-
|
| 11 |
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
|
|
|
| 12 |
from fastapi import FastAPI
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 14 |
import uvicorn
|
| 15 |
-
import logging
|
| 16 |
|
|
|
|
| 17 |
from config import (
|
| 18 |
API_TITLE,
|
| 19 |
API_DESCRIPTION,
|
|
@@ -37,15 +34,81 @@ from routes import api_router
|
|
| 37 |
|
| 38 |
# Configure logging
|
| 39 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
logger = logging.getLogger(__name__)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
@asynccontextmanager
|
| 44 |
async def lifespan(app: FastAPI):
|
| 45 |
-
""
|
| 46 |
-
|
| 47 |
-
logger.info("
|
| 48 |
|
|
|
|
| 49 |
# Load stance detection model
|
| 50 |
try:
|
| 51 |
logger.info(f"Loading stance model from Hugging Face: {HUGGINGFACE_STANCE_MODEL_ID}")
|
|
@@ -77,18 +140,72 @@ async def lifespan(app: FastAPI):
|
|
| 77 |
|
| 78 |
# Shutdown: Cleanup (if needed)
|
| 79 |
# Currently no cleanup needed, but you can add it here if necessary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
app = FastAPI(
|
| 83 |
title=API_TITLE,
|
| 84 |
description=API_DESCRIPTION,
|
| 85 |
version=API_VERSION,
|
|
|
|
| 86 |
docs_url="/docs",
|
| 87 |
redoc_url="/redoc",
|
| 88 |
-
|
| 89 |
)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
app.add_middleware(
|
| 93 |
CORSMiddleware,
|
| 94 |
allow_origins=CORS_ORIGINS,
|
|
@@ -97,20 +214,129 @@ app.add_middleware(
|
|
| 97 |
allow_headers=CORS_HEADERS,
|
| 98 |
)
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
|
|
|
| 104 |
if __name__ == "__main__":
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
-
# Run the API server
|
| 110 |
uvicorn.run(
|
| 111 |
"main:app",
|
| 112 |
host=HOST,
|
| 113 |
port=PORT,
|
| 114 |
reload=RELOAD,
|
| 115 |
log_level="info"
|
| 116 |
-
)
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
| 3 |
+
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from contextlib import asynccontextmanager
|
| 5 |
+
import atexit
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
from fastapi import FastAPI
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.responses import RedirectResponse
|
| 11 |
import uvicorn
|
|
|
|
| 12 |
|
| 13 |
+
<<<<<<< HEAD
|
| 14 |
from config import (
|
| 15 |
API_TITLE,
|
| 16 |
API_DESCRIPTION,
|
|
|
|
| 34 |
|
| 35 |
# Configure logging
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
| 37 |
+
=======
|
| 38 |
+
# --- Logging ---
|
| 39 |
+
logging.basicConfig(
|
| 40 |
+
level=logging.INFO,
|
| 41 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 42 |
+
)
|
| 43 |
+
>>>>>>> 45e145b23965e35eefb7990d09b535073040e40a
|
| 44 |
logger = logging.getLogger(__name__)
|
| 45 |
|
| 46 |
+
# --- Ajouter app dir au PATH ---
|
| 47 |
+
app_dir = Path(__file__).parent
|
| 48 |
+
sys.path.insert(0, str(app_dir))
|
| 49 |
+
|
| 50 |
+
# --- Config ---
|
| 51 |
+
from config import (
|
| 52 |
+
API_TITLE, API_DESCRIPTION, API_VERSION,
|
| 53 |
+
HUGGINGFACE_API_KEY, HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_LABEL_MODEL_ID,
|
| 54 |
+
HOST, PORT, RELOAD,
|
| 55 |
+
CORS_ORIGINS, CORS_METHODS, CORS_HEADERS, CORS_CREDENTIALS,
|
| 56 |
+
PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL,
|
| 57 |
+
GROQ_API_KEY, GROQ_STT_MODEL, GROQ_TTS_MODEL, GROQ_CHAT_MODEL
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# --- Fonction de nettoyage ---
|
| 61 |
+
def cleanup_temp_files():
|
| 62 |
+
"""Nettoyer les fichiers temporaires audio au démarrage"""
|
| 63 |
+
temp_dir = Path("temp_audio")
|
| 64 |
+
if temp_dir.exists():
|
| 65 |
+
try:
|
| 66 |
+
shutil.rmtree(temp_dir)
|
| 67 |
+
logger.info("✓ Fichiers temporaires audio nettoyés")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.warning(f"⚠ Impossible de nettoyer le répertoire temporaire: {e}")
|
| 70 |
+
|
| 71 |
+
# Appeler au démarrage
|
| 72 |
+
cleanup_temp_files()
|
| 73 |
|
| 74 |
+
# Configurer le nettoyage à la fermeture
|
| 75 |
+
@atexit.register
|
| 76 |
+
def cleanup_on_exit():
|
| 77 |
+
temp_dir = Path("temp_audio")
|
| 78 |
+
if temp_dir.exists():
|
| 79 |
+
try:
|
| 80 |
+
shutil.rmtree(temp_dir)
|
| 81 |
+
logger.info("Nettoyage final des fichiers temporaires")
|
| 82 |
+
except:
|
| 83 |
+
logger.warning("Échec du nettoyage final")
|
| 84 |
+
|
| 85 |
+
# --- Import des singletons de services ---
|
| 86 |
+
stance_model_manager = None
|
| 87 |
+
kpa_model_manager = None
|
| 88 |
+
try:
|
| 89 |
+
from services.stance_model_manager import stance_model_manager
|
| 90 |
+
from services.label_model_manager import kpa_model_manager # Corrigé : kpa_model_manager partout
|
| 91 |
+
logger.info("✓ Gestionnaires de modèles importés")
|
| 92 |
+
except ImportError as e:
|
| 93 |
+
logger.warning(f"⚠ Impossible d'importer les gestionnaires de modèles: {e}")
|
| 94 |
+
|
| 95 |
+
# --- Vérification MCP ---
|
| 96 |
+
MCP_ENABLED = False
|
| 97 |
+
try:
|
| 98 |
+
from services.mcp_service import init_mcp_server
|
| 99 |
+
MCP_ENABLED = True
|
| 100 |
+
logger.info("✓ Modules MCP détectés")
|
| 101 |
+
except ImportError as e:
|
| 102 |
+
logger.warning(f"⚠ MCP non disponible: {e}")
|
| 103 |
+
|
| 104 |
+
# --- Lifespan / startup API ---
|
| 105 |
@asynccontextmanager
|
| 106 |
async def lifespan(app: FastAPI):
|
| 107 |
+
logger.info("="*60)
|
| 108 |
+
logger.info("🚀 DÉMARRAGE API - Chargement des modèles et vérification des APIs...")
|
| 109 |
+
logger.info("="*60)
|
| 110 |
|
| 111 |
+
<<<<<<< HEAD
|
| 112 |
# Load stance detection model
|
| 113 |
try:
|
| 114 |
logger.info(f"Loading stance model from Hugging Face: {HUGGINGFACE_STANCE_MODEL_ID}")
|
|
|
|
| 140 |
|
| 141 |
# Shutdown: Cleanup (if needed)
|
| 142 |
# Currently no cleanup needed, but you can add it here if necessary
|
| 143 |
+
=======
|
| 144 |
+
# Vérifier les clés API
|
| 145 |
+
if not GROQ_API_KEY:
|
| 146 |
+
logger.warning("⚠ GROQ_API_KEY non configurée. Fonctions STT/TTS désactivées.")
|
| 147 |
+
else:
|
| 148 |
+
logger.info("✓ GROQ_API_KEY configurée")
|
| 149 |
+
|
| 150 |
+
if not HUGGINGFACE_API_KEY:
|
| 151 |
+
logger.warning("⚠ HUGGINGFACE_API_KEY non configurée. Modèles locaux désactivés.")
|
| 152 |
+
else:
|
| 153 |
+
logger.info("✓ HUGGINGFACE_API_KEY configurée")
|
| 154 |
+
>>>>>>> 45e145b23965e35eefb7990d09b535073040e40a
|
| 155 |
+
|
| 156 |
+
# Précharger les modèles Hugging Face si configuré
|
| 157 |
+
if PRELOAD_MODELS_ON_STARTUP:
|
| 158 |
+
|
| 159 |
+
# Charger stance model
|
| 160 |
+
if LOAD_STANCE_MODEL and stance_model_manager and HUGGINGFACE_STANCE_MODEL_ID:
|
| 161 |
+
try:
|
| 162 |
+
stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
|
| 163 |
+
logger.info("✓ Modèle de détection de stance chargé")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"✗ Échec chargement modèle stance: {e}")
|
| 166 |
|
| 167 |
+
# Charger KPA model
|
| 168 |
+
if LOAD_KPA_MODEL and kpa_model_manager and HUGGINGFACE_LABEL_MODEL_ID:
|
| 169 |
+
try:
|
| 170 |
+
kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
|
| 171 |
+
logger.info("✓ Modèle KPA chargé")
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error(f"✗ Échec chargement modèle KPA: {e}")
|
| 174 |
+
|
| 175 |
+
# Initialiser MCP si disponible
|
| 176 |
+
if MCP_ENABLED:
|
| 177 |
+
try:
|
| 178 |
+
init_mcp_server(app)
|
| 179 |
+
logger.info("✓ Serveur MCP initialisé")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"✗ Échec initialisation MCP: {e}")
|
| 182 |
+
|
| 183 |
+
logger.info("="*60)
|
| 184 |
+
logger.info("✓ Démarrage terminé. API prête à recevoir des requêtes.")
|
| 185 |
+
logger.info(f" STT Model: {GROQ_STT_MODEL}")
|
| 186 |
+
logger.info(f" TTS Model: {GROQ_TTS_MODEL}")
|
| 187 |
+
logger.info(f" Chat Model: {GROQ_CHAT_MODEL}")
|
| 188 |
+
logger.info(f" MCP: {'Activé' if MCP_ENABLED else 'Désactivé'}")
|
| 189 |
+
logger.info("="*60)
|
| 190 |
+
|
| 191 |
+
yield
|
| 192 |
+
|
| 193 |
+
logger.info("🛑 Arrêt de l'API...")
|
| 194 |
+
# Nettoyage final
|
| 195 |
+
cleanup_on_exit()
|
| 196 |
+
|
| 197 |
+
# --- FastAPI app ---
|
| 198 |
app = FastAPI(
|
| 199 |
title=API_TITLE,
|
| 200 |
description=API_DESCRIPTION,
|
| 201 |
version=API_VERSION,
|
| 202 |
+
lifespan=lifespan,
|
| 203 |
docs_url="/docs",
|
| 204 |
redoc_url="/redoc",
|
| 205 |
+
openapi_url="/openapi.json"
|
| 206 |
)
|
| 207 |
|
| 208 |
+
# --- CORS ---
|
| 209 |
app.add_middleware(
|
| 210 |
CORSMiddleware,
|
| 211 |
allow_origins=CORS_ORIGINS,
|
|
|
|
| 214 |
allow_headers=CORS_HEADERS,
|
| 215 |
)
|
| 216 |
|
| 217 |
+
# --- Routes ---
|
| 218 |
+
# STT Routes
|
| 219 |
+
try:
|
| 220 |
+
from routes.stt_routes import router as stt_router
|
| 221 |
+
app.include_router(stt_router, prefix="/api/v1/stt", tags=["Speech To Text"])
|
| 222 |
+
logger.info("✓ Route STT chargée (Groq Whisper)")
|
| 223 |
+
except ImportError as e:
|
| 224 |
+
logger.warning(f"⚠ Route STT non trouvée: {e}")
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.warning(f"⚠ Échec chargement route STT: {e}")
|
| 227 |
+
|
| 228 |
+
# TTS Routes
|
| 229 |
+
try:
|
| 230 |
+
from routes.tts_routes import router as tts_router
|
| 231 |
+
app.include_router(tts_router, prefix="/api/v1/tts", tags=["Text To Speech"])
|
| 232 |
+
logger.info("✓ Route TTS chargée (Groq PlayAI TTS)")
|
| 233 |
+
except ImportError as e:
|
| 234 |
+
logger.warning(f"⚠ Route TTS non trouvée: {e}")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.warning(f"⚠ Échec chargement route TTS: {e}")
|
| 237 |
+
|
| 238 |
+
# Voice Chat Routes
|
| 239 |
+
try:
|
| 240 |
+
from routes.voice_chat_routes import router as voice_chat_router
|
| 241 |
+
app.include_router(voice_chat_router, tags=["Voice Chat"])
|
| 242 |
+
logger.info("✓ Route Voice Chat chargée")
|
| 243 |
+
except ImportError as e:
|
| 244 |
+
logger.warning(f"⚠ Route Voice Chat non trouvée: {e}")
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.warning(f"⚠ Échec chargement route Voice Chat: {e}")
|
| 247 |
+
|
| 248 |
+
# Main API Routes (KPA, Stance, etc.)
|
| 249 |
+
try:
|
| 250 |
+
from routes import api_router
|
| 251 |
+
app.include_router(api_router, prefix="/api/v1")
|
| 252 |
+
logger.info("✓ Routes API principales chargées")
|
| 253 |
+
except ImportError as e:
|
| 254 |
+
logger.warning(f"⚠ Routes API principales non trouvées: {e}")
|
| 255 |
+
# Fallback
|
| 256 |
+
try:
|
| 257 |
+
from routes.label import router as kpa_router
|
| 258 |
+
app.include_router(kpa_router, prefix="/api/v1/kpa", tags=["KPA"])
|
| 259 |
+
from routes.stance import router as stance_router
|
| 260 |
+
app.include_router(stance_router, prefix="/api/v1/stance", tags=["Stance Detection"])
|
| 261 |
+
logger.info("✓ Routes KPA et Stance chargées en fallback")
|
| 262 |
+
except ImportError:
|
| 263 |
+
logger.warning("⚠ Fallback pour KPA/Stance échoué")
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.warning(f"⚠ Échec chargement routes API principales: {e}")
|
| 266 |
+
|
| 267 |
+
# MCP Routes - Routes FastAPI pour Swagger UI + mount pour compatibilité MCP
|
| 268 |
+
if MCP_ENABLED:
|
| 269 |
+
try:
|
| 270 |
+
from routes.mcp_routes import router as mcp_router
|
| 271 |
+
app.include_router(mcp_router) # Pas de prefix car déjà dans le router
|
| 272 |
+
logger.info("✓ Routes MCP FastAPI chargées (visibles dans Swagger)")
|
| 273 |
+
except ImportError as e:
|
| 274 |
+
logger.warning(f"⚠ Routes MCP FastAPI non trouvées: {e}")
|
| 275 |
+
except Exception as e:
|
| 276 |
+
logger.warning(f"⚠ Échec chargement routes MCP FastAPI: {e}")
|
| 277 |
+
|
| 278 |
+
logger.info("✓ MCP monté via lifespan (endpoints auto-gérés)")
|
| 279 |
+
else:
|
| 280 |
+
logger.warning("⚠ MCP désactivé")
|
| 281 |
+
|
| 282 |
+
# --- Basic routes ---
|
| 283 |
+
@app.get("/health", tags=["Health"])
|
| 284 |
+
async def health():
|
| 285 |
+
health_status = {
|
| 286 |
+
"status": "healthy",
|
| 287 |
+
"service": "NLP Debater + Groq Voice",
|
| 288 |
+
"version": API_VERSION,
|
| 289 |
+
"features": {
|
| 290 |
+
"stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled",
|
| 291 |
+
"tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled",
|
| 292 |
+
"chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled",
|
| 293 |
+
"stance_model": "loaded" if (stance_model_manager and hasattr(stance_model_manager, 'model_loaded') and stance_model_manager.model_loaded) else "not loaded",
|
| 294 |
+
"kpa_model": "loaded" if (kpa_model_manager and hasattr(kpa_model_manager, 'model_loaded') and kpa_model_manager.model_loaded) else "not loaded",
|
| 295 |
+
"mcp": "enabled" if MCP_ENABLED else "disabled"
|
| 296 |
+
},
|
| 297 |
+
"endpoints": {
|
| 298 |
+
"mcp": "/api/v1/mcp" if MCP_ENABLED else "disabled"
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
return health_status
|
| 302 |
+
|
| 303 |
+
@app.get("/", tags=["Root"])
|
| 304 |
+
async def root():
|
| 305 |
+
return RedirectResponse(url="/docs")
|
| 306 |
|
| 307 |
+
# --- Error handlers ---
|
| 308 |
+
@app.exception_handler(404)
|
| 309 |
+
async def not_found_handler(request, exc):
|
| 310 |
+
endpoints = {
|
| 311 |
+
"GET /": "Redirige vers /docs (Swagger UI)",
|
| 312 |
+
"GET /health": "Health check",
|
| 313 |
+
"POST /api/v1/stt/": "Speech to text",
|
| 314 |
+
"POST /api/v1/tts/": "Text to speech",
|
| 315 |
+
"POST /voice-chat/voice": "Voice chat"
|
| 316 |
+
}
|
| 317 |
+
if MCP_ENABLED:
|
| 318 |
+
endpoints.update({
|
| 319 |
+
"GET /api/v1/mcp/health": "Health check MCP",
|
| 320 |
+
"GET /api/v1/mcp/tools": "Liste outils MCP",
|
| 321 |
+
"POST /api/v1/mcp/tools/call": "Appel d'outil MCP"
|
| 322 |
+
})
|
| 323 |
+
return {
|
| 324 |
+
"error": "Not Found",
|
| 325 |
+
"message": f"URL {request.url} non trouvée",
|
| 326 |
+
"available_endpoints": endpoints
|
| 327 |
+
}
|
| 328 |
|
| 329 |
+
# --- Run server ---
|
| 330 |
if __name__ == "__main__":
|
| 331 |
+
logger.info("="*60)
|
| 332 |
+
logger.info(f"Démarrage du serveur sur {HOST}:{PORT}")
|
| 333 |
+
logger.info(f"Mode reload: {RELOAD}")
|
| 334 |
+
logger.info("="*60)
|
| 335 |
|
|
|
|
| 336 |
uvicorn.run(
|
| 337 |
"main:app",
|
| 338 |
host=HOST,
|
| 339 |
port=PORT,
|
| 340 |
reload=RELOAD,
|
| 341 |
log_level="info"
|
| 342 |
+
)
|
models/__init__.py
CHANGED
|
@@ -22,6 +22,21 @@ from .health import (
|
|
| 22 |
HealthResponse,
|
| 23 |
)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
__all__ = [
|
| 26 |
# Stance schemas
|
| 27 |
"StanceRequest",
|
|
@@ -36,4 +51,16 @@ __all__ = [
|
|
| 36 |
"KPAHealthResponse",
|
| 37 |
# Health schemas
|
| 38 |
"HealthResponse",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
]
|
|
|
|
| 22 |
HealthResponse,
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# Import MCP-related schemas
|
| 26 |
+
from .mcp_models import (
|
| 27 |
+
ToolCallRequest,
|
| 28 |
+
ToolCallResponse,
|
| 29 |
+
ToolInfo,
|
| 30 |
+
ToolListResponse,
|
| 31 |
+
ResourceInfo,
|
| 32 |
+
ResourceListResponse,
|
| 33 |
+
DetectStanceResponse,
|
| 34 |
+
MatchKeypointResponse,
|
| 35 |
+
TranscribeAudioResponse,
|
| 36 |
+
GenerateSpeechResponse,
|
| 37 |
+
GenerateArgumentResponse,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
__all__ = [
|
| 41 |
# Stance schemas
|
| 42 |
"StanceRequest",
|
|
|
|
| 51 |
"KPAHealthResponse",
|
| 52 |
# Health schemas
|
| 53 |
"HealthResponse",
|
| 54 |
+
# MCP schemas
|
| 55 |
+
"ToolCallRequest",
|
| 56 |
+
"ToolCallResponse",
|
| 57 |
+
"ToolInfo",
|
| 58 |
+
"ToolListResponse",
|
| 59 |
+
"ResourceInfo",
|
| 60 |
+
"ResourceListResponse",
|
| 61 |
+
"DetectStanceResponse",
|
| 62 |
+
"MatchKeypointResponse",
|
| 63 |
+
"TranscribeAudioResponse",
|
| 64 |
+
"GenerateSpeechResponse",
|
| 65 |
+
"GenerateArgumentResponse",
|
| 66 |
]
|
models/mcp_models.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
class ToolCallRequest(BaseModel):
|
| 5 |
+
"""Request for calling an MCP tool"""
|
| 6 |
+
tool_name: str
|
| 7 |
+
arguments: Dict[str, Any] = {}
|
| 8 |
+
|
| 9 |
+
class ToolCallResponse(BaseModel):
|
| 10 |
+
"""Response from MCP tool call"""
|
| 11 |
+
success: bool
|
| 12 |
+
result: Optional[Dict[str, Any]] = None
|
| 13 |
+
error: Optional[str] = None
|
| 14 |
+
tool_name: str
|
| 15 |
+
|
| 16 |
+
# Response models for individual MCP tools
|
| 17 |
+
class DetectStanceResponse(BaseModel):
|
| 18 |
+
"""Response model for stance detection"""
|
| 19 |
+
model_config = ConfigDict(
|
| 20 |
+
json_schema_extra={
|
| 21 |
+
"example": {
|
| 22 |
+
"predicted_stance": "PRO",
|
| 23 |
+
"confidence": 0.9598,
|
| 24 |
+
"probability_con": 0.0402,
|
| 25 |
+
"probability_pro": 0.9598
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
predicted_stance: str = Field(..., description="PRO or CON")
|
| 31 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
| 32 |
+
probability_con: float = Field(..., ge=0.0, le=1.0)
|
| 33 |
+
probability_pro: float = Field(..., ge=0.0, le=1.0)
|
| 34 |
+
|
| 35 |
+
class MatchKeypointResponse(BaseModel):
|
| 36 |
+
"""Response model for keypoint matching"""
|
| 37 |
+
model_config = ConfigDict(
|
| 38 |
+
json_schema_extra={
|
| 39 |
+
"example": {
|
| 40 |
+
"prediction": 1,
|
| 41 |
+
"label": "apparie",
|
| 42 |
+
"confidence": 0.8157,
|
| 43 |
+
"probabilities": {
|
| 44 |
+
"non_apparie": 0.1843,
|
| 45 |
+
"apparie": 0.8157
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
prediction: int = Field(..., description="1 = apparie, 0 = non_apparie")
|
| 52 |
+
label: str = Field(..., description="apparie or non_apparie")
|
| 53 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
| 54 |
+
probabilities: Dict[str, float] = Field(..., description="Dictionary of class probabilities")
|
| 55 |
+
|
| 56 |
+
class TranscribeAudioResponse(BaseModel):
|
| 57 |
+
"""Response model for audio transcription"""
|
| 58 |
+
model_config = ConfigDict(
|
| 59 |
+
json_schema_extra={
|
| 60 |
+
"example": {
|
| 61 |
+
"text": "Hello, this is the transcribed text from the audio file."
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
text: str = Field(..., description="Transcribed text from audio")
|
| 67 |
+
|
| 68 |
+
class GenerateSpeechResponse(BaseModel):
|
| 69 |
+
"""Response model for speech generation"""
|
| 70 |
+
model_config = ConfigDict(
|
| 71 |
+
json_schema_extra={
|
| 72 |
+
"example": {
|
| 73 |
+
"audio_path": "temp_audio/tts_e9b78164.wav"
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
audio_path: str = Field(..., description="Path to generated audio file")
|
| 79 |
+
|
| 80 |
+
class GenerateArgumentResponse(BaseModel):
|
| 81 |
+
"""Response model for argument generation"""
|
| 82 |
+
model_config = ConfigDict(
|
| 83 |
+
json_schema_extra={
|
| 84 |
+
"example": {
|
| 85 |
+
"argument": "Climate change is a pressing issue that requires immediate action..."
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
argument: str = Field(..., description="Generated debate argument")
|
| 91 |
+
|
| 92 |
+
class ResourceInfo(BaseModel):
|
| 93 |
+
"""Information about an MCP resource"""
|
| 94 |
+
uri: str
|
| 95 |
+
name: str
|
| 96 |
+
description: Optional[str] = None
|
| 97 |
+
mime_type: str
|
| 98 |
+
|
| 99 |
+
class ToolInfo(BaseModel):
|
| 100 |
+
"""Information about an MCP tool"""
|
| 101 |
+
name: str
|
| 102 |
+
description: str
|
| 103 |
+
input_schema: Dict[str, Any]
|
| 104 |
+
|
| 105 |
+
class ResourceListResponse(BaseModel):
|
| 106 |
+
"""Response for listing resources"""
|
| 107 |
+
resources: List[ResourceInfo]
|
| 108 |
+
count: int
|
| 109 |
+
|
| 110 |
+
class ToolListResponse(BaseModel):
|
| 111 |
+
"""Response for listing tools"""
|
| 112 |
+
tools: List[ToolInfo]
|
| 113 |
+
count: int
|
models/stt.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
class STTResponse(BaseModel):
|
| 4 |
+
text: str
|
| 5 |
+
|
| 6 |
+
class Config:
|
| 7 |
+
json_schema_extra = {
|
| 8 |
+
"example": {
|
| 9 |
+
"text": "Hello, how are you today?"
|
| 10 |
+
}
|
| 11 |
+
}
|
models/tts.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
class TTSRequest(BaseModel):
|
| 4 |
+
text: str = Field(..., min_length=1, max_length=5000)
|
| 5 |
+
voice: str = Field(default="Aaliyah-PlayAI")
|
| 6 |
+
format: str = Field(default="wav", pattern="^(wav|mp3)$")
|
| 7 |
+
|
| 8 |
+
class Config:
|
| 9 |
+
json_schema_extra = {
|
| 10 |
+
"example": {
|
| 11 |
+
"text": "Hello, this is a test of text-to-speech.",
|
| 12 |
+
"voice": "Aaliyah-PlayAI",
|
| 13 |
+
"format": "wav"
|
| 14 |
+
}
|
| 15 |
+
}
|
models/voice_chat.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
class TextChatRequest(BaseModel):
|
| 5 |
+
text: str
|
| 6 |
+
conversation_id: Optional[str] = None
|
| 7 |
+
|
| 8 |
+
class VoiceChatResponse(BaseModel):
|
| 9 |
+
text_response: str
|
| 10 |
+
audio_url: Optional[str] = None
|
| 11 |
+
conversation_id: str
|
| 12 |
+
|
| 13 |
+
class Config:
|
| 14 |
+
json_schema_extra = {
|
| 15 |
+
"example": {
|
| 16 |
+
"text_response": "Hello! How can I help you today?",
|
| 17 |
+
"audio_url": "/voice-chat/audio/123e4567",
|
| 18 |
+
"conversation_id": "123e4567"
|
| 19 |
+
}
|
| 20 |
+
}
|
requirements.txt
CHANGED
|
@@ -1,10 +1,23 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn[standard]
|
| 3 |
-
|
| 4 |
-
python-dotenv
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
transformers>=4.35.0
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
python-multipart>=0.0.6
|
| 4 |
+
python-dotenv>=1.0.0
|
| 5 |
+
pydantic>=2.5.0
|
| 6 |
+
|
| 7 |
+
# API Clients
|
| 8 |
+
requests>=2.31.0
|
| 9 |
+
groq>=0.9.0
|
| 10 |
+
|
| 11 |
+
# Audio processing (optionnel si vous avez besoin de traitement local)
|
| 12 |
+
soundfile>=0.12.1
|
| 13 |
+
|
| 14 |
+
# Hugging Face
|
| 15 |
transformers>=4.35.0
|
| 16 |
+
torch>=2.0.1
|
| 17 |
+
|
| 18 |
+
# Autres dépendances
|
| 19 |
+
numpy>=1.26.4
|
| 20 |
|
| 21 |
+
mcp>=1.0.0
|
| 22 |
+
# Note: fastapi-mcp peut ne pas exister officiellement,
|
| 23 |
+
# vous devrez probablement créer votre propre intégration
|
routes/__init__.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
"""API route handlers"""
|
| 2 |
|
| 3 |
from fastapi import APIRouter
|
|
|
|
| 4 |
from . import root, health, stance, label, generate
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# Create main router
|
| 7 |
api_router = APIRouter()
|
| 8 |
|
|
@@ -11,7 +16,12 @@ api_router.include_router(root.router)
|
|
| 11 |
api_router.include_router(health.router)
|
| 12 |
api_router.include_router(stance.router, prefix="/stance")
|
| 13 |
api_router.include_router(label.router, prefix="/label")
|
|
|
|
| 14 |
api_router.include_router(generate.router, prefix="/generate")
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
__all__ = ["api_router"]
|
| 17 |
|
|
|
|
|
|
| 1 |
"""API route handlers"""
|
| 2 |
|
| 3 |
from fastapi import APIRouter
|
| 4 |
+
<<<<<<< HEAD
|
| 5 |
from . import root, health, stance, label, generate
|
| 6 |
|
| 7 |
+
=======
|
| 8 |
+
from . import root, health, stance, label
|
| 9 |
+
from routes.tts_routes import router as audio_router
|
| 10 |
+
>>>>>>> 45e145b23965e35eefb7990d09b535073040e40a
|
| 11 |
# Create main router
|
| 12 |
api_router = APIRouter()
|
| 13 |
|
|
|
|
| 16 |
api_router.include_router(health.router)
|
| 17 |
api_router.include_router(stance.router, prefix="/stance")
|
| 18 |
api_router.include_router(label.router, prefix="/label")
|
| 19 |
+
<<<<<<< HEAD
|
| 20 |
api_router.include_router(generate.router, prefix="/generate")
|
| 21 |
+
=======
|
| 22 |
+
api_router.include_router(audio_router)
|
| 23 |
+
>>>>>>> 45e145b23965e35eefb7990d09b535073040e40a
|
| 24 |
|
| 25 |
__all__ = ["api_router"]
|
| 26 |
|
| 27 |
+
|
routes/mcp_routes.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Routes pour exposer MCP via FastAPI pour Swagger UI"""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, HTTPException
|
| 4 |
+
from typing import Dict, Any, Optional
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from services.mcp_service import mcp_server
|
| 10 |
+
from models.mcp_models import (
|
| 11 |
+
ToolListResponse,
|
| 12 |
+
ToolInfo,
|
| 13 |
+
ToolCallRequest,
|
| 14 |
+
ToolCallResponse,
|
| 15 |
+
DetectStanceResponse,
|
| 16 |
+
MatchKeypointResponse,
|
| 17 |
+
TranscribeAudioResponse,
|
| 18 |
+
GenerateSpeechResponse,
|
| 19 |
+
GenerateArgumentResponse
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
router = APIRouter(prefix="/api/v1/mcp", tags=["MCP"])
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ===== Models pour chaque outil MCP =====
|
| 27 |
+
|
| 28 |
+
class DetectStanceRequest(BaseModel):
|
| 29 |
+
"""Request pour détecter la stance d'un argument"""
|
| 30 |
+
topic: str = Field(..., description="Le sujet du débat")
|
| 31 |
+
argument: str = Field(..., description="L'argument à analyser")
|
| 32 |
+
|
| 33 |
+
class Config:
|
| 34 |
+
json_schema_extra = {
|
| 35 |
+
"example": {
|
| 36 |
+
"topic": "Climate change is real",
|
| 37 |
+
"argument": "Rising global temperatures prove it"
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
class MatchKeypointRequest(BaseModel):
|
| 42 |
+
"""Request pour matcher un argument avec un keypoint"""
|
| 43 |
+
argument: str = Field(..., description="L'argument à évaluer")
|
| 44 |
+
key_point: str = Field(..., description="Le keypoint de référence")
|
| 45 |
+
|
| 46 |
+
class Config:
|
| 47 |
+
json_schema_extra = {
|
| 48 |
+
"example": {
|
| 49 |
+
"argument": "Renewable energy reduces emissions",
|
| 50 |
+
"key_point": "Environmental benefits"
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
class TranscribeAudioRequest(BaseModel):
|
| 55 |
+
"""Request pour transcrire un audio"""
|
| 56 |
+
audio_path: str = Field(..., description="Chemin vers le fichier audio")
|
| 57 |
+
|
| 58 |
+
class Config:
|
| 59 |
+
json_schema_extra = {
|
| 60 |
+
"example": {
|
| 61 |
+
"audio_path": "/path/to/audio.wav"
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
class GenerateSpeechRequest(BaseModel):
|
| 66 |
+
"""Request pour générer de la parole"""
|
| 67 |
+
text: str = Field(..., description="Texte à convertir en parole")
|
| 68 |
+
voice: str = Field(default="Aaliyah-PlayAI", description="Voix à utiliser")
|
| 69 |
+
format: str = Field(default="wav", description="Format audio (wav, mp3, etc.)")
|
| 70 |
+
|
| 71 |
+
class Config:
|
| 72 |
+
json_schema_extra = {
|
| 73 |
+
"example": {
|
| 74 |
+
"text": "Hello, this is a test",
|
| 75 |
+
"voice": "Aaliyah-PlayAI",
|
| 76 |
+
"format": "wav"
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
class GenerateArgumentRequest(BaseModel):
|
| 81 |
+
"""Request pour générer un argument"""
|
| 82 |
+
user_input: str = Field(..., description="Input utilisateur pour générer l'argument")
|
| 83 |
+
conversation_id: Optional[str] = Field(default=None, description="ID de conversation (optionnel)")
|
| 84 |
+
|
| 85 |
+
class Config:
|
| 86 |
+
json_schema_extra = {
|
| 87 |
+
"example": {
|
| 88 |
+
"user_input": "Generate an argument about climate change",
|
| 89 |
+
"conversation_id": "conv_123"
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ===== Routes MCP =====
|
| 95 |
+
|
| 96 |
+
@router.get("/health", summary="Health Check MCP")
|
| 97 |
+
async def mcp_health():
|
| 98 |
+
"""Health check pour le serveur MCP"""
|
| 99 |
+
try:
|
| 100 |
+
# Liste hardcodée des outils disponibles (plus fiable)
|
| 101 |
+
tool_names = [
|
| 102 |
+
"detect_stance",
|
| 103 |
+
"match_keypoint_argument",
|
| 104 |
+
"transcribe_audio",
|
| 105 |
+
"generate_speech",
|
| 106 |
+
"generate_argument",
|
| 107 |
+
"health_check"
|
| 108 |
+
]
|
| 109 |
+
return {
|
| 110 |
+
"status": "healthy",
|
| 111 |
+
"tools": tool_names,
|
| 112 |
+
"tool_count": len(tool_names)
|
| 113 |
+
}
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"MCP health check error: {e}")
|
| 116 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 117 |
+
|
| 118 |
+
@router.get("/tools", response_model=ToolListResponse, summary="Liste des outils MCP")
|
| 119 |
+
async def list_mcp_tools():
|
| 120 |
+
"""Liste tous les outils MCP disponibles"""
|
| 121 |
+
try:
|
| 122 |
+
# Définir manuellement les outils avec leurs schémas
|
| 123 |
+
tool_list = [
|
| 124 |
+
ToolInfo(
|
| 125 |
+
name="detect_stance",
|
| 126 |
+
description="Détecte si un argument est PRO ou CON pour un topic donné",
|
| 127 |
+
input_schema={
|
| 128 |
+
"type": "object",
|
| 129 |
+
"properties": {
|
| 130 |
+
"topic": {"type": "string", "description": "Le sujet du débat"},
|
| 131 |
+
"argument": {"type": "string", "description": "L'argument à analyser"}
|
| 132 |
+
},
|
| 133 |
+
"required": ["topic", "argument"]
|
| 134 |
+
}
|
| 135 |
+
),
|
| 136 |
+
ToolInfo(
|
| 137 |
+
name="match_keypoint_argument",
|
| 138 |
+
description="Détermine si un argument correspond à un keypoint",
|
| 139 |
+
input_schema={
|
| 140 |
+
"type": "object",
|
| 141 |
+
"properties": {
|
| 142 |
+
"argument": {"type": "string", "description": "L'argument à évaluer"},
|
| 143 |
+
"key_point": {"type": "string", "description": "Le keypoint de référence"}
|
| 144 |
+
},
|
| 145 |
+
"required": ["argument", "key_point"]
|
| 146 |
+
}
|
| 147 |
+
),
|
| 148 |
+
ToolInfo(
|
| 149 |
+
name="transcribe_audio",
|
| 150 |
+
description="Convertit un fichier audio en texte",
|
| 151 |
+
input_schema={
|
| 152 |
+
"type": "object",
|
| 153 |
+
"properties": {
|
| 154 |
+
"audio_path": {"type": "string", "description": "Chemin vers le fichier audio"}
|
| 155 |
+
},
|
| 156 |
+
"required": ["audio_path"]
|
| 157 |
+
}
|
| 158 |
+
),
|
| 159 |
+
ToolInfo(
|
| 160 |
+
name="generate_speech",
|
| 161 |
+
description="Convertit du texte en fichier audio",
|
| 162 |
+
input_schema={
|
| 163 |
+
"type": "object",
|
| 164 |
+
"properties": {
|
| 165 |
+
"text": {"type": "string", "description": "Texte à convertir en parole"},
|
| 166 |
+
"voice": {"type": "string", "description": "Voix à utiliser", "default": "Aaliyah-PlayAI"},
|
| 167 |
+
"format": {"type": "string", "description": "Format audio", "default": "wav"}
|
| 168 |
+
},
|
| 169 |
+
"required": ["text"]
|
| 170 |
+
}
|
| 171 |
+
),
|
| 172 |
+
ToolInfo(
|
| 173 |
+
name="generate_argument",
|
| 174 |
+
description="Génère un argument de débat à partir d'un input utilisateur",
|
| 175 |
+
input_schema={
|
| 176 |
+
"type": "object",
|
| 177 |
+
"properties": {
|
| 178 |
+
"user_input": {"type": "string", "description": "Input utilisateur pour générer l'argument"},
|
| 179 |
+
"conversation_id": {"type": "string", "description": "ID de conversation (optionnel)"}
|
| 180 |
+
},
|
| 181 |
+
"required": ["user_input"]
|
| 182 |
+
}
|
| 183 |
+
),
|
| 184 |
+
ToolInfo(
|
| 185 |
+
name="health_check",
|
| 186 |
+
description="Health check pour le serveur MCP",
|
| 187 |
+
input_schema={
|
| 188 |
+
"type": "object",
|
| 189 |
+
"properties": {},
|
| 190 |
+
"required": []
|
| 191 |
+
}
|
| 192 |
+
)
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
return ToolListResponse(tools=tool_list, count=len(tool_list))
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Error listing MCP tools: {e}")
|
| 198 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 199 |
+
|
| 200 |
+
@router.post("/tools/call", response_model=ToolCallResponse, summary="Appeler un outil MCP")
|
| 201 |
+
async def call_mcp_tool(request: ToolCallRequest):
|
| 202 |
+
"""Appelle un outil MCP par son nom avec des arguments"""
|
| 203 |
+
try:
|
| 204 |
+
result = await mcp_server.call_tool(request.tool_name, request.arguments)
|
| 205 |
+
# Gérer différents types de retours MCP
|
| 206 |
+
if isinstance(result, dict):
|
| 207 |
+
# Si le résultat contient une clé "result" avec une liste de ContentBlock
|
| 208 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 209 |
+
content_block = result["result"][0]
|
| 210 |
+
if hasattr(content_block, 'text') and content_block.text:
|
| 211 |
+
try:
|
| 212 |
+
final_result = json.loads(content_block.text)
|
| 213 |
+
except json.JSONDecodeError:
|
| 214 |
+
final_result = {"text": content_block.text}
|
| 215 |
+
else:
|
| 216 |
+
final_result = result
|
| 217 |
+
else:
|
| 218 |
+
final_result = result
|
| 219 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 220 |
+
# Si c'est une liste de ContentBlock, extraire le contenu
|
| 221 |
+
if hasattr(result[0], 'text') and result[0].text:
|
| 222 |
+
try:
|
| 223 |
+
final_result = json.loads(result[0].text)
|
| 224 |
+
except json.JSONDecodeError:
|
| 225 |
+
final_result = {"text": result[0].text}
|
| 226 |
+
else:
|
| 227 |
+
final_result = {"result": result[0] if result else {}}
|
| 228 |
+
else:
|
| 229 |
+
final_result = {"result": result}
|
| 230 |
+
|
| 231 |
+
return ToolCallResponse(
|
| 232 |
+
success=True,
|
| 233 |
+
result=final_result,
|
| 234 |
+
tool_name=request.tool_name
|
| 235 |
+
)
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.error(f"Error calling MCP tool {request.tool_name}: {e}")
|
| 238 |
+
return ToolCallResponse(
|
| 239 |
+
success=False,
|
| 240 |
+
error=str(e),
|
| 241 |
+
tool_name=request.tool_name
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# ===== Routes individuelles pour chaque outil (pour Swagger) =====
|
| 246 |
+
|
| 247 |
+
@router.post("/tools/detect-stance", response_model=DetectStanceResponse, summary="Détecter la stance d'un argument")
|
| 248 |
+
async def mcp_detect_stance(request: DetectStanceRequest):
|
| 249 |
+
"""Détecte si un argument est PRO ou CON pour un topic donné"""
|
| 250 |
+
try:
|
| 251 |
+
# Appeler directement via call_tool (async)
|
| 252 |
+
result = await mcp_server.call_tool("detect_stance", {
|
| 253 |
+
"topic": request.topic,
|
| 254 |
+
"argument": request.argument
|
| 255 |
+
})
|
| 256 |
+
|
| 257 |
+
# Extraire les données du résultat MCP
|
| 258 |
+
parsed_result = None
|
| 259 |
+
if isinstance(result, dict):
|
| 260 |
+
# Si le résultat contient une clé "result" avec une liste de ContentBlock
|
| 261 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 262 |
+
content_block = result["result"][0]
|
| 263 |
+
if hasattr(content_block, 'text') and content_block.text:
|
| 264 |
+
try:
|
| 265 |
+
parsed_result = json.loads(content_block.text)
|
| 266 |
+
except json.JSONDecodeError:
|
| 267 |
+
raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
|
| 268 |
+
else:
|
| 269 |
+
parsed_result = result
|
| 270 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 271 |
+
if hasattr(result[0], 'text') and result[0].text:
|
| 272 |
+
try:
|
| 273 |
+
parsed_result = json.loads(result[0].text)
|
| 274 |
+
except json.JSONDecodeError:
|
| 275 |
+
raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
|
| 276 |
+
else:
|
| 277 |
+
parsed_result = result
|
| 278 |
+
|
| 279 |
+
if not parsed_result:
|
| 280 |
+
raise HTTPException(status_code=500, detail="Empty response from MCP tool")
|
| 281 |
+
|
| 282 |
+
# Construire la réponse structurée
|
| 283 |
+
response = DetectStanceResponse(
|
| 284 |
+
predicted_stance=parsed_result["predicted_stance"],
|
| 285 |
+
confidence=parsed_result["confidence"],
|
| 286 |
+
probability_con=parsed_result["probability_con"],
|
| 287 |
+
probability_pro=parsed_result["probability_pro"]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
logger.info(f"Stance prediction: {response.predicted_stance} (conf={response.confidence:.4f})")
|
| 291 |
+
return response
|
| 292 |
+
|
| 293 |
+
except HTTPException:
|
| 294 |
+
raise
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.error(f"Error in detect_stance: {e}")
|
| 297 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool detect_stance: {e}")
|
| 298 |
+
|
| 299 |
+
@router.post("/tools/match-keypoint", response_model=MatchKeypointResponse, summary="Matcher un argument avec un keypoint")
|
| 300 |
+
async def mcp_match_keypoint(request: MatchKeypointRequest):
|
| 301 |
+
"""Détermine si un argument correspond à un keypoint"""
|
| 302 |
+
try:
|
| 303 |
+
result = await mcp_server.call_tool("match_keypoint_argument", {
|
| 304 |
+
"argument": request.argument,
|
| 305 |
+
"key_point": request.key_point
|
| 306 |
+
})
|
| 307 |
+
|
| 308 |
+
# Extraire les données du résultat MCP
|
| 309 |
+
parsed_result = None
|
| 310 |
+
if isinstance(result, dict):
|
| 311 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 312 |
+
content_block = result["result"][0]
|
| 313 |
+
if hasattr(content_block, 'text') and content_block.text:
|
| 314 |
+
try:
|
| 315 |
+
parsed_result = json.loads(content_block.text)
|
| 316 |
+
except json.JSONDecodeError:
|
| 317 |
+
raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
|
| 318 |
+
else:
|
| 319 |
+
parsed_result = result
|
| 320 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 321 |
+
if hasattr(result[0], 'text') and result[0].text:
|
| 322 |
+
try:
|
| 323 |
+
parsed_result = json.loads(result[0].text)
|
| 324 |
+
except json.JSONDecodeError:
|
| 325 |
+
raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
|
| 326 |
+
else:
|
| 327 |
+
parsed_result = result
|
| 328 |
+
|
| 329 |
+
if not parsed_result:
|
| 330 |
+
raise HTTPException(status_code=500, detail="Empty response from MCP tool")
|
| 331 |
+
|
| 332 |
+
# Construire la réponse structurée
|
| 333 |
+
response = MatchKeypointResponse(
|
| 334 |
+
prediction=parsed_result["prediction"],
|
| 335 |
+
label=parsed_result["label"],
|
| 336 |
+
confidence=parsed_result["confidence"],
|
| 337 |
+
probabilities=parsed_result["probabilities"]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
logger.info(f"Keypoint matching: {response.label} (conf={response.confidence:.4f})")
|
| 341 |
+
return response
|
| 342 |
+
|
| 343 |
+
except HTTPException:
|
| 344 |
+
raise
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"Error in match_keypoint_argument: {e}")
|
| 347 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool match_keypoint_argument: {e}")
|
| 348 |
+
|
| 349 |
+
@router.post("/tools/transcribe-audio", response_model=TranscribeAudioResponse, summary="Transcrire un audio en texte")
|
| 350 |
+
async def mcp_transcribe_audio(request: TranscribeAudioRequest):
|
| 351 |
+
"""Convertit un fichier audio en texte"""
|
| 352 |
+
try:
|
| 353 |
+
result = await mcp_server.call_tool("transcribe_audio", {
|
| 354 |
+
"audio_path": request.audio_path
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
# Extraire le texte du résultat MCP
|
| 358 |
+
transcribed_text = None
|
| 359 |
+
if isinstance(result, dict):
|
| 360 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 361 |
+
content_block = result["result"][0]
|
| 362 |
+
if hasattr(content_block, 'text'):
|
| 363 |
+
transcribed_text = content_block.text
|
| 364 |
+
elif "text" in result:
|
| 365 |
+
transcribed_text = result["text"]
|
| 366 |
+
elif isinstance(result, str):
|
| 367 |
+
transcribed_text = result
|
| 368 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 369 |
+
if hasattr(result[0], 'text'):
|
| 370 |
+
transcribed_text = result[0].text
|
| 371 |
+
else:
|
| 372 |
+
transcribed_text = str(result[0])
|
| 373 |
+
else:
|
| 374 |
+
transcribed_text = str(result)
|
| 375 |
+
|
| 376 |
+
if not transcribed_text:
|
| 377 |
+
raise HTTPException(status_code=500, detail="Empty transcription result from MCP tool")
|
| 378 |
+
|
| 379 |
+
response = TranscribeAudioResponse(text=transcribed_text)
|
| 380 |
+
logger.info(f"Audio transcribed: {len(transcribed_text)} characters")
|
| 381 |
+
return response
|
| 382 |
+
|
| 383 |
+
except FileNotFoundError as e:
|
| 384 |
+
logger.error(f"File not found in transcribe_audio: {e}")
|
| 385 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool transcribe_audio: {e}")
|
| 386 |
+
except HTTPException:
|
| 387 |
+
raise
|
| 388 |
+
except Exception as e:
|
| 389 |
+
logger.error(f"Error in transcribe_audio: {e}")
|
| 390 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool transcribe_audio: {e}")
|
| 391 |
+
|
| 392 |
+
@router.post("/tools/generate-speech", response_model=GenerateSpeechResponse, summary="Générer de la parole à partir de texte")
|
| 393 |
+
async def mcp_generate_speech(request: GenerateSpeechRequest):
|
| 394 |
+
"""Convertit du texte en fichier audio"""
|
| 395 |
+
try:
|
| 396 |
+
result = await mcp_server.call_tool("generate_speech", {
|
| 397 |
+
"text": request.text,
|
| 398 |
+
"voice": request.voice,
|
| 399 |
+
"format": request.format
|
| 400 |
+
})
|
| 401 |
+
|
| 402 |
+
# Extraire le chemin audio du résultat MCP
|
| 403 |
+
audio_path = None
|
| 404 |
+
if isinstance(result, dict):
|
| 405 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 406 |
+
content_block = result["result"][0]
|
| 407 |
+
if hasattr(content_block, 'text'):
|
| 408 |
+
audio_path = content_block.text
|
| 409 |
+
elif "audio_path" in result:
|
| 410 |
+
audio_path = result["audio_path"]
|
| 411 |
+
elif isinstance(result, str):
|
| 412 |
+
audio_path = result
|
| 413 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 414 |
+
if hasattr(result[0], 'text'):
|
| 415 |
+
audio_path = result[0].text
|
| 416 |
+
else:
|
| 417 |
+
audio_path = str(result[0])
|
| 418 |
+
else:
|
| 419 |
+
audio_path = str(result)
|
| 420 |
+
|
| 421 |
+
if not audio_path:
|
| 422 |
+
raise HTTPException(status_code=500, detail="Empty audio path from MCP tool")
|
| 423 |
+
|
| 424 |
+
response = GenerateSpeechResponse(audio_path=audio_path)
|
| 425 |
+
logger.info(f"Speech generated: {audio_path}")
|
| 426 |
+
return response
|
| 427 |
+
|
| 428 |
+
except HTTPException:
|
| 429 |
+
raise
|
| 430 |
+
except Exception as e:
|
| 431 |
+
logger.error(f"Error in generate_speech: {e}")
|
| 432 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool generate_speech: {e}")
|
| 433 |
+
|
| 434 |
+
@router.post("/tools/generate-argument", response_model=GenerateArgumentResponse, summary="Générer un argument de débat")
|
| 435 |
+
async def mcp_generate_argument(request: GenerateArgumentRequest):
|
| 436 |
+
"""Génère un argument de débat à partir d'un input utilisateur"""
|
| 437 |
+
try:
|
| 438 |
+
result = await mcp_server.call_tool("generate_argument", {
|
| 439 |
+
"user_input": request.user_input,
|
| 440 |
+
"conversation_id": request.conversation_id
|
| 441 |
+
})
|
| 442 |
+
|
| 443 |
+
# Extraire l'argument du résultat MCP
|
| 444 |
+
generated_argument = None
|
| 445 |
+
if isinstance(result, dict):
|
| 446 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 447 |
+
content_block = result["result"][0]
|
| 448 |
+
if hasattr(content_block, 'text'):
|
| 449 |
+
generated_argument = content_block.text
|
| 450 |
+
elif "argument" in result:
|
| 451 |
+
generated_argument = result["argument"]
|
| 452 |
+
elif isinstance(result, str):
|
| 453 |
+
generated_argument = result
|
| 454 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 455 |
+
if hasattr(result[0], 'text'):
|
| 456 |
+
generated_argument = result[0].text
|
| 457 |
+
else:
|
| 458 |
+
generated_argument = str(result[0])
|
| 459 |
+
else:
|
| 460 |
+
generated_argument = str(result)
|
| 461 |
+
|
| 462 |
+
if not generated_argument:
|
| 463 |
+
raise HTTPException(status_code=500, detail="Empty argument from MCP tool")
|
| 464 |
+
|
| 465 |
+
response = GenerateArgumentResponse(argument=generated_argument)
|
| 466 |
+
logger.info(f"Argument generated: {len(generated_argument)} characters")
|
| 467 |
+
return response
|
| 468 |
+
|
| 469 |
+
except HTTPException:
|
| 470 |
+
raise
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.error(f"Error in generate_argument: {e}")
|
| 473 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}")
|
| 474 |
+
|
| 475 |
+
@router.get("/tools/health-check", summary="Health check MCP (outil)")
|
| 476 |
+
async def mcp_tool_health_check() -> Dict[str, Any]:
|
| 477 |
+
"""Health check via l'outil MCP"""
|
| 478 |
+
try:
|
| 479 |
+
result = await mcp_server.call_tool("health_check", {})
|
| 480 |
+
# Gérer différents types de retours MCP
|
| 481 |
+
import json
|
| 482 |
+
if isinstance(result, dict):
|
| 483 |
+
# Si le résultat contient une clé "result" avec une liste de ContentBlock
|
| 484 |
+
if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
|
| 485 |
+
content_block = result["result"][0]
|
| 486 |
+
if hasattr(content_block, 'text') and content_block.text:
|
| 487 |
+
try:
|
| 488 |
+
return json.loads(content_block.text)
|
| 489 |
+
except json.JSONDecodeError:
|
| 490 |
+
return {"text": content_block.text}
|
| 491 |
+
return result
|
| 492 |
+
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
| 493 |
+
if hasattr(result[0], 'text') and result[0].text:
|
| 494 |
+
try:
|
| 495 |
+
return json.loads(result[0].text)
|
| 496 |
+
except json.JSONDecodeError:
|
| 497 |
+
return {"text": result[0].text}
|
| 498 |
+
return {"result": result[0] if result else {}}
|
| 499 |
+
return {"result": result}
|
| 500 |
+
except Exception as e:
|
| 501 |
+
logger.error(f"Error in health_check tool: {e}")
|
| 502 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool health_check: {e}")
|
routes/stt_routes.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException
|
| 2 |
+
from services.stt_service import speech_to_text
|
| 3 |
+
from models.stt import STTResponse
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/stt", tags=["Speech To Text"])
|
| 8 |
+
|
| 9 |
+
@router.post("/", response_model=STTResponse)
|
| 10 |
+
async def convert_speech_to_text(file: UploadFile = File(...)):
|
| 11 |
+
"""
|
| 12 |
+
Convert uploaded audio file to text (English only)
|
| 13 |
+
"""
|
| 14 |
+
# Check file type
|
| 15 |
+
if not file.content_type or not file.content_type.startswith('audio/'):
|
| 16 |
+
raise HTTPException(status_code=400, detail="File must be an audio file")
|
| 17 |
+
|
| 18 |
+
# Create temporary file
|
| 19 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
| 20 |
+
temp_path = temp_file.name
|
| 21 |
+
content = await file.read()
|
| 22 |
+
|
| 23 |
+
if len(content) == 0:
|
| 24 |
+
os.unlink(temp_path)
|
| 25 |
+
raise HTTPException(status_code=400, detail="Audio file is empty")
|
| 26 |
+
|
| 27 |
+
temp_file.write(content)
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
# Convert audio to text
|
| 31 |
+
text = speech_to_text(temp_path)
|
| 32 |
+
|
| 33 |
+
# Clean up
|
| 34 |
+
os.unlink(temp_path)
|
| 35 |
+
|
| 36 |
+
return STTResponse(text=text)
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
# Clean up on error
|
| 40 |
+
if os.path.exists(temp_path):
|
| 41 |
+
os.unlink(temp_path)
|
| 42 |
+
raise HTTPException(status_code=500, detail=str(e))
|
routes/tts_routes.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from fastapi.responses import FileResponse
|
| 3 |
+
from models.tts import TTSRequest
|
| 4 |
+
from services.tts_service import text_to_speech
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/tts", tags=["Text To Speech"])
|
| 8 |
+
|
| 9 |
+
@router.post("/")
|
| 10 |
+
async def generate_tts(request: TTSRequest):
|
| 11 |
+
"""
|
| 12 |
+
Convert text to speech (English only)
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
# Generate audio
|
| 16 |
+
audio_path = text_to_speech(
|
| 17 |
+
text=request.text,
|
| 18 |
+
voice=request.voice,
|
| 19 |
+
fmt=request.format
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Verify file exists
|
| 23 |
+
if not Path(audio_path).exists():
|
| 24 |
+
raise HTTPException(status_code=500, detail="Audio file generation failed")
|
| 25 |
+
|
| 26 |
+
# Determine MIME type
|
| 27 |
+
media_type = "audio/wav" if request.format == "wav" else "audio/mpeg"
|
| 28 |
+
|
| 29 |
+
# Return audio file
|
| 30 |
+
return FileResponse(
|
| 31 |
+
path=audio_path,
|
| 32 |
+
filename=f"speech.{request.format}",
|
| 33 |
+
media_type=media_type,
|
| 34 |
+
headers={
|
| 35 |
+
"Content-Disposition": f"attachment; filename=speech.{request.format}"
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise HTTPException(status_code=500, detail=str(e))
|
routes/voice_chat_routes.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
| 2 |
+
from fastapi.responses import StreamingResponse
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import uuid
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
from services.stt_service import speech_to_text
|
| 12 |
+
from services.tts_service import text_to_speech
|
| 13 |
+
from services.chat_service import generate_chat_response
|
| 14 |
+
from models.voice_chat import TextChatRequest, VoiceChatResponse
|
| 15 |
+
|
| 16 |
+
router = APIRouter(prefix="/voice-chat", tags=["Voice Chat"])
|
| 17 |
+
|
| 18 |
+
# Temporary audio cache
|
| 19 |
+
audio_cache = {}
|
| 20 |
+
|
| 21 |
+
@router.post("/voice", response_model=VoiceChatResponse)
|
| 22 |
+
async def voice_chat_endpoint(
|
| 23 |
+
file: UploadFile = File(...),
|
| 24 |
+
conversation_id: Optional[str] = Query(None)
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Complete voice chat endpoint (English only):
|
| 28 |
+
1. STT: Audio → Text
|
| 29 |
+
2. Chatbot: Text → Response
|
| 30 |
+
3. TTS: Response → Audio
|
| 31 |
+
"""
|
| 32 |
+
# 1. Check audio file
|
| 33 |
+
if not file.content_type or not file.content_type.startswith('audio/'):
|
| 34 |
+
raise HTTPException(
|
| 35 |
+
status_code=400,
|
| 36 |
+
detail=f"File must be an audio file. Received: {file.content_type}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# 2. Create conversation ID if not provided
|
| 40 |
+
if not conversation_id:
|
| 41 |
+
conversation_id = str(uuid.uuid4())
|
| 42 |
+
|
| 43 |
+
# 3. Save audio temporarily
|
| 44 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
| 45 |
+
temp_path = temp_file.name
|
| 46 |
+
content = await file.read()
|
| 47 |
+
|
| 48 |
+
if len(content) == 0:
|
| 49 |
+
os.unlink(temp_path)
|
| 50 |
+
raise HTTPException(status_code=400, detail="Audio file is empty")
|
| 51 |
+
|
| 52 |
+
temp_file.write(content)
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# 4. STT: Audio → Text (English)
|
| 56 |
+
user_text = speech_to_text(temp_path)
|
| 57 |
+
|
| 58 |
+
if not user_text or user_text.strip() == "":
|
| 59 |
+
raise HTTPException(
|
| 60 |
+
status_code=400,
|
| 61 |
+
detail="No speech detected in audio."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
print(f"🎤 STT Result: {user_text}")
|
| 65 |
+
|
| 66 |
+
# 5. Generate chatbot response (English)
|
| 67 |
+
chatbot_response = generate_chat_response(
|
| 68 |
+
user_input=user_text,
|
| 69 |
+
conversation_id=conversation_id
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
print(f"🤖 Chatbot Response: {chatbot_response}")
|
| 73 |
+
|
| 74 |
+
# 6. TTS: Response text → Audio (English voice)
|
| 75 |
+
audio_path = text_to_speech(
|
| 76 |
+
text=chatbot_response,
|
| 77 |
+
voice="Aaliyah-PlayAI", # English voice
|
| 78 |
+
fmt="wav"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# 7. Read and store audio
|
| 82 |
+
with open(audio_path, "rb") as audio_file:
|
| 83 |
+
audio_data = audio_file.read()
|
| 84 |
+
|
| 85 |
+
audio_cache[conversation_id] = {
|
| 86 |
+
"audio": audio_data,
|
| 87 |
+
"text": chatbot_response
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# 8. Clean up temporary files
|
| 91 |
+
os.unlink(temp_path)
|
| 92 |
+
if Path(audio_path).exists():
|
| 93 |
+
os.unlink(audio_path)
|
| 94 |
+
|
| 95 |
+
# 9. Return response
|
| 96 |
+
return VoiceChatResponse(
|
| 97 |
+
text_response=chatbot_response,
|
| 98 |
+
audio_url=f"/voice-chat/audio/{conversation_id}",
|
| 99 |
+
conversation_id=conversation_id
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
except HTTPException:
|
| 103 |
+
raise
|
| 104 |
+
except Exception as e:
|
| 105 |
+
# Clean up on error
|
| 106 |
+
if os.path.exists(temp_path):
|
| 107 |
+
os.unlink(temp_path)
|
| 108 |
+
|
| 109 |
+
import traceback
|
| 110 |
+
error_details = traceback.format_exc()
|
| 111 |
+
print(f"❌ Error in voice_chat_endpoint: {error_details}")
|
| 112 |
+
|
| 113 |
+
raise HTTPException(
|
| 114 |
+
status_code=500,
|
| 115 |
+
detail=f"Error during voice processing: {str(e)}"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
@router.post("/text", response_model=VoiceChatResponse)
|
| 119 |
+
async def text_chat_endpoint(request: TextChatRequest):
|
| 120 |
+
"""
|
| 121 |
+
Text chat with audio response (English only)
|
| 122 |
+
For users who prefer to type but hear the response
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
# 1. Create conversation ID if not provided
|
| 126 |
+
if not request.conversation_id:
|
| 127 |
+
conversation_id = str(uuid.uuid4())
|
| 128 |
+
else:
|
| 129 |
+
conversation_id = request.conversation_id
|
| 130 |
+
|
| 131 |
+
# 2. Validate text
|
| 132 |
+
if not request.text or request.text.strip() == "":
|
| 133 |
+
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
| 134 |
+
|
| 135 |
+
print(f"📝 Text received: {request.text}")
|
| 136 |
+
|
| 137 |
+
# 3. Generate chatbot response
|
| 138 |
+
chatbot_response = generate_chat_response(
|
| 139 |
+
user_input=request.text,
|
| 140 |
+
conversation_id=conversation_id
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
print(f"🤖 Chatbot Response: {chatbot_response}")
|
| 144 |
+
|
| 145 |
+
# 4. TTS with English voice
|
| 146 |
+
audio_path = text_to_speech(
|
| 147 |
+
text=chatbot_response,
|
| 148 |
+
voice="Aaliyah-PlayAI",
|
| 149 |
+
fmt="wav"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# 5. Read and store audio
|
| 153 |
+
with open(audio_path, "rb") as audio_file:
|
| 154 |
+
audio_data = audio_file.read()
|
| 155 |
+
|
| 156 |
+
audio_cache[conversation_id] = {
|
| 157 |
+
"audio": audio_data,
|
| 158 |
+
"text": chatbot_response
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
# 6. Clean up
|
| 162 |
+
if Path(audio_path).exists():
|
| 163 |
+
os.unlink(audio_path)
|
| 164 |
+
|
| 165 |
+
# 7. Return response
|
| 166 |
+
return VoiceChatResponse(
|
| 167 |
+
text_response=chatbot_response,
|
| 168 |
+
audio_url=f"/voice-chat/audio/{conversation_id}",
|
| 169 |
+
conversation_id=conversation_id
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
except HTTPException:
|
| 173 |
+
raise
|
| 174 |
+
except Exception as e:
|
| 175 |
+
import traceback
|
| 176 |
+
error_details = traceback.format_exc()
|
| 177 |
+
print(f"❌ Error in text_chat_endpoint: {error_details}")
|
| 178 |
+
|
| 179 |
+
raise HTTPException(
|
| 180 |
+
status_code=500,
|
| 181 |
+
detail=f"Error during chat: {str(e)}"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
@router.get("/audio/{conversation_id}")
|
| 185 |
+
async def get_audio_stream(conversation_id: str):
|
| 186 |
+
"""
|
| 187 |
+
Stream audio of the last response
|
| 188 |
+
"""
|
| 189 |
+
if conversation_id not in audio_cache:
|
| 190 |
+
raise HTTPException(
|
| 191 |
+
status_code=404,
|
| 192 |
+
detail=f"No audio found for conversation {conversation_id}"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
audio_data = audio_cache[conversation_id]["audio"]
|
| 196 |
+
|
| 197 |
+
return StreamingResponse(
|
| 198 |
+
io.BytesIO(audio_data),
|
| 199 |
+
media_type="audio/wav",
|
| 200 |
+
headers={
|
| 201 |
+
"Content-Disposition": f"attachment; filename=response_{conversation_id[:8]}.wav"
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
@router.get("/test")
|
| 206 |
+
async def test_endpoint():
|
| 207 |
+
"""
|
| 208 |
+
Test endpoint to verify API is working
|
| 209 |
+
"""
|
| 210 |
+
return {
|
| 211 |
+
"status": "ok",
|
| 212 |
+
"message": "Voice Chat API is working (English only)",
|
| 213 |
+
"endpoints": {
|
| 214 |
+
"POST /voice-chat/voice": "Voice input → Voice response",
|
| 215 |
+
"POST /voice-chat/text": "Text input → Voice response",
|
| 216 |
+
"GET /voice-chat/audio/{id}": "Get audio response",
|
| 217 |
+
"POST /stt/": "Speech to text",
|
| 218 |
+
"POST /tts/": "Text to speech"
|
| 219 |
+
}
|
| 220 |
+
}
|
services/__init__.py
CHANGED
|
@@ -4,11 +4,22 @@ from .stance_model_manager import StanceModelManager, stance_model_manager
|
|
| 4 |
from .label_model_manager import KpaModelManager, kpa_model_manager
|
| 5 |
from .generate_model_manager import GenerateModelManager, generate_model_manager
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
__all__ = [
|
| 8 |
"StanceModelManager",
|
| 9 |
"stance_model_manager",
|
| 10 |
"KpaModelManager",
|
| 11 |
"kpa_model_manager",
|
|
|
|
| 12 |
"GenerateModelManager",
|
| 13 |
"generate_model_manager",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
]
|
|
|
|
| 4 |
from .label_model_manager import KpaModelManager, kpa_model_manager
|
| 5 |
from .generate_model_manager import GenerateModelManager, generate_model_manager
|
| 6 |
|
| 7 |
+
# NEW imports
|
| 8 |
+
from .stt_service import speech_to_text
|
| 9 |
+
from .tts_service import text_to_speech
|
| 10 |
+
|
| 11 |
__all__ = [
|
| 12 |
"StanceModelManager",
|
| 13 |
"stance_model_manager",
|
| 14 |
"KpaModelManager",
|
| 15 |
"kpa_model_manager",
|
| 16 |
+
<<<<<<< HEAD
|
| 17 |
"GenerateModelManager",
|
| 18 |
"generate_model_manager",
|
| 19 |
+
=======
|
| 20 |
+
|
| 21 |
+
# NEW exports
|
| 22 |
+
"speech_to_text",
|
| 23 |
+
"text_to_speech",
|
| 24 |
+
>>>>>>> 45e145b23965e35eefb7990d09b535073040e40a
|
| 25 |
]
|
services/chat_service.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import requests
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
from config import GROQ_API_KEY, GROQ_CHAT_MODEL
|
| 8 |
+
|
| 9 |
+
# In-memory conversation storage
|
| 10 |
+
conversation_store: Dict[str, List[Dict]] = {}
|
| 11 |
+
|
| 12 |
+
def generate_chat_response(
|
| 13 |
+
user_input: str,
|
| 14 |
+
conversation_id: Optional[str] = None,
|
| 15 |
+
system_prompt: Optional[str] = None
|
| 16 |
+
) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Generate chatbot response for user input (English only)
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
# 1. Validate input
|
| 22 |
+
if not user_input or not isinstance(user_input, str):
|
| 23 |
+
raise ValueError("User input must be a non-empty string")
|
| 24 |
+
|
| 25 |
+
user_input = user_input.strip()
|
| 26 |
+
if len(user_input) == 0:
|
| 27 |
+
return "I didn't hear what you said. Can you repeat?"
|
| 28 |
+
|
| 29 |
+
# 2. Handle conversation
|
| 30 |
+
if not conversation_id:
|
| 31 |
+
conversation_id = str(uuid.uuid4())
|
| 32 |
+
|
| 33 |
+
# Initialize conversation if it doesn't exist
|
| 34 |
+
if conversation_id not in conversation_store:
|
| 35 |
+
conversation_store[conversation_id] = []
|
| 36 |
+
|
| 37 |
+
# 3. Add user message to history
|
| 38 |
+
conversation_store[conversation_id].append({
|
| 39 |
+
"role": "user",
|
| 40 |
+
"content": user_input,
|
| 41 |
+
"timestamp": datetime.now().isoformat()
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
# 4. Prepare system prompt (English only)
|
| 45 |
+
if not system_prompt:
|
| 46 |
+
system_prompt = """You are a friendly and helpful English voice assistant.
|
| 47 |
+
Respond in English only. Keep responses concise (2-3 sentences max),
|
| 48 |
+
natural for speech, and helpful. Be polite and engaging."""
|
| 49 |
+
|
| 50 |
+
# 5. Prepare messages for Groq API
|
| 51 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 52 |
+
|
| 53 |
+
# Add conversation history (last 6 messages)
|
| 54 |
+
history = conversation_store[conversation_id][-6:]
|
| 55 |
+
for msg in history:
|
| 56 |
+
messages.append({"role": msg["role"], "content": msg["content"]})
|
| 57 |
+
|
| 58 |
+
# 6. Call Groq Chat API
|
| 59 |
+
if not GROQ_API_KEY:
|
| 60 |
+
# Fallback if no API key
|
| 61 |
+
response_text = f"Hello! You said: '{user_input}'. I'm a voice assistant configured to respond in English."
|
| 62 |
+
else:
|
| 63 |
+
try:
|
| 64 |
+
response_text = call_groq_chat_api(messages)
|
| 65 |
+
except Exception as api_error:
|
| 66 |
+
print(f"Groq API error: {api_error}")
|
| 67 |
+
response_text = f"I understand you said: {user_input}. How can I help you today?"
|
| 68 |
+
|
| 69 |
+
# 7. Add response to history
|
| 70 |
+
conversation_store[conversation_id].append({
|
| 71 |
+
"role": "assistant",
|
| 72 |
+
"content": response_text,
|
| 73 |
+
"timestamp": datetime.now().isoformat()
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
# Limit history size
|
| 77 |
+
if len(conversation_store[conversation_id]) > 20:
|
| 78 |
+
conversation_store[conversation_id] = conversation_store[conversation_id][-10:]
|
| 79 |
+
|
| 80 |
+
return response_text
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Error in generate_chat_response: {e}")
|
| 84 |
+
return "Sorry, an error occurred. Can you please repeat?"
|
| 85 |
+
|
| 86 |
+
def call_groq_chat_api(messages: List[Dict]) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Call Groq Chat API
|
| 89 |
+
"""
|
| 90 |
+
if not GROQ_API_KEY:
|
| 91 |
+
raise RuntimeError("GROQ_API_KEY is not configured")
|
| 92 |
+
|
| 93 |
+
url = "https://api.groq.com/openai/v1/chat/completions"
|
| 94 |
+
|
| 95 |
+
headers = {
|
| 96 |
+
"Authorization": f"Bearer {GROQ_API_KEY}",
|
| 97 |
+
"Content-Type": "application/json"
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
payload = {
|
| 101 |
+
"model": GROQ_CHAT_MODEL,
|
| 102 |
+
"messages": messages,
|
| 103 |
+
"temperature": 0.7,
|
| 104 |
+
"max_tokens": 300,
|
| 105 |
+
"top_p": 0.9,
|
| 106 |
+
"stream": False
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
| 111 |
+
response.raise_for_status()
|
| 112 |
+
|
| 113 |
+
result = response.json()
|
| 114 |
+
|
| 115 |
+
if "choices" not in result or len(result["choices"]) == 0:
|
| 116 |
+
raise ValueError("Invalid response from Groq API")
|
| 117 |
+
|
| 118 |
+
return result["choices"][0]["message"]["content"]
|
| 119 |
+
|
| 120 |
+
except requests.exceptions.RequestException as e:
|
| 121 |
+
raise Exception(f"Groq API connection error: {str(e)}")
|
| 122 |
+
except KeyError as e:
|
| 123 |
+
raise Exception(f"Invalid response format: {str(e)}")
|
| 124 |
+
|
| 125 |
+
def get_conversation_history(conversation_id: str) -> List[Dict]:
|
| 126 |
+
"""
|
| 127 |
+
Get conversation history
|
| 128 |
+
"""
|
| 129 |
+
return conversation_store.get(conversation_id, [])
|
| 130 |
+
|
| 131 |
+
def clear_conversation(conversation_id: str) -> bool:
|
| 132 |
+
"""
|
| 133 |
+
Clear a conversation
|
| 134 |
+
"""
|
| 135 |
+
if conversation_id in conversation_store:
|
| 136 |
+
del conversation_store[conversation_id]
|
| 137 |
+
return True
|
| 138 |
+
return False
|
services/mcp_service.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Service pour initialiser le serveur MCP avec FastMCP"""
|
| 2 |
+
|
| 3 |
+
from mcp.server.fastmcp import FastMCP
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
|
| 9 |
+
from services.stance_model_manager import stance_model_manager
|
| 10 |
+
from services.label_model_manager import kpa_model_manager
|
| 11 |
+
from services.stt_service import speech_to_text
|
| 12 |
+
from services.tts_service import text_to_speech
|
| 13 |
+
from services.chat_service import generate_chat_response
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Créer l'instance FastMCP
|
| 18 |
+
mcp_server = FastMCP("NLP-Debater-MCP", json_response=True, stateless_http=False) # Stateful pour sessions
|
| 19 |
+
|
| 20 |
+
# Tools (inchangés, OK)
|
| 21 |
+
@mcp_server.tool()
|
| 22 |
+
def detect_stance(topic: str, argument: str) -> Dict[str, Any]:
|
| 23 |
+
if not stance_model_manager.model_loaded:
|
| 24 |
+
raise ValueError("Modèle stance non chargé")
|
| 25 |
+
result = stance_model_manager.predict(topic, argument)
|
| 26 |
+
return {
|
| 27 |
+
"predicted_stance": result["predicted_stance"],
|
| 28 |
+
"confidence": result["confidence"],
|
| 29 |
+
"probability_con": result["probability_con"],
|
| 30 |
+
"probability_pro": result["probability_pro"]
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@mcp_server.tool()
|
| 34 |
+
def match_keypoint_argument(argument: str, key_point: str) -> Dict[str, Any]:
|
| 35 |
+
if not kpa_model_manager.model_loaded:
|
| 36 |
+
raise ValueError("Modèle KPA non chargé")
|
| 37 |
+
result = kpa_model_manager.predict(argument, key_point)
|
| 38 |
+
return {
|
| 39 |
+
"prediction": result["prediction"],
|
| 40 |
+
"label": result["label"],
|
| 41 |
+
"confidence": result["confidence"],
|
| 42 |
+
"probabilities": result["probabilities"]
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
@mcp_server.tool()
|
| 46 |
+
def transcribe_audio(audio_path: str) -> str:
|
| 47 |
+
return speech_to_text(audio_path)
|
| 48 |
+
|
| 49 |
+
@mcp_server.tool()
|
| 50 |
+
def generate_speech(text: str, voice: str = "Aaliyah-PlayAI", format: str = "wav") -> str:
|
| 51 |
+
return text_to_speech(text, voice, format)
|
| 52 |
+
|
| 53 |
+
@mcp_server.tool()
|
| 54 |
+
def generate_argument(user_input: str, conversation_id: str = None) -> str:
|
| 55 |
+
return generate_chat_response(user_input, conversation_id)
|
| 56 |
+
|
| 57 |
+
@mcp_server.resource("debate://prompt")
|
| 58 |
+
def get_debate_prompt() -> str:
|
| 59 |
+
return "Tu es un expert en débat. Génère 3 arguments PRO pour le topic donné. Sois concis et persuasif."
|
| 60 |
+
|
| 61 |
+
# Health tool (enregistré avant l'initialisation)
|
| 62 |
+
@mcp_server.tool()
|
| 63 |
+
def health_check() -> Dict[str, Any]:
|
| 64 |
+
"""Health check pour le serveur MCP"""
|
| 65 |
+
try:
|
| 66 |
+
# Liste hardcodée pour éviter les problèmes avec list_tools()
|
| 67 |
+
tool_names = [
|
| 68 |
+
"detect_stance",
|
| 69 |
+
"match_keypoint_argument",
|
| 70 |
+
"transcribe_audio",
|
| 71 |
+
"generate_speech",
|
| 72 |
+
"generate_argument",
|
| 73 |
+
"health_check"
|
| 74 |
+
]
|
| 75 |
+
except Exception:
|
| 76 |
+
tool_names = []
|
| 77 |
+
return {"status": "healthy", "tools": tool_names}
|
| 78 |
+
|
| 79 |
+
def init_mcp_server(app: FastAPI) -> None:
|
| 80 |
+
"""
|
| 81 |
+
Initialise et monte le serveur MCP sur l'app FastAPI.
|
| 82 |
+
"""
|
| 83 |
+
# CORRIGÉ : Utilise streamable_http_app() qui retourne l'ASGI app
|
| 84 |
+
mcp_app = mcp_server.streamable_http_app() # L'ASGI app pour mounting (gère /health, /tools, etc. nativement)
|
| 85 |
+
|
| 86 |
+
# Monte à /api/v1/mcp - FastAPI gère le lifespan auto
|
| 87 |
+
app.mount("/api/v1/mcp", mcp_app)
|
| 88 |
+
|
| 89 |
+
logger.info("✓ Serveur MCP monté sur /api/v1/mcp avec tools NLP/STT/TTS")
|
services/stt_service.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from config import GROQ_API_KEY, GROQ_STT_MODEL
|
| 3 |
+
|
| 4 |
+
def speech_to_text(audio_file: str) -> str:
|
| 5 |
+
"""
|
| 6 |
+
Convert audio file to text using Groq's Whisper API (English only)
|
| 7 |
+
"""
|
| 8 |
+
if not GROQ_API_KEY:
|
| 9 |
+
raise RuntimeError("GROQ_API_KEY is not set in config")
|
| 10 |
+
|
| 11 |
+
url = "https://api.groq.com/openai/v1/audio/transcriptions"
|
| 12 |
+
|
| 13 |
+
headers = {
|
| 14 |
+
"Authorization": f"Bearer {GROQ_API_KEY}"
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
with open(audio_file, "rb") as audio_data:
|
| 18 |
+
files = {
|
| 19 |
+
"file": (audio_file, audio_data, "audio/wav")
|
| 20 |
+
}
|
| 21 |
+
data = {
|
| 22 |
+
"model": GROQ_STT_MODEL,
|
| 23 |
+
"language": "en", # Force English
|
| 24 |
+
"temperature": 0,
|
| 25 |
+
"response_format": "json"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
response = requests.post(url, headers=headers, files=files, data=data, timeout=30)
|
| 30 |
+
response.raise_for_status()
|
| 31 |
+
|
| 32 |
+
result = response.json()
|
| 33 |
+
return result.get("text", "")
|
| 34 |
+
|
| 35 |
+
except requests.exceptions.RequestException as e:
|
| 36 |
+
raise Exception(f"Groq STT API error: {str(e)}")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
raise Exception(f"Unexpected error in speech_to_text: {str(e)}")
|
services/tts_service.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import uuid
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from config import GROQ_API_KEY, GROQ_TTS_MODEL
|
| 6 |
+
|
| 7 |
+
def text_to_speech(
|
| 8 |
+
text: str,
|
| 9 |
+
voice: str = "Aaliyah-PlayAI",
|
| 10 |
+
fmt: str = "wav"
|
| 11 |
+
) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Convert text to speech using Groq's TTS API (English only)
|
| 14 |
+
"""
|
| 15 |
+
if not GROQ_API_KEY:
|
| 16 |
+
raise RuntimeError("GROQ_API_KEY is not set in config")
|
| 17 |
+
|
| 18 |
+
if not text or not text.strip():
|
| 19 |
+
raise ValueError("Text cannot be empty")
|
| 20 |
+
|
| 21 |
+
url = "https://api.groq.com/openai/v1/audio/speech"
|
| 22 |
+
|
| 23 |
+
headers = {
|
| 24 |
+
"Authorization": f"Bearer {GROQ_API_KEY}",
|
| 25 |
+
"Content-Type": "application/json"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
payload = {
|
| 29 |
+
"model": GROQ_TTS_MODEL,
|
| 30 |
+
"input": text.strip(),
|
| 31 |
+
"voice": voice,
|
| 32 |
+
"response_format": fmt
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# Create temp directory for audio files
|
| 37 |
+
temp_dir = Path("temp_audio")
|
| 38 |
+
temp_dir.mkdir(exist_ok=True)
|
| 39 |
+
|
| 40 |
+
# Unique filename
|
| 41 |
+
output_filename = f"tts_{uuid.uuid4().hex[:8]}.{fmt}"
|
| 42 |
+
output_path = temp_dir / output_filename
|
| 43 |
+
|
| 44 |
+
# Call Groq API
|
| 45 |
+
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
| 46 |
+
response.raise_for_status()
|
| 47 |
+
|
| 48 |
+
# Save audio file
|
| 49 |
+
with open(output_path, "wb") as f:
|
| 50 |
+
f.write(response.content)
|
| 51 |
+
|
| 52 |
+
return str(output_path)
|
| 53 |
+
|
| 54 |
+
except requests.exceptions.RequestException as e:
|
| 55 |
+
raise Exception(f"Groq TTS API error: {str(e)}")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
raise Exception(f"Unexpected error in text_to_speech: {str(e)}")
|