Yassine Mhirsi
Merge branch 'main' of https://huggingface.co/spaces/NLP-Debater-Project/FastAPI-Backend-Models
7f4de42
import sys
from pathlib import Path
import logging
from contextlib import asynccontextmanager
import atexit
import shutil
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
# Setup logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# --- Config ---
from config import (
API_TITLE, API_DESCRIPTION, API_VERSION,
HUGGINGFACE_API_KEY, HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_LABEL_MODEL_ID,
HUGGINGFACE_GENERATE_MODEL_ID,
HOST, PORT, RELOAD,
CORS_ORIGINS, CORS_METHODS, CORS_HEADERS, CORS_CREDENTIALS,
PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL,
GROQ_API_KEY, GROQ_STT_MODEL, GROQ_TTS_MODEL, GROQ_CHAT_MODEL
)
# --- Fonction de nettoyage ---
def cleanup_temp_files():
"""Nettoyer les fichiers temporaires audio au démarrage"""
temp_dir = Path("temp_audio")
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
logger.info("✓ Fichiers temporaires audio nettoyés")
except Exception as e:
logger.warning(f"⚠ Impossible de nettoyer le répertoire temporaire: {e}")
# Appeler au démarrage
cleanup_temp_files()
# Configurer le nettoyage à la fermeture
@atexit.register
def cleanup_on_exit():
temp_dir = Path("temp_audio")
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
logger.info("Nettoyage final des fichiers temporaires")
except:
logger.warning("Échec du nettoyage final")
# --- Import des singletons de services ---
stance_model_manager = None
kpa_model_manager = None
generate_model_manager = None
topic_similarity_service = None
try:
from services.stance_model_manager import stance_model_manager
from services.label_model_manager import kpa_model_manager
from services.generate_model_manager import generate_model_manager
from services.topic_similarity_service import topic_similarity_service
logger.info("✓ Gestionnaires de modèles importés")
except ImportError as e:
logger.warning(f"⚠ Impossible d'importer les gestionnaires de modèles: {e}")
# --- Vérification MCP ---
MCP_ENABLED = False
try:
from services.mcp_service import init_mcp_server
MCP_ENABLED = True
logger.info("✓ Modules MCP détectés")
except ImportError as e:
logger.warning(f"⚠ MCP non disponible: {e}")
# --- Lifespan / startup API ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("="*60)
logger.info("🚀 DÉMARRAGE API - Chargement des modèles et vérification des APIs...")
logger.info("="*60)
# Load stance detection model
try:
logger.info(f"Loading stance model from Hugging Face: {HUGGINGFACE_STANCE_MODEL_ID}")
stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
except Exception as e:
logger.error(f"✗ Failed to load stance model: {str(e)}")
logger.error("⚠️ Stance detection endpoints will not work!")
# Load KPA (label) model
try:
logger.info(f"Loading KPA model from Hugging Face: {HUGGINGFACE_LABEL_MODEL_ID}")
kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
except Exception as e:
logger.error(f"✗ Failed to load KPA model: {str(e)}")
logger.error("⚠️ KPA/Label prediction endpoints will not work!")
# Load Generation model
try:
logger.info(f"Loading Generation model from Hugging Face: {HUGGINGFACE_GENERATE_MODEL_ID}")
generate_model_manager.load_model(HUGGINGFACE_GENERATE_MODEL_ID, HUGGINGFACE_API_KEY)
except Exception as e:
logger.error(f"✗ Failed to load Generation model: {str(e)}")
logger.error("⚠️ Generation endpoints will not work!")
# Initialize Topic Extraction service (uses Groq LLM)
if topic_similarity_service and GROQ_API_KEY:
try:
logger.info("Initializing Topic Extraction service (Groq LLM)...")
topic_similarity_service.initialize()
logger.info("✓ Topic Extraction service initialized")
except Exception as e:
logger.error(f"✗ Failed to initialize Topic Extraction service: {str(e)}")
logger.error("⚠️ Topic extraction endpoints will not work!")
elif not GROQ_API_KEY:
logger.warning("⚠ GROQ_API_KEY not configured. Topic extraction service will not be available.")
logger.info("✓ API startup complete")
logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs")
# Vérifier les clés API
if not GROQ_API_KEY:
logger.warning("⚠ GROQ_API_KEY non configurée. Fonctions STT/TTS désactivées.")
else:
logger.info("✓ GROQ_API_KEY configurée")
if not HUGGINGFACE_API_KEY:
logger.warning("⚠ HUGGINGFACE_API_KEY non configurée. Modèles locaux désactivés.")
else:
logger.info("✓ HUGGINGFACE_API_KEY configurée")
# Précharger les modèles Hugging Face si configuré
if PRELOAD_MODELS_ON_STARTUP:
# Charger stance model
if LOAD_STANCE_MODEL and stance_model_manager and HUGGINGFACE_STANCE_MODEL_ID:
try:
stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
logger.info("✓ Modèle de détection de stance chargé")
except Exception as e:
logger.error(f"✗ Échec chargement modèle stance: {e}")
# Charger KPA model
if LOAD_KPA_MODEL and kpa_model_manager and HUGGINGFACE_LABEL_MODEL_ID:
try:
kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
logger.info("✓ Modèle KPA chargé")
except Exception as e:
logger.error(f"✗ Échec chargement modèle KPA: {e}")
# Initialiser MCP si disponible
if MCP_ENABLED:
try:
init_mcp_server(app)
logger.info("✓ Serveur MCP initialisé")
except Exception as e:
logger.error(f"✗ Échec initialisation MCP: {e}")
logger.info("="*60)
logger.info("✓ Démarrage terminé. API prête à recevoir des requêtes.")
logger.info(f" STT Model: {GROQ_STT_MODEL}")
logger.info(f" TTS Model: {GROQ_TTS_MODEL}")
logger.info(f" Chat Model: {GROQ_CHAT_MODEL}")
logger.info(f" Topic Extraction: {'Initialized' if (topic_similarity_service and topic_similarity_service.initialized) else 'Not initialized'}")
logger.info(f" Voice Chat: {'Available' if GROQ_API_KEY else 'Disabled (no GROQ_API_KEY)'}")
logger.info(f" MCP: {'Activé' if MCP_ENABLED else 'Désactivé'}")
if MCP_ENABLED:
logger.info(f" - Tools: detect_stance, match_keypoint_argument, transcribe_audio, generate_speech, generate_argument, extract_topic, voice_chat, health_check")
logger.info("="*60)
yield
logger.info("🛑 Arrêt de l'API...")
# Nettoyage final
cleanup_on_exit()
# --- FastAPI app ---
app = FastAPI(
title=API_TITLE,
description=API_DESCRIPTION,
version=API_VERSION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# --- CORS ---
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=CORS_CREDENTIALS,
allow_methods=CORS_METHODS,
allow_headers=CORS_HEADERS,
)
# --- Routes ---
# STT Routes
try:
from routes.stt_routes import router as stt_router
app.include_router(stt_router, prefix="/api/v1/stt", tags=["Speech To Text"])
logger.info("✓ Route STT chargée (Groq Whisper)")
except ImportError as e:
logger.warning(f"⚠ Route STT non trouvée: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement route STT: {e}")
# TTS Routes
try:
from routes.tts_routes import router as tts_router
app.include_router(tts_router, prefix="/api/v1/tts", tags=["Text To Speech"])
logger.info("✓ Route TTS chargée (Groq PlayAI TTS)")
except ImportError as e:
logger.warning(f"⚠ Route TTS non trouvée: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement route TTS: {e}")
# Voice Chat Routes
try:
from routes.voice_chat_routes import router as voice_chat_router
app.include_router(voice_chat_router, tags=["Voice Chat"])
logger.info("✓ Route Voice Chat chargée")
except ImportError as e:
logger.warning(f"⚠ Route Voice Chat non trouvée: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement route Voice Chat: {e}")
# Main API Routes (KPA, Stance, etc.)
try:
from routes import api_router
app.include_router(api_router, prefix="/api/v1")
logger.info("✓ Routes API principales chargées")
# Log all routes for debugging
for route in api_router.routes:
if hasattr(route, 'path'):
logger.info(f" - Route: {route.path} (methods: {getattr(route, 'methods', 'N/A')})")
except ImportError as e:
logger.error(f"⚠ Routes API principales non trouvées: {e}")
import traceback
logger.error(traceback.format_exc())
# Fallback
try:
from routes.label import router as kpa_router
app.include_router(kpa_router, prefix="/api/v1/kpa", tags=["KPA"])
from routes.stance import router as stance_router
app.include_router(stance_router, prefix="/api/v1/stance", tags=["Stance Detection"])
logger.info("✓ Routes KPA et Stance chargées en fallback")
except ImportError:
logger.warning("⚠ Fallback pour KPA/Stance échoué")
except Exception as e:
logger.error(f"⚠ Échec chargement routes API principales: {e}")
import traceback
logger.error(traceback.format_exc())
# MCP Routes - Routes FastAPI pour Swagger UI + mount pour compatibilité MCP
if MCP_ENABLED:
try:
from routes.mcp_routes import router as mcp_router
app.include_router(mcp_router) # Pas de prefix car déjà dans le router
logger.info("✓ Routes MCP FastAPI chargées (visibles dans Swagger)")
except ImportError as e:
logger.warning(f"⚠ Routes MCP FastAPI non trouvées: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement routes MCP FastAPI: {e}")
logger.info("✓ MCP monté via lifespan (endpoints auto-gérés)")
else:
logger.warning("⚠ MCP désactivé")
# --- Basic routes ---
@app.get("/health", tags=["Health"])
async def health():
health_status = {
"status": "healthy",
"service": "NLP Debater + Groq Voice",
"version": API_VERSION,
"features": {
"stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled",
"tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled",
"chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled",
"topic_extraction": "initialized" if (topic_similarity_service and hasattr(topic_similarity_service, 'initialized') and topic_similarity_service.initialized) else "not initialized",
"voice_chat": "available" if GROQ_API_KEY else "disabled",
"stance_model": "loaded" if (stance_model_manager and hasattr(stance_model_manager, 'model_loaded') and stance_model_manager.model_loaded) else "not loaded",
"kpa_model": "loaded" if (kpa_model_manager and hasattr(kpa_model_manager, 'model_loaded') and kpa_model_manager.model_loaded) else "not loaded",
"generate_model": "loaded" if (generate_model_manager and hasattr(generate_model_manager, 'model_loaded') and generate_model_manager.model_loaded) else "not loaded",
"mcp": "enabled" if MCP_ENABLED else "disabled"
},
"endpoints": {
"mcp": "/api/v1/mcp" if MCP_ENABLED else "disabled",
"topic_extraction": "/api/v1/topic/extract",
"voice_chat": "/voice-chat/voice or /voice-chat/text",
"mcp_tools": {
"extract_topic": "/api/v1/mcp/tools/extract-topic",
"voice_chat": "/api/v1/mcp/tools/voice-chat"
} if MCP_ENABLED else "disabled"
}
}
return health_status
@app.get("/", tags=["Root"])
async def root():
return RedirectResponse(url="/docs")
# --- Error handlers ---
@app.exception_handler(404)
async def not_found_handler(request, exc):
endpoints = {
"GET /": "Redirige vers /docs (Swagger UI)",
"GET /health": "Health check",
"POST /api/v1/stt/": "Speech to text",
"POST /api/v1/tts/": "Text to speech",
"POST /voice-chat/voice": "Voice chat (audio input)",
"POST /voice-chat/text": "Voice chat (text input)",
"POST /api/v1/topic/extract": "Extract topic from text"
}
if MCP_ENABLED:
endpoints.update({
"GET /api/v1/mcp/health": "Health check MCP",
"GET /api/v1/mcp/tools": "Liste outils MCP",
"POST /api/v1/mcp/tools/call": "Appel d'outil MCP",
"POST /api/v1/mcp/tools/extract-topic": "Extract topic (MCP tool)",
"POST /api/v1/mcp/tools/voice-chat": "Voice chat (MCP tool)",
"POST /api/v1/mcp/tools/detect-stance": "Detect stance (MCP tool)",
"POST /api/v1/mcp/tools/match-keypoint": "Match keypoint (MCP tool)",
"POST /api/v1/mcp/tools/transcribe-audio": "Transcribe audio (MCP tool)",
"POST /api/v1/mcp/tools/generate-speech": "Generate speech (MCP tool)",
"POST /api/v1/mcp/tools/generate-argument": "Generate argument (MCP tool)"
})
return {
"error": "Not Found",
"message": f"URL {request.url} non trouvée",
"available_endpoints": endpoints
}
# --- Run server ---
if __name__ == "__main__":
logger.info("="*60)
logger.info(f"Démarrage du serveur sur {HOST}:{PORT}")
logger.info(f"Mode reload: {RELOAD}")
logger.info("="*60)
uvicorn.run(
"main:app",
host=HOST,
port=PORT,
reload=RELOAD,
log_level="info"
)