malek-messaoudii
feat: Refactor stance detection, keypoint matching, and argument generation to utilize dedicated model managers for improved reliability and response handling
76c719b
| """Routes pour exposer MCP via FastAPI pour Swagger UI""" | |
| from fastapi import APIRouter, HTTPException, UploadFile, File | |
| from fastapi.responses import FileResponse | |
| from typing import Dict, Any, Optional | |
| from pydantic import BaseModel, Field | |
| import logging | |
| import json | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| from services.mcp_service import mcp_server | |
| from services.stance_model_manager import stance_model_manager | |
| from services.label_model_manager import kpa_model_manager | |
| from services.generate_model_manager import generate_model_manager | |
| from models.mcp_models import ( | |
| ToolListResponse, | |
| ToolInfo, | |
| ToolCallRequest, | |
| ToolCallResponse, | |
| DetectStanceResponse, | |
| MatchKeypointResponse, | |
| TranscribeAudioResponse, | |
| GenerateSpeechResponse | |
| ) | |
| from models.generate import GenerateRequest, GenerateResponse | |
| from datetime import datetime | |
| router = APIRouter(prefix="/api/v1/mcp", tags=["MCP"]) | |
| logger = logging.getLogger(__name__) | |
| # ===== Models pour chaque outil MCP ===== | |
| class DetectStanceRequest(BaseModel): | |
| """Request pour détecter la stance d'un argument""" | |
| topic: str = Field(..., description="Le sujet du débat") | |
| argument: str = Field(..., description="L'argument à analyser") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "topic": "Climate change is real", | |
| "argument": "Rising global temperatures prove it" | |
| } | |
| } | |
| class MatchKeypointRequest(BaseModel): | |
| """Request pour matcher un argument avec un keypoint""" | |
| argument: str = Field(..., description="L'argument à évaluer") | |
| key_point: str = Field(..., description="Le keypoint de référence") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "argument": "Renewable energy reduces emissions", | |
| "key_point": "Environmental benefits" | |
| } | |
| } | |
| class GenerateSpeechRequest(BaseModel): | |
| """Request pour générer de la parole""" | |
| text: str = Field(..., description="Texte à convertir en parole") | |
| voice: str = Field(default="Aaliyah-PlayAI", description="Voix à utiliser") | |
| format: str = Field(default="wav", description="Format audio (wav, mp3, etc.)") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "text": "Hello, this is a test", | |
| "voice": "Aaliyah-PlayAI", | |
| "format": "wav" | |
| } | |
| } | |
| # ===== Routes MCP ===== | |
| async def mcp_health(): | |
| """Health check pour le serveur MCP""" | |
| try: | |
| # Liste hardcodée des outils disponibles (plus fiable) | |
| tool_names = [ | |
| "detect_stance", | |
| "match_keypoint_argument", | |
| "transcribe_audio", | |
| "generate_speech", | |
| "generate_argument", | |
| "health_check" | |
| ] | |
| return { | |
| "status": "healthy", | |
| "tools": tool_names, | |
| "tool_count": len(tool_names) | |
| } | |
| except Exception as e: | |
| logger.error(f"MCP health check error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_mcp_tools(): | |
| """Liste tous les outils MCP disponibles""" | |
| try: | |
| # Définir manuellement les outils avec leurs schémas | |
| tool_list = [ | |
| ToolInfo( | |
| name="detect_stance", | |
| description="Détecte si un argument est PRO ou CON pour un topic donné", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "topic": {"type": "string", "description": "Le sujet du débat"}, | |
| "argument": {"type": "string", "description": "L'argument à analyser"} | |
| }, | |
| "required": ["topic", "argument"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="match_keypoint_argument", | |
| description="Détermine si un argument correspond à un keypoint", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "argument": {"type": "string", "description": "L'argument à évaluer"}, | |
| "key_point": {"type": "string", "description": "Le keypoint de référence"} | |
| }, | |
| "required": ["argument", "key_point"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="transcribe_audio", | |
| description="Convertit un fichier audio en texte", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "audio_path": {"type": "string", "description": "Chemin vers le fichier audio"} | |
| }, | |
| "required": ["audio_path"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="generate_speech", | |
| description="Convertit du texte en fichier audio", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "text": {"type": "string", "description": "Texte à convertir en parole"}, | |
| "voice": {"type": "string", "description": "Voix à utiliser", "default": "Aaliyah-PlayAI"}, | |
| "format": {"type": "string", "description": "Format audio", "default": "wav"} | |
| }, | |
| "required": ["text"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="generate_argument", | |
| description="Génère un argument de débat pour un topic et une position donnés", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "topic": {"type": "string", "description": "Le sujet du débat"}, | |
| "position": {"type": "string", "description": "La position à prendre (positive ou negative)"} | |
| }, | |
| "required": ["topic", "position"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="health_check", | |
| description="Health check pour le serveur MCP", | |
| input_schema={ | |
| "type": "object", | |
| "properties": {}, | |
| "required": [] | |
| } | |
| ) | |
| ] | |
| return ToolListResponse(tools=tool_list, count=len(tool_list)) | |
| except Exception as e: | |
| logger.error(f"Error listing MCP tools: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def call_mcp_tool(request: ToolCallRequest): | |
| """ | |
| Appelle un outil MCP par son nom avec des arguments | |
| **Exemples d'utilisation:** | |
| 1. **detect_stance** - Détecter la stance d'un argument: | |
| ```json | |
| { | |
| "tool_name": "detect_stance", | |
| "arguments": { | |
| "topic": "Climate change is real", | |
| "argument": "Rising global temperatures prove it" | |
| } | |
| } | |
| ``` | |
| 2. **match_keypoint_argument** - Matcher un argument avec un keypoint: | |
| ```json | |
| { | |
| "tool_name": "match_keypoint_argument", | |
| "arguments": { | |
| "argument": "Renewable energy reduces emissions", | |
| "key_point": "Environmental benefits" | |
| } | |
| } | |
| ``` | |
| 3. **generate_argument** - Générer un argument: | |
| ```json | |
| { | |
| "tool_name": "generate_argument", | |
| "arguments": { | |
| "topic": "Assisted suicide should be legal", | |
| "position": "positive" | |
| } | |
| } | |
| ``` | |
| 4. **transcribe_audio** - Transcrire un audio: | |
| ```json | |
| { | |
| "tool_name": "transcribe_audio", | |
| "arguments": { | |
| "audio_path": "/path/to/audio.wav" | |
| } | |
| } | |
| ``` | |
| 5. **generate_speech** - Générer de la parole: | |
| ```json | |
| { | |
| "tool_name": "generate_speech", | |
| "arguments": { | |
| "text": "Hello, this is a test", | |
| "voice": "Aaliyah-PlayAI", | |
| "format": "wav" | |
| } | |
| } | |
| ``` | |
| """ | |
| try: | |
| result = await mcp_server.call_tool(request.tool_name, request.arguments) | |
| # Gérer différents types de retours MCP | |
| if isinstance(result, dict): | |
| # Si le résultat contient une clé "result" avec une liste de ContentBlock | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text') and content_block.text: | |
| try: | |
| final_result = json.loads(content_block.text) | |
| except json.JSONDecodeError: | |
| final_result = {"text": content_block.text} | |
| else: | |
| final_result = result | |
| else: | |
| final_result = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| # Si c'est une liste de ContentBlock, extraire le contenu | |
| if hasattr(result[0], 'text') and result[0].text: | |
| try: | |
| final_result = json.loads(result[0].text) | |
| except json.JSONDecodeError: | |
| final_result = {"text": result[0].text} | |
| else: | |
| final_result = {"result": result[0] if result else {}} | |
| else: | |
| final_result = {"result": result} | |
| return ToolCallResponse( | |
| success=True, | |
| result=final_result, | |
| tool_name=request.tool_name | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error calling MCP tool {request.tool_name}: {e}") | |
| return ToolCallResponse( | |
| success=False, | |
| error=str(e), | |
| tool_name=request.tool_name | |
| ) | |
| # ===== Routes individuelles pour chaque outil (pour Swagger) ===== | |
| async def mcp_detect_stance(request: DetectStanceRequest): | |
| """Détecte si un argument est PRO ou CON pour un topic donné""" | |
| try: | |
| # Vérifier que le modèle est chargé | |
| if not stance_model_manager.model_loaded: | |
| raise HTTPException(status_code=503, detail="Stance model not loaded") | |
| # Appeler directement le modèle (plus fiable que via MCP) | |
| result = stance_model_manager.predict(request.topic, request.argument) | |
| # Construire la réponse structurée directement depuis le résultat du modèle | |
| response = DetectStanceResponse( | |
| predicted_stance=result["predicted_stance"], | |
| confidence=result["confidence"], | |
| probability_con=result["probability_con"], | |
| probability_pro=result["probability_pro"] | |
| ) | |
| logger.info(f"Stance prediction: {response.predicted_stance} (conf={response.confidence:.4f})") | |
| return response | |
| except HTTPException: | |
| raise | |
| except KeyError as e: | |
| logger.error(f"Missing key in detect_stance response: {e}") | |
| raise HTTPException(status_code=500, detail=f"Invalid response format: missing {e}") | |
| except Exception as e: | |
| logger.error(f"Error in detect_stance: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Error executing tool detect_stance: {e}") | |
| async def mcp_match_keypoint(request: MatchKeypointRequest): | |
| """Détermine si un argument correspond à un keypoint""" | |
| try: | |
| # Vérifier que le modèle est chargé | |
| if not kpa_model_manager.model_loaded: | |
| raise HTTPException(status_code=503, detail="KPA model not loaded") | |
| # Appeler directement le modèle (plus fiable que via MCP) | |
| result = kpa_model_manager.predict(request.argument, request.key_point) | |
| # Construire la réponse structurée directement depuis le résultat du modèle | |
| response = MatchKeypointResponse( | |
| prediction=result["prediction"], | |
| label=result["label"], | |
| confidence=result["confidence"], | |
| probabilities=result["probabilities"] | |
| ) | |
| logger.info(f"Keypoint matching: {response.label} (conf={response.confidence:.4f})") | |
| return response | |
| except HTTPException: | |
| raise | |
| except KeyError as e: | |
| logger.error(f"Missing key in match_keypoint response: {e}") | |
| raise HTTPException(status_code=500, detail=f"Invalid response format: missing {e}") | |
| except Exception as e: | |
| logger.error(f"Error in match_keypoint_argument: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Error executing tool match_keypoint_argument: {e}") | |
| async def mcp_transcribe_audio(file: UploadFile = File(...)): | |
| """Convertit un fichier audio en texte (upload de fichier)""" | |
| # Vérifier le type de fichier | |
| if not file.content_type or not file.content_type.startswith('audio/'): | |
| raise HTTPException(status_code=400, detail="File must be an audio file") | |
| # Créer un fichier temporaire | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
| temp_path = temp_file.name | |
| content = await file.read() | |
| if len(content) == 0: | |
| os.unlink(temp_path) | |
| raise HTTPException(status_code=400, detail="Audio file is empty") | |
| temp_file.write(content) | |
| try: | |
| # Appeler le service MCP avec le chemin temporaire | |
| result = await mcp_server.call_tool("transcribe_audio", { | |
| "audio_path": temp_path | |
| }) | |
| # Extraire le texte du résultat MCP | |
| transcribed_text = None | |
| if isinstance(result, dict): | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text'): | |
| transcribed_text = content_block.text | |
| elif "text" in result: | |
| transcribed_text = result["text"] | |
| elif isinstance(result, str): | |
| transcribed_text = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text'): | |
| transcribed_text = result[0].text | |
| else: | |
| transcribed_text = str(result[0]) | |
| else: | |
| transcribed_text = str(result) | |
| if not transcribed_text: | |
| raise HTTPException(status_code=500, detail="Empty transcription result from MCP tool") | |
| response = TranscribeAudioResponse(text=transcribed_text) | |
| logger.info(f"Audio transcribed: {len(transcribed_text)} characters") | |
| return response | |
| except FileNotFoundError as e: | |
| logger.error(f"File not found in transcribe_audio: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool transcribe_audio: {e}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in transcribe_audio: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool transcribe_audio: {e}") | |
| finally: | |
| # Nettoyer le fichier temporaire | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| async def mcp_generate_speech(request: GenerateSpeechRequest): | |
| """Convertit du texte en fichier audio (téléchargeable)""" | |
| try: | |
| result = await mcp_server.call_tool("generate_speech", { | |
| "text": request.text, | |
| "voice": request.voice, | |
| "format": request.format | |
| }) | |
| # Extraire le chemin audio du résultat MCP | |
| audio_path = None | |
| if isinstance(result, dict): | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text'): | |
| audio_path = content_block.text | |
| elif "audio_path" in result: | |
| audio_path = result["audio_path"] | |
| elif isinstance(result, str): | |
| audio_path = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text'): | |
| audio_path = result[0].text | |
| else: | |
| audio_path = str(result[0]) | |
| else: | |
| audio_path = str(result) | |
| # Nettoyer le chemin si c'est une représentation string d'objet | |
| if audio_path and isinstance(audio_path, str): | |
| # Si c'est une représentation d'objet TextContent, extraire le chemin | |
| if "text='" in audio_path and ".wav" in audio_path: | |
| import re | |
| match = re.search(r"text='([^']+)'", audio_path) | |
| if match: | |
| audio_path = match.group(1) | |
| if not audio_path: | |
| raise HTTPException(status_code=500, detail="Empty audio path from MCP tool") | |
| # Vérifier que le fichier existe | |
| if not Path(audio_path).exists(): | |
| raise HTTPException(status_code=500, detail=f"Audio file not found: {audio_path}") | |
| # Déterminer le type MIME | |
| media_type = "audio/wav" if request.format == "wav" else "audio/mpeg" | |
| # Retourner le fichier pour téléchargement | |
| logger.info(f"Speech generated: {audio_path}") | |
| return FileResponse( | |
| path=audio_path, | |
| filename=f"speech.{request.format}", | |
| media_type=media_type, | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=speech.{request.format}" | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in generate_speech: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool generate_speech: {e}") | |
| async def mcp_generate_argument(request: GenerateRequest): | |
| """Génère un argument de débat pour un topic et une position donnés""" | |
| try: | |
| # Vérifier que le modèle est chargé | |
| if not generate_model_manager.model_loaded: | |
| raise HTTPException(status_code=503, detail="Generation model not loaded") | |
| # Appeler directement le modèle (plus fiable que via MCP) | |
| argument_text = generate_model_manager.generate( | |
| topic=request.topic, | |
| position=request.position | |
| ) | |
| # Construire la réponse structurée | |
| response = GenerateResponse( | |
| topic=request.topic, | |
| position=request.position, | |
| argument=argument_text, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| logger.info(f"Argument generated for topic '{request.topic}' with position '{request.position}': {len(response.argument)} characters") | |
| return response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in generate_argument: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}") | |
| async def mcp_tool_health_check() -> Dict[str, Any]: | |
| """Health check via l'outil MCP""" | |
| try: | |
| result = await mcp_server.call_tool("health_check", {}) | |
| # Gérer différents types de retours MCP | |
| import json | |
| if isinstance(result, dict): | |
| # Si le résultat contient une clé "result" avec une liste de ContentBlock | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text') and content_block.text: | |
| try: | |
| return json.loads(content_block.text) | |
| except json.JSONDecodeError: | |
| return {"text": content_block.text} | |
| return result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text') and result[0].text: | |
| try: | |
| return json.loads(result[0].text) | |
| except json.JSONDecodeError: | |
| return {"text": result[0].text} | |
| return {"result": result[0] if result else {}} | |
| return {"result": result} | |
| except Exception as e: | |
| logger.error(f"Error in health_check tool: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool health_check: {e}") | |