malek-messaoudii
Add response models for new MCP tools: DetectStance, MatchKeypoint, TranscribeAudio, GenerateSpeech, and GenerateArgument
45e145b
| """Routes pour exposer MCP via FastAPI pour Swagger UI""" | |
| from fastapi import APIRouter, HTTPException | |
| from typing import Dict, Any, Optional | |
| from pydantic import BaseModel, Field | |
| import logging | |
| import json | |
| from services.mcp_service import mcp_server | |
| from models.mcp_models import ( | |
| ToolListResponse, | |
| ToolInfo, | |
| ToolCallRequest, | |
| ToolCallResponse, | |
| DetectStanceResponse, | |
| MatchKeypointResponse, | |
| TranscribeAudioResponse, | |
| GenerateSpeechResponse, | |
| GenerateArgumentResponse | |
| ) | |
| router = APIRouter(prefix="/api/v1/mcp", tags=["MCP"]) | |
| logger = logging.getLogger(__name__) | |
| # ===== Models pour chaque outil MCP ===== | |
| class DetectStanceRequest(BaseModel): | |
| """Request pour détecter la stance d'un argument""" | |
| topic: str = Field(..., description="Le sujet du débat") | |
| argument: str = Field(..., description="L'argument à analyser") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "topic": "Climate change is real", | |
| "argument": "Rising global temperatures prove it" | |
| } | |
| } | |
| class MatchKeypointRequest(BaseModel): | |
| """Request pour matcher un argument avec un keypoint""" | |
| argument: str = Field(..., description="L'argument à évaluer") | |
| key_point: str = Field(..., description="Le keypoint de référence") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "argument": "Renewable energy reduces emissions", | |
| "key_point": "Environmental benefits" | |
| } | |
| } | |
| class TranscribeAudioRequest(BaseModel): | |
| """Request pour transcrire un audio""" | |
| audio_path: str = Field(..., description="Chemin vers le fichier audio") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "audio_path": "/path/to/audio.wav" | |
| } | |
| } | |
| class GenerateSpeechRequest(BaseModel): | |
| """Request pour générer de la parole""" | |
| text: str = Field(..., description="Texte à convertir en parole") | |
| voice: str = Field(default="Aaliyah-PlayAI", description="Voix à utiliser") | |
| format: str = Field(default="wav", description="Format audio (wav, mp3, etc.)") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "text": "Hello, this is a test", | |
| "voice": "Aaliyah-PlayAI", | |
| "format": "wav" | |
| } | |
| } | |
| class GenerateArgumentRequest(BaseModel): | |
| """Request pour générer un argument""" | |
| user_input: str = Field(..., description="Input utilisateur pour générer l'argument") | |
| conversation_id: Optional[str] = Field(default=None, description="ID de conversation (optionnel)") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "user_input": "Generate an argument about climate change", | |
| "conversation_id": "conv_123" | |
| } | |
| } | |
| # ===== Routes MCP ===== | |
| async def mcp_health(): | |
| """Health check pour le serveur MCP""" | |
| try: | |
| # Liste hardcodée des outils disponibles (plus fiable) | |
| tool_names = [ | |
| "detect_stance", | |
| "match_keypoint_argument", | |
| "transcribe_audio", | |
| "generate_speech", | |
| "generate_argument", | |
| "health_check" | |
| ] | |
| return { | |
| "status": "healthy", | |
| "tools": tool_names, | |
| "tool_count": len(tool_names) | |
| } | |
| except Exception as e: | |
| logger.error(f"MCP health check error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_mcp_tools(): | |
| """Liste tous les outils MCP disponibles""" | |
| try: | |
| # Définir manuellement les outils avec leurs schémas | |
| tool_list = [ | |
| ToolInfo( | |
| name="detect_stance", | |
| description="Détecte si un argument est PRO ou CON pour un topic donné", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "topic": {"type": "string", "description": "Le sujet du débat"}, | |
| "argument": {"type": "string", "description": "L'argument à analyser"} | |
| }, | |
| "required": ["topic", "argument"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="match_keypoint_argument", | |
| description="Détermine si un argument correspond à un keypoint", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "argument": {"type": "string", "description": "L'argument à évaluer"}, | |
| "key_point": {"type": "string", "description": "Le keypoint de référence"} | |
| }, | |
| "required": ["argument", "key_point"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="transcribe_audio", | |
| description="Convertit un fichier audio en texte", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "audio_path": {"type": "string", "description": "Chemin vers le fichier audio"} | |
| }, | |
| "required": ["audio_path"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="generate_speech", | |
| description="Convertit du texte en fichier audio", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "text": {"type": "string", "description": "Texte à convertir en parole"}, | |
| "voice": {"type": "string", "description": "Voix à utiliser", "default": "Aaliyah-PlayAI"}, | |
| "format": {"type": "string", "description": "Format audio", "default": "wav"} | |
| }, | |
| "required": ["text"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="generate_argument", | |
| description="Génère un argument de débat à partir d'un input utilisateur", | |
| input_schema={ | |
| "type": "object", | |
| "properties": { | |
| "user_input": {"type": "string", "description": "Input utilisateur pour générer l'argument"}, | |
| "conversation_id": {"type": "string", "description": "ID de conversation (optionnel)"} | |
| }, | |
| "required": ["user_input"] | |
| } | |
| ), | |
| ToolInfo( | |
| name="health_check", | |
| description="Health check pour le serveur MCP", | |
| input_schema={ | |
| "type": "object", | |
| "properties": {}, | |
| "required": [] | |
| } | |
| ) | |
| ] | |
| return ToolListResponse(tools=tool_list, count=len(tool_list)) | |
| except Exception as e: | |
| logger.error(f"Error listing MCP tools: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def call_mcp_tool(request: ToolCallRequest): | |
| """Appelle un outil MCP par son nom avec des arguments""" | |
| try: | |
| result = await mcp_server.call_tool(request.tool_name, request.arguments) | |
| # Gérer différents types de retours MCP | |
| if isinstance(result, dict): | |
| # Si le résultat contient une clé "result" avec une liste de ContentBlock | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text') and content_block.text: | |
| try: | |
| final_result = json.loads(content_block.text) | |
| except json.JSONDecodeError: | |
| final_result = {"text": content_block.text} | |
| else: | |
| final_result = result | |
| else: | |
| final_result = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| # Si c'est une liste de ContentBlock, extraire le contenu | |
| if hasattr(result[0], 'text') and result[0].text: | |
| try: | |
| final_result = json.loads(result[0].text) | |
| except json.JSONDecodeError: | |
| final_result = {"text": result[0].text} | |
| else: | |
| final_result = {"result": result[0] if result else {}} | |
| else: | |
| final_result = {"result": result} | |
| return ToolCallResponse( | |
| success=True, | |
| result=final_result, | |
| tool_name=request.tool_name | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error calling MCP tool {request.tool_name}: {e}") | |
| return ToolCallResponse( | |
| success=False, | |
| error=str(e), | |
| tool_name=request.tool_name | |
| ) | |
| # ===== Routes individuelles pour chaque outil (pour Swagger) ===== | |
| async def mcp_detect_stance(request: DetectStanceRequest): | |
| """Détecte si un argument est PRO ou CON pour un topic donné""" | |
| try: | |
| # Appeler directement via call_tool (async) | |
| result = await mcp_server.call_tool("detect_stance", { | |
| "topic": request.topic, | |
| "argument": request.argument | |
| }) | |
| # Extraire les données du résultat MCP | |
| parsed_result = None | |
| if isinstance(result, dict): | |
| # Si le résultat contient une clé "result" avec une liste de ContentBlock | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text') and content_block.text: | |
| try: | |
| parsed_result = json.loads(content_block.text) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool") | |
| else: | |
| parsed_result = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text') and result[0].text: | |
| try: | |
| parsed_result = json.loads(result[0].text) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool") | |
| else: | |
| parsed_result = result | |
| if not parsed_result: | |
| raise HTTPException(status_code=500, detail="Empty response from MCP tool") | |
| # Construire la réponse structurée | |
| response = DetectStanceResponse( | |
| predicted_stance=parsed_result["predicted_stance"], | |
| confidence=parsed_result["confidence"], | |
| probability_con=parsed_result["probability_con"], | |
| probability_pro=parsed_result["probability_pro"] | |
| ) | |
| logger.info(f"Stance prediction: {response.predicted_stance} (conf={response.confidence:.4f})") | |
| return response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in detect_stance: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool detect_stance: {e}") | |
| async def mcp_match_keypoint(request: MatchKeypointRequest): | |
| """Détermine si un argument correspond à un keypoint""" | |
| try: | |
| result = await mcp_server.call_tool("match_keypoint_argument", { | |
| "argument": request.argument, | |
| "key_point": request.key_point | |
| }) | |
| # Extraire les données du résultat MCP | |
| parsed_result = None | |
| if isinstance(result, dict): | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text') and content_block.text: | |
| try: | |
| parsed_result = json.loads(content_block.text) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool") | |
| else: | |
| parsed_result = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text') and result[0].text: | |
| try: | |
| parsed_result = json.loads(result[0].text) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool") | |
| else: | |
| parsed_result = result | |
| if not parsed_result: | |
| raise HTTPException(status_code=500, detail="Empty response from MCP tool") | |
| # Construire la réponse structurée | |
| response = MatchKeypointResponse( | |
| prediction=parsed_result["prediction"], | |
| label=parsed_result["label"], | |
| confidence=parsed_result["confidence"], | |
| probabilities=parsed_result["probabilities"] | |
| ) | |
| logger.info(f"Keypoint matching: {response.label} (conf={response.confidence:.4f})") | |
| return response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in match_keypoint_argument: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool match_keypoint_argument: {e}") | |
| async def mcp_transcribe_audio(request: TranscribeAudioRequest): | |
| """Convertit un fichier audio en texte""" | |
| try: | |
| result = await mcp_server.call_tool("transcribe_audio", { | |
| "audio_path": request.audio_path | |
| }) | |
| # Extraire le texte du résultat MCP | |
| transcribed_text = None | |
| if isinstance(result, dict): | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text'): | |
| transcribed_text = content_block.text | |
| elif "text" in result: | |
| transcribed_text = result["text"] | |
| elif isinstance(result, str): | |
| transcribed_text = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text'): | |
| transcribed_text = result[0].text | |
| else: | |
| transcribed_text = str(result[0]) | |
| else: | |
| transcribed_text = str(result) | |
| if not transcribed_text: | |
| raise HTTPException(status_code=500, detail="Empty transcription result from MCP tool") | |
| response = TranscribeAudioResponse(text=transcribed_text) | |
| logger.info(f"Audio transcribed: {len(transcribed_text)} characters") | |
| return response | |
| except FileNotFoundError as e: | |
| logger.error(f"File not found in transcribe_audio: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool transcribe_audio: {e}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in transcribe_audio: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool transcribe_audio: {e}") | |
| async def mcp_generate_speech(request: GenerateSpeechRequest): | |
| """Convertit du texte en fichier audio""" | |
| try: | |
| result = await mcp_server.call_tool("generate_speech", { | |
| "text": request.text, | |
| "voice": request.voice, | |
| "format": request.format | |
| }) | |
| # Extraire le chemin audio du résultat MCP | |
| audio_path = None | |
| if isinstance(result, dict): | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text'): | |
| audio_path = content_block.text | |
| elif "audio_path" in result: | |
| audio_path = result["audio_path"] | |
| elif isinstance(result, str): | |
| audio_path = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text'): | |
| audio_path = result[0].text | |
| else: | |
| audio_path = str(result[0]) | |
| else: | |
| audio_path = str(result) | |
| if not audio_path: | |
| raise HTTPException(status_code=500, detail="Empty audio path from MCP tool") | |
| response = GenerateSpeechResponse(audio_path=audio_path) | |
| logger.info(f"Speech generated: {audio_path}") | |
| return response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in generate_speech: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool generate_speech: {e}") | |
| async def mcp_generate_argument(request: GenerateArgumentRequest): | |
| """Génère un argument de débat à partir d'un input utilisateur""" | |
| try: | |
| result = await mcp_server.call_tool("generate_argument", { | |
| "user_input": request.user_input, | |
| "conversation_id": request.conversation_id | |
| }) | |
| # Extraire l'argument du résultat MCP | |
| generated_argument = None | |
| if isinstance(result, dict): | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text'): | |
| generated_argument = content_block.text | |
| elif "argument" in result: | |
| generated_argument = result["argument"] | |
| elif isinstance(result, str): | |
| generated_argument = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text'): | |
| generated_argument = result[0].text | |
| else: | |
| generated_argument = str(result[0]) | |
| else: | |
| generated_argument = str(result) | |
| if not generated_argument: | |
| raise HTTPException(status_code=500, detail="Empty argument from MCP tool") | |
| response = GenerateArgumentResponse(argument=generated_argument) | |
| logger.info(f"Argument generated: {len(generated_argument)} characters") | |
| return response | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in generate_argument: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}") | |
| async def mcp_tool_health_check() -> Dict[str, Any]: | |
| """Health check via l'outil MCP""" | |
| try: | |
| result = await mcp_server.call_tool("health_check", {}) | |
| # Gérer différents types de retours MCP | |
| import json | |
| if isinstance(result, dict): | |
| # Si le résultat contient une clé "result" avec une liste de ContentBlock | |
| if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0: | |
| content_block = result["result"][0] | |
| if hasattr(content_block, 'text') and content_block.text: | |
| try: | |
| return json.loads(content_block.text) | |
| except json.JSONDecodeError: | |
| return {"text": content_block.text} | |
| return result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if hasattr(result[0], 'text') and result[0].text: | |
| try: | |
| return json.loads(result[0].text) | |
| except json.JSONDecodeError: | |
| return {"text": result[0].text} | |
| return {"result": result[0] if result else {}} | |
| return {"result": result} | |
| except Exception as e: | |
| logger.error(f"Error in health_check tool: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error executing tool health_check: {e}") | |