S01Nour
style: remove redundant comment from kpa_model_manager import.
25481b5
raw
history blame
10.7 kB
import sys
from pathlib import Path
import logging
from contextlib import asynccontextmanager
import atexit
import shutil
# --- Config ---
from config import (
API_TITLE, API_DESCRIPTION, API_VERSION,
HUGGINGFACE_API_KEY, HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_LABEL_MODEL_ID,
HOST, PORT, RELOAD,
CORS_ORIGINS, CORS_METHODS, CORS_HEADERS, CORS_CREDENTIALS,
PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL,
GROQ_API_KEY, GROQ_STT_MODEL, GROQ_TTS_MODEL, GROQ_CHAT_MODEL
)
# --- Fonction de nettoyage ---
def cleanup_temp_files():
"""Nettoyer les fichiers temporaires audio au démarrage"""
temp_dir = Path("temp_audio")
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
logger.info("✓ Fichiers temporaires audio nettoyés")
except Exception as e:
logger.warning(f"⚠ Impossible de nettoyer le répertoire temporaire: {e}")
# Appeler au démarrage
cleanup_temp_files()
# Configurer le nettoyage à la fermeture
@atexit.register
def cleanup_on_exit():
temp_dir = Path("temp_audio")
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
logger.info("Nettoyage final des fichiers temporaires")
except:
logger.warning("Échec du nettoyage final")
# --- Import des singletons de services ---
stance_model_manager = None
kpa_model_manager = None
try:
from services.stance_model_manager import stance_model_manager
from services.label_model_manager import kpa_model_manager
logger.info("✓ Gestionnaires de modèles importés")
except ImportError as e:
logger.warning(f"⚠ Impossible d'importer les gestionnaires de modèles: {e}")
# --- Vérification MCP ---
MCP_ENABLED = False
try:
from services.mcp_service import init_mcp_server
MCP_ENABLED = True
logger.info("✓ Modules MCP détectés")
except ImportError as e:
logger.warning(f"⚠ MCP non disponible: {e}")
# --- Lifespan / startup API ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("="*60)
logger.info("🚀 DÉMARRAGE API - Chargement des modèles et vérification des APIs...")
logger.info("="*60)
# Load stance detection model
try:
logger.info(f"Loading stance model from Hugging Face: {HUGGINGFACE_STANCE_MODEL_ID}")
stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
except Exception as e:
logger.error(f"✗ Failed to load stance model: {str(e)}")
logger.error("⚠️ Stance detection endpoints will not work!")
# Load KPA (label) model
try:
logger.info(f"Loading KPA model from Hugging Face: {HUGGINGFACE_LABEL_MODEL_ID}")
kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
except Exception as e:
logger.error(f"✗ Failed to load KPA model: {str(e)}")
logger.error("⚠️ KPA/Label prediction endpoints will not work!")
# Load Generation model
try:
logger.info(f"Loading Generation model from Hugging Face: {HUGGINGFACE_GENERATE_MODEL_ID}")
generate_model_manager.load_model(HUGGINGFACE_GENERATE_MODEL_ID, HUGGINGFACE_API_KEY)
except Exception as e:
logger.error(f"✗ Failed to load Generation model: {str(e)}")
logger.error("⚠️ Generation endpoints will not work!")
logger.info("✓ API startup complete")
logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs")
# Vérifier les clés API
if not GROQ_API_KEY:
logger.warning("⚠ GROQ_API_KEY non configurée. Fonctions STT/TTS désactivées.")
else:
logger.info("✓ GROQ_API_KEY configurée")
if not HUGGINGFACE_API_KEY:
logger.warning("⚠ HUGGINGFACE_API_KEY non configurée. Modèles locaux désactivés.")
else:
logger.info("✓ HUGGINGFACE_API_KEY configurée")
# Précharger les modèles Hugging Face si configuré
if PRELOAD_MODELS_ON_STARTUP:
# Charger stance model
if LOAD_STANCE_MODEL and stance_model_manager and HUGGINGFACE_STANCE_MODEL_ID:
try:
stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY)
logger.info("✓ Modèle de détection de stance chargé")
except Exception as e:
logger.error(f"✗ Échec chargement modèle stance: {e}")
# Charger KPA model
if LOAD_KPA_MODEL and kpa_model_manager and HUGGINGFACE_LABEL_MODEL_ID:
try:
kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY)
logger.info("✓ Modèle KPA chargé")
except Exception as e:
logger.error(f"✗ Échec chargement modèle KPA: {e}")
# Initialiser MCP si disponible
if MCP_ENABLED:
try:
init_mcp_server(app)
logger.info("✓ Serveur MCP initialisé")
except Exception as e:
logger.error(f"✗ Échec initialisation MCP: {e}")
logger.info("="*60)
logger.info("✓ Démarrage terminé. API prête à recevoir des requêtes.")
logger.info(f" STT Model: {GROQ_STT_MODEL}")
logger.info(f" TTS Model: {GROQ_TTS_MODEL}")
logger.info(f" Chat Model: {GROQ_CHAT_MODEL}")
logger.info(f" MCP: {'Activé' if MCP_ENABLED else 'Désactivé'}")
logger.info("="*60)
yield
logger.info("🛑 Arrêt de l'API...")
# Nettoyage final
cleanup_on_exit()
# --- FastAPI app ---
app = FastAPI(
title=API_TITLE,
description=API_DESCRIPTION,
version=API_VERSION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# --- CORS ---
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=CORS_CREDENTIALS,
allow_methods=CORS_METHODS,
allow_headers=CORS_HEADERS,
)
# --- Routes ---
# STT Routes
try:
from routes.stt_routes import router as stt_router
app.include_router(stt_router, prefix="/api/v1/stt", tags=["Speech To Text"])
logger.info("✓ Route STT chargée (Groq Whisper)")
except ImportError as e:
logger.warning(f"⚠ Route STT non trouvée: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement route STT: {e}")
# TTS Routes
try:
from routes.tts_routes import router as tts_router
app.include_router(tts_router, prefix="/api/v1/tts", tags=["Text To Speech"])
logger.info("✓ Route TTS chargée (Groq PlayAI TTS)")
except ImportError as e:
logger.warning(f"⚠ Route TTS non trouvée: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement route TTS: {e}")
# Voice Chat Routes
try:
from routes.voice_chat_routes import router as voice_chat_router
app.include_router(voice_chat_router, tags=["Voice Chat"])
logger.info("✓ Route Voice Chat chargée")
except ImportError as e:
logger.warning(f"⚠ Route Voice Chat non trouvée: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement route Voice Chat: {e}")
# Main API Routes (KPA, Stance, etc.)
try:
from routes import api_router
app.include_router(api_router, prefix="/api/v1")
logger.info("✓ Routes API principales chargées")
except ImportError as e:
logger.warning(f"⚠ Routes API principales non trouvées: {e}")
# Fallback
try:
from routes.label import router as kpa_router
app.include_router(kpa_router, prefix="/api/v1/kpa", tags=["KPA"])
from routes.stance import router as stance_router
app.include_router(stance_router, prefix="/api/v1/stance", tags=["Stance Detection"])
logger.info("✓ Routes KPA et Stance chargées en fallback")
except ImportError:
logger.warning("⚠ Fallback pour KPA/Stance échoué")
except Exception as e:
logger.warning(f"⚠ Échec chargement routes API principales: {e}")
# MCP Routes - Routes FastAPI pour Swagger UI + mount pour compatibilité MCP
if MCP_ENABLED:
try:
from routes.mcp_routes import router as mcp_router
app.include_router(mcp_router) # Pas de prefix car déjà dans le router
logger.info("✓ Routes MCP FastAPI chargées (visibles dans Swagger)")
except ImportError as e:
logger.warning(f"⚠ Routes MCP FastAPI non trouvées: {e}")
except Exception as e:
logger.warning(f"⚠ Échec chargement routes MCP FastAPI: {e}")
logger.info("✓ MCP monté via lifespan (endpoints auto-gérés)")
else:
logger.warning("⚠ MCP désactivé")
# --- Basic routes ---
@app.get("/health", tags=["Health"])
async def health():
health_status = {
"status": "healthy",
"service": "NLP Debater + Groq Voice",
"version": API_VERSION,
"features": {
"stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled",
"tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled",
"chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled",
"stance_model": "loaded" if (stance_model_manager and hasattr(stance_model_manager, 'model_loaded') and stance_model_manager.model_loaded) else "not loaded",
"kpa_model": "loaded" if (kpa_model_manager and hasattr(kpa_model_manager, 'model_loaded') and kpa_model_manager.model_loaded) else "not loaded",
"mcp": "enabled" if MCP_ENABLED else "disabled"
},
"endpoints": {
"mcp": "/api/v1/mcp" if MCP_ENABLED else "disabled"
}
}
return health_status
@app.get("/", tags=["Root"])
async def root():
return RedirectResponse(url="/docs")
# --- Error handlers ---
@app.exception_handler(404)
async def not_found_handler(request, exc):
endpoints = {
"GET /": "Redirige vers /docs (Swagger UI)",
"GET /health": "Health check",
"POST /api/v1/stt/": "Speech to text",
"POST /api/v1/tts/": "Text to speech",
"POST /voice-chat/voice": "Voice chat"
}
if MCP_ENABLED:
endpoints.update({
"GET /api/v1/mcp/health": "Health check MCP",
"GET /api/v1/mcp/tools": "Liste outils MCP",
"POST /api/v1/mcp/tools/call": "Appel d'outil MCP"
})
return {
"error": "Not Found",
"message": f"URL {request.url} non trouvée",
"available_endpoints": endpoints
}
# --- Run server ---
if __name__ == "__main__":
logger.info("="*60)
logger.info(f"Démarrage du serveur sur {HOST}:{PORT}")
logger.info(f"Mode reload: {RELOAD}")
logger.info("="*60)
uvicorn.run(
"main:app",
host=HOST,
port=PORT,
reload=RELOAD,
log_level="info"
)