Yassine Mhirsi
Merge branch 'main' of https://huggingface.co/spaces/NLP-Debater-Project/FastAPI-Backend-Models
7f4de42
| import sys | |
| from pathlib import Path | |
| import logging | |
| from contextlib import asynccontextmanager | |
| import atexit | |
| import shutil | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import RedirectResponse | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --- Config --- | |
| from config import ( | |
| API_TITLE, API_DESCRIPTION, API_VERSION, | |
| HUGGINGFACE_API_KEY, HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_LABEL_MODEL_ID, | |
| HUGGINGFACE_GENERATE_MODEL_ID, | |
| HOST, PORT, RELOAD, | |
| CORS_ORIGINS, CORS_METHODS, CORS_HEADERS, CORS_CREDENTIALS, | |
| PRELOAD_MODELS_ON_STARTUP, LOAD_STANCE_MODEL, LOAD_KPA_MODEL, | |
| GROQ_API_KEY, GROQ_STT_MODEL, GROQ_TTS_MODEL, GROQ_CHAT_MODEL | |
| ) | |
| # --- Fonction de nettoyage --- | |
| def cleanup_temp_files(): | |
| """Nettoyer les fichiers temporaires audio au démarrage""" | |
| temp_dir = Path("temp_audio") | |
| if temp_dir.exists(): | |
| try: | |
| shutil.rmtree(temp_dir) | |
| logger.info("✓ Fichiers temporaires audio nettoyés") | |
| except Exception as e: | |
| logger.warning(f"⚠ Impossible de nettoyer le répertoire temporaire: {e}") | |
| # Appeler au démarrage | |
| cleanup_temp_files() | |
| # Configurer le nettoyage à la fermeture | |
| def cleanup_on_exit(): | |
| temp_dir = Path("temp_audio") | |
| if temp_dir.exists(): | |
| try: | |
| shutil.rmtree(temp_dir) | |
| logger.info("Nettoyage final des fichiers temporaires") | |
| except: | |
| logger.warning("Échec du nettoyage final") | |
| # --- Import des singletons de services --- | |
| stance_model_manager = None | |
| kpa_model_manager = None | |
| generate_model_manager = None | |
| topic_similarity_service = None | |
| try: | |
| from services.stance_model_manager import stance_model_manager | |
| from services.label_model_manager import kpa_model_manager | |
| from services.generate_model_manager import generate_model_manager | |
| from services.topic_similarity_service import topic_similarity_service | |
| logger.info("✓ Gestionnaires de modèles importés") | |
| except ImportError as e: | |
| logger.warning(f"⚠ Impossible d'importer les gestionnaires de modèles: {e}") | |
| # --- Vérification MCP --- | |
| MCP_ENABLED = False | |
| try: | |
| from services.mcp_service import init_mcp_server | |
| MCP_ENABLED = True | |
| logger.info("✓ Modules MCP détectés") | |
| except ImportError as e: | |
| logger.warning(f"⚠ MCP non disponible: {e}") | |
| # --- Lifespan / startup API --- | |
| async def lifespan(app: FastAPI): | |
| logger.info("="*60) | |
| logger.info("🚀 DÉMARRAGE API - Chargement des modèles et vérification des APIs...") | |
| logger.info("="*60) | |
| # Load stance detection model | |
| try: | |
| logger.info(f"Loading stance model from Hugging Face: {HUGGINGFACE_STANCE_MODEL_ID}") | |
| stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY) | |
| except Exception as e: | |
| logger.error(f"✗ Failed to load stance model: {str(e)}") | |
| logger.error("⚠️ Stance detection endpoints will not work!") | |
| # Load KPA (label) model | |
| try: | |
| logger.info(f"Loading KPA model from Hugging Face: {HUGGINGFACE_LABEL_MODEL_ID}") | |
| kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY) | |
| except Exception as e: | |
| logger.error(f"✗ Failed to load KPA model: {str(e)}") | |
| logger.error("⚠️ KPA/Label prediction endpoints will not work!") | |
| # Load Generation model | |
| try: | |
| logger.info(f"Loading Generation model from Hugging Face: {HUGGINGFACE_GENERATE_MODEL_ID}") | |
| generate_model_manager.load_model(HUGGINGFACE_GENERATE_MODEL_ID, HUGGINGFACE_API_KEY) | |
| except Exception as e: | |
| logger.error(f"✗ Failed to load Generation model: {str(e)}") | |
| logger.error("⚠️ Generation endpoints will not work!") | |
| # Initialize Topic Extraction service (uses Groq LLM) | |
| if topic_similarity_service and GROQ_API_KEY: | |
| try: | |
| logger.info("Initializing Topic Extraction service (Groq LLM)...") | |
| topic_similarity_service.initialize() | |
| logger.info("✓ Topic Extraction service initialized") | |
| except Exception as e: | |
| logger.error(f"✗ Failed to initialize Topic Extraction service: {str(e)}") | |
| logger.error("⚠️ Topic extraction endpoints will not work!") | |
| elif not GROQ_API_KEY: | |
| logger.warning("⚠ GROQ_API_KEY not configured. Topic extraction service will not be available.") | |
| logger.info("✓ API startup complete") | |
| logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs") | |
| # Vérifier les clés API | |
| if not GROQ_API_KEY: | |
| logger.warning("⚠ GROQ_API_KEY non configurée. Fonctions STT/TTS désactivées.") | |
| else: | |
| logger.info("✓ GROQ_API_KEY configurée") | |
| if not HUGGINGFACE_API_KEY: | |
| logger.warning("⚠ HUGGINGFACE_API_KEY non configurée. Modèles locaux désactivés.") | |
| else: | |
| logger.info("✓ HUGGINGFACE_API_KEY configurée") | |
| # Précharger les modèles Hugging Face si configuré | |
| if PRELOAD_MODELS_ON_STARTUP: | |
| # Charger stance model | |
| if LOAD_STANCE_MODEL and stance_model_manager and HUGGINGFACE_STANCE_MODEL_ID: | |
| try: | |
| stance_model_manager.load_model(HUGGINGFACE_STANCE_MODEL_ID, HUGGINGFACE_API_KEY) | |
| logger.info("✓ Modèle de détection de stance chargé") | |
| except Exception as e: | |
| logger.error(f"✗ Échec chargement modèle stance: {e}") | |
| # Charger KPA model | |
| if LOAD_KPA_MODEL and kpa_model_manager and HUGGINGFACE_LABEL_MODEL_ID: | |
| try: | |
| kpa_model_manager.load_model(HUGGINGFACE_LABEL_MODEL_ID, HUGGINGFACE_API_KEY) | |
| logger.info("✓ Modèle KPA chargé") | |
| except Exception as e: | |
| logger.error(f"✗ Échec chargement modèle KPA: {e}") | |
| # Initialiser MCP si disponible | |
| if MCP_ENABLED: | |
| try: | |
| init_mcp_server(app) | |
| logger.info("✓ Serveur MCP initialisé") | |
| except Exception as e: | |
| logger.error(f"✗ Échec initialisation MCP: {e}") | |
| logger.info("="*60) | |
| logger.info("✓ Démarrage terminé. API prête à recevoir des requêtes.") | |
| logger.info(f" STT Model: {GROQ_STT_MODEL}") | |
| logger.info(f" TTS Model: {GROQ_TTS_MODEL}") | |
| logger.info(f" Chat Model: {GROQ_CHAT_MODEL}") | |
| logger.info(f" Topic Extraction: {'Initialized' if (topic_similarity_service and topic_similarity_service.initialized) else 'Not initialized'}") | |
| logger.info(f" Voice Chat: {'Available' if GROQ_API_KEY else 'Disabled (no GROQ_API_KEY)'}") | |
| logger.info(f" MCP: {'Activé' if MCP_ENABLED else 'Désactivé'}") | |
| if MCP_ENABLED: | |
| logger.info(f" - Tools: detect_stance, match_keypoint_argument, transcribe_audio, generate_speech, generate_argument, extract_topic, voice_chat, health_check") | |
| logger.info("="*60) | |
| yield | |
| logger.info("🛑 Arrêt de l'API...") | |
| # Nettoyage final | |
| cleanup_on_exit() | |
| # --- FastAPI app --- | |
| app = FastAPI( | |
| title=API_TITLE, | |
| description=API_DESCRIPTION, | |
| version=API_VERSION, | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| openapi_url="/openapi.json" | |
| ) | |
| # --- CORS --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=CORS_ORIGINS, | |
| allow_credentials=CORS_CREDENTIALS, | |
| allow_methods=CORS_METHODS, | |
| allow_headers=CORS_HEADERS, | |
| ) | |
| # --- Routes --- | |
| # STT Routes | |
| try: | |
| from routes.stt_routes import router as stt_router | |
| app.include_router(stt_router, prefix="/api/v1/stt", tags=["Speech To Text"]) | |
| logger.info("✓ Route STT chargée (Groq Whisper)") | |
| except ImportError as e: | |
| logger.warning(f"⚠ Route STT non trouvée: {e}") | |
| except Exception as e: | |
| logger.warning(f"⚠ Échec chargement route STT: {e}") | |
| # TTS Routes | |
| try: | |
| from routes.tts_routes import router as tts_router | |
| app.include_router(tts_router, prefix="/api/v1/tts", tags=["Text To Speech"]) | |
| logger.info("✓ Route TTS chargée (Groq PlayAI TTS)") | |
| except ImportError as e: | |
| logger.warning(f"⚠ Route TTS non trouvée: {e}") | |
| except Exception as e: | |
| logger.warning(f"⚠ Échec chargement route TTS: {e}") | |
| # Voice Chat Routes | |
| try: | |
| from routes.voice_chat_routes import router as voice_chat_router | |
| app.include_router(voice_chat_router, tags=["Voice Chat"]) | |
| logger.info("✓ Route Voice Chat chargée") | |
| except ImportError as e: | |
| logger.warning(f"⚠ Route Voice Chat non trouvée: {e}") | |
| except Exception as e: | |
| logger.warning(f"⚠ Échec chargement route Voice Chat: {e}") | |
| # Main API Routes (KPA, Stance, etc.) | |
| try: | |
| from routes import api_router | |
| app.include_router(api_router, prefix="/api/v1") | |
| logger.info("✓ Routes API principales chargées") | |
| # Log all routes for debugging | |
| for route in api_router.routes: | |
| if hasattr(route, 'path'): | |
| logger.info(f" - Route: {route.path} (methods: {getattr(route, 'methods', 'N/A')})") | |
| except ImportError as e: | |
| logger.error(f"⚠ Routes API principales non trouvées: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| # Fallback | |
| try: | |
| from routes.label import router as kpa_router | |
| app.include_router(kpa_router, prefix="/api/v1/kpa", tags=["KPA"]) | |
| from routes.stance import router as stance_router | |
| app.include_router(stance_router, prefix="/api/v1/stance", tags=["Stance Detection"]) | |
| logger.info("✓ Routes KPA et Stance chargées en fallback") | |
| except ImportError: | |
| logger.warning("⚠ Fallback pour KPA/Stance échoué") | |
| except Exception as e: | |
| logger.error(f"⚠ Échec chargement routes API principales: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| # MCP Routes - Routes FastAPI pour Swagger UI + mount pour compatibilité MCP | |
| if MCP_ENABLED: | |
| try: | |
| from routes.mcp_routes import router as mcp_router | |
| app.include_router(mcp_router) # Pas de prefix car déjà dans le router | |
| logger.info("✓ Routes MCP FastAPI chargées (visibles dans Swagger)") | |
| except ImportError as e: | |
| logger.warning(f"⚠ Routes MCP FastAPI non trouvées: {e}") | |
| except Exception as e: | |
| logger.warning(f"⚠ Échec chargement routes MCP FastAPI: {e}") | |
| logger.info("✓ MCP monté via lifespan (endpoints auto-gérés)") | |
| else: | |
| logger.warning("⚠ MCP désactivé") | |
| # --- Basic routes --- | |
| async def health(): | |
| health_status = { | |
| "status": "healthy", | |
| "service": "NLP Debater + Groq Voice", | |
| "version": API_VERSION, | |
| "features": { | |
| "stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled", | |
| "tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled", | |
| "chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled", | |
| "topic_extraction": "initialized" if (topic_similarity_service and hasattr(topic_similarity_service, 'initialized') and topic_similarity_service.initialized) else "not initialized", | |
| "voice_chat": "available" if GROQ_API_KEY else "disabled", | |
| "stance_model": "loaded" if (stance_model_manager and hasattr(stance_model_manager, 'model_loaded') and stance_model_manager.model_loaded) else "not loaded", | |
| "kpa_model": "loaded" if (kpa_model_manager and hasattr(kpa_model_manager, 'model_loaded') and kpa_model_manager.model_loaded) else "not loaded", | |
| "generate_model": "loaded" if (generate_model_manager and hasattr(generate_model_manager, 'model_loaded') and generate_model_manager.model_loaded) else "not loaded", | |
| "mcp": "enabled" if MCP_ENABLED else "disabled" | |
| }, | |
| "endpoints": { | |
| "mcp": "/api/v1/mcp" if MCP_ENABLED else "disabled", | |
| "topic_extraction": "/api/v1/topic/extract", | |
| "voice_chat": "/voice-chat/voice or /voice-chat/text", | |
| "mcp_tools": { | |
| "extract_topic": "/api/v1/mcp/tools/extract-topic", | |
| "voice_chat": "/api/v1/mcp/tools/voice-chat" | |
| } if MCP_ENABLED else "disabled" | |
| } | |
| } | |
| return health_status | |
| async def root(): | |
| return RedirectResponse(url="/docs") | |
| # --- Error handlers --- | |
| async def not_found_handler(request, exc): | |
| endpoints = { | |
| "GET /": "Redirige vers /docs (Swagger UI)", | |
| "GET /health": "Health check", | |
| "POST /api/v1/stt/": "Speech to text", | |
| "POST /api/v1/tts/": "Text to speech", | |
| "POST /voice-chat/voice": "Voice chat (audio input)", | |
| "POST /voice-chat/text": "Voice chat (text input)", | |
| "POST /api/v1/topic/extract": "Extract topic from text" | |
| } | |
| if MCP_ENABLED: | |
| endpoints.update({ | |
| "GET /api/v1/mcp/health": "Health check MCP", | |
| "GET /api/v1/mcp/tools": "Liste outils MCP", | |
| "POST /api/v1/mcp/tools/call": "Appel d'outil MCP", | |
| "POST /api/v1/mcp/tools/extract-topic": "Extract topic (MCP tool)", | |
| "POST /api/v1/mcp/tools/voice-chat": "Voice chat (MCP tool)", | |
| "POST /api/v1/mcp/tools/detect-stance": "Detect stance (MCP tool)", | |
| "POST /api/v1/mcp/tools/match-keypoint": "Match keypoint (MCP tool)", | |
| "POST /api/v1/mcp/tools/transcribe-audio": "Transcribe audio (MCP tool)", | |
| "POST /api/v1/mcp/tools/generate-speech": "Generate speech (MCP tool)", | |
| "POST /api/v1/mcp/tools/generate-argument": "Generate argument (MCP tool)" | |
| }) | |
| return { | |
| "error": "Not Found", | |
| "message": f"URL {request.url} non trouvée", | |
| "available_endpoints": endpoints | |
| } | |
| # --- Run server --- | |
| if __name__ == "__main__": | |
| logger.info("="*60) | |
| logger.info(f"Démarrage du serveur sur {HOST}:{PORT}") | |
| logger.info(f"Mode reload: {RELOAD}") | |
| logger.info("="*60) | |
| uvicorn.run( | |
| "main:app", | |
| host=HOST, | |
| port=PORT, | |
| reload=RELOAD, | |
| log_level="info" | |
| ) |