malek-messaoudii commited on
Commit
76c719b
·
1 Parent(s): 200de02

feat: Refactor stance detection, keypoint matching, and argument generation to utilize dedicated model managers for improved reliability and response handling

Browse files
Files changed (1) hide show
  1. routes/mcp_routes.py +36 -119
routes/mcp_routes.py CHANGED
@@ -11,6 +11,9 @@ import os
11
  from pathlib import Path
12
 
13
  from services.mcp_service import mcp_server
 
 
 
14
  from models.mcp_models import (
15
  ToolListResponse,
16
  ToolInfo,
@@ -290,49 +293,19 @@ async def call_mcp_tool(request: ToolCallRequest):
290
  async def mcp_detect_stance(request: DetectStanceRequest):
291
  """Détecte si un argument est PRO ou CON pour un topic donné"""
292
  try:
293
- # Appeler directement via call_tool (async)
294
- result = await mcp_server.call_tool("detect_stance", {
295
- "topic": request.topic,
296
- "argument": request.argument
297
- })
298
-
299
- # FastMCP avec json_response=True retourne directement le dict
300
- parsed_result = None
301
- if isinstance(result, dict):
302
- # Vérifier si c'est un dict avec "result" contenant des ContentBlocks
303
- if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
304
- content_block = result["result"][0]
305
- if hasattr(content_block, 'text') and content_block.text:
306
- try:
307
- parsed_result = json.loads(content_block.text)
308
- except json.JSONDecodeError:
309
- raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
310
- else:
311
- # Si pas de text, utiliser le dict directement
312
- parsed_result = result
313
- else:
314
- # Dict direct retourné par le tool
315
- parsed_result = result
316
- elif isinstance(result, (list, tuple)) and len(result) > 0:
317
- if hasattr(result[0], 'text') and result[0].text:
318
- try:
319
- parsed_result = json.loads(result[0].text)
320
- except json.JSONDecodeError:
321
- raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
322
- else:
323
- parsed_result = result[0] if isinstance(result[0], dict) else result
324
- else:
325
- parsed_result = result
326
 
327
- if not parsed_result or not isinstance(parsed_result, dict):
328
- raise HTTPException(status_code=500, detail="Invalid response format from MCP tool")
329
 
330
- # Construire la réponse structurée
331
  response = DetectStanceResponse(
332
- predicted_stance=parsed_result["predicted_stance"],
333
- confidence=parsed_result["confidence"],
334
- probability_con=parsed_result["probability_con"],
335
- probability_pro=parsed_result["probability_pro"]
336
  )
337
 
338
  logger.info(f"Stance prediction: {response.predicted_stance} (conf={response.confidence:.4f})")
@@ -344,52 +317,26 @@ async def mcp_detect_stance(request: DetectStanceRequest):
344
  logger.error(f"Missing key in detect_stance response: {e}")
345
  raise HTTPException(status_code=500, detail=f"Invalid response format: missing {e}")
346
  except Exception as e:
347
- logger.error(f"Error in detect_stance: {e}")
348
  raise HTTPException(status_code=500, detail=f"Error executing tool detect_stance: {e}")
349
 
350
  @router.post("/tools/match-keypoint", response_model=MatchKeypointResponse, summary="Matcher un argument avec un keypoint")
351
  async def mcp_match_keypoint(request: MatchKeypointRequest):
352
  """Détermine si un argument correspond à un keypoint"""
353
  try:
354
- result = await mcp_server.call_tool("match_keypoint_argument", {
355
- "argument": request.argument,
356
- "key_point": request.key_point
357
- })
358
-
359
- # FastMCP avec json_response=True retourne directement le dict
360
- parsed_result = None
361
- if isinstance(result, dict):
362
- if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
363
- content_block = result["result"][0]
364
- if hasattr(content_block, 'text') and content_block.text:
365
- try:
366
- parsed_result = json.loads(content_block.text)
367
- except json.JSONDecodeError:
368
- raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
369
- else:
370
- parsed_result = result
371
- else:
372
- parsed_result = result
373
- elif isinstance(result, (list, tuple)) and len(result) > 0:
374
- if hasattr(result[0], 'text') and result[0].text:
375
- try:
376
- parsed_result = json.loads(result[0].text)
377
- except json.JSONDecodeError:
378
- raise HTTPException(status_code=500, detail="Invalid JSON response from MCP tool")
379
- else:
380
- parsed_result = result[0] if isinstance(result[0], dict) else result
381
- else:
382
- parsed_result = result
383
 
384
- if not parsed_result or not isinstance(parsed_result, dict):
385
- raise HTTPException(status_code=500, detail="Invalid response format from MCP tool")
386
 
387
- # Construire la réponse structurée
388
  response = MatchKeypointResponse(
389
- prediction=parsed_result["prediction"],
390
- label=parsed_result["label"],
391
- confidence=parsed_result["confidence"],
392
- probabilities=parsed_result["probabilities"]
393
  )
394
 
395
  logger.info(f"Keypoint matching: {response.label} (conf={response.confidence:.4f})")
@@ -401,7 +348,7 @@ async def mcp_match_keypoint(request: MatchKeypointRequest):
401
  logger.error(f"Missing key in match_keypoint response: {e}")
402
  raise HTTPException(status_code=500, detail=f"Invalid response format: missing {e}")
403
  except Exception as e:
404
- logger.error(f"Error in match_keypoint_argument: {e}")
405
  raise HTTPException(status_code=500, detail=f"Error executing tool match_keypoint_argument: {e}")
406
 
407
  @router.post("/tools/transcribe-audio", response_model=TranscribeAudioResponse, summary="Transcrire un audio en texte")
@@ -536,50 +483,20 @@ async def mcp_generate_speech(request: GenerateSpeechRequest):
536
  async def mcp_generate_argument(request: GenerateRequest):
537
  """Génère un argument de débat pour un topic et une position donnés"""
538
  try:
539
- result = await mcp_server.call_tool("generate_argument", {
540
- "topic": request.topic,
541
- "position": request.position
542
- })
543
-
544
- # FastMCP avec json_response=True retourne directement le dict
545
- parsed_result = None
546
- if isinstance(result, dict):
547
- if "result" in result and isinstance(result["result"], list) and len(result["result"]) > 0:
548
- content_block = result["result"][0]
549
- if hasattr(content_block, 'text') and content_block.text:
550
- try:
551
- parsed_result = json.loads(content_block.text)
552
- except json.JSONDecodeError:
553
- # Si ce n'est pas du JSON, c'est peut-être juste le texte
554
- parsed_result = {"argument": content_block.text}
555
- else:
556
- parsed_result = result
557
- else:
558
- parsed_result = result
559
- elif isinstance(result, (list, tuple)) and len(result) > 0:
560
- if hasattr(result[0], 'text') and result[0].text:
561
- try:
562
- parsed_result = json.loads(result[0].text)
563
- except json.JSONDecodeError:
564
- parsed_result = {"argument": result[0].text}
565
- else:
566
- parsed_result = result[0] if isinstance(result[0], dict) else result
567
- else:
568
- parsed_result = result
569
-
570
- if not parsed_result or not isinstance(parsed_result, dict):
571
- raise HTTPException(status_code=500, detail="Invalid response format from MCP tool")
572
 
573
- # Extraire l'argument (peut être dans "argument" ou directement dans le dict)
574
- argument_text = parsed_result.get("argument", "")
575
- if not argument_text:
576
- # Essayer de trouver le texte ailleurs dans la réponse
577
- argument_text = str(parsed_result)
578
 
579
  # Construire la réponse structurée
580
  response = GenerateResponse(
581
- topic=parsed_result.get("topic", request.topic),
582
- position=parsed_result.get("position", request.position),
583
  argument=argument_text,
584
  timestamp=datetime.now().isoformat()
585
  )
@@ -590,7 +507,7 @@ async def mcp_generate_argument(request: GenerateRequest):
590
  except HTTPException:
591
  raise
592
  except Exception as e:
593
- logger.error(f"Error in generate_argument: {e}")
594
  raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}")
595
 
596
  @router.get("/tools/health-check", summary="Health check MCP (outil)")
 
11
  from pathlib import Path
12
 
13
  from services.mcp_service import mcp_server
14
+ from services.stance_model_manager import stance_model_manager
15
+ from services.label_model_manager import kpa_model_manager
16
+ from services.generate_model_manager import generate_model_manager
17
  from models.mcp_models import (
18
  ToolListResponse,
19
  ToolInfo,
 
293
  async def mcp_detect_stance(request: DetectStanceRequest):
294
  """Détecte si un argument est PRO ou CON pour un topic donné"""
295
  try:
296
+ # Vérifier que le modèle est chargé
297
+ if not stance_model_manager.model_loaded:
298
+ raise HTTPException(status_code=503, detail="Stance model not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ # Appeler directement le modèle (plus fiable que via MCP)
301
+ result = stance_model_manager.predict(request.topic, request.argument)
302
 
303
+ # Construire la réponse structurée directement depuis le résultat du modèle
304
  response = DetectStanceResponse(
305
+ predicted_stance=result["predicted_stance"],
306
+ confidence=result["confidence"],
307
+ probability_con=result["probability_con"],
308
+ probability_pro=result["probability_pro"]
309
  )
310
 
311
  logger.info(f"Stance prediction: {response.predicted_stance} (conf={response.confidence:.4f})")
 
317
  logger.error(f"Missing key in detect_stance response: {e}")
318
  raise HTTPException(status_code=500, detail=f"Invalid response format: missing {e}")
319
  except Exception as e:
320
+ logger.error(f"Error in detect_stance: {e}", exc_info=True)
321
  raise HTTPException(status_code=500, detail=f"Error executing tool detect_stance: {e}")
322
 
323
  @router.post("/tools/match-keypoint", response_model=MatchKeypointResponse, summary="Matcher un argument avec un keypoint")
324
  async def mcp_match_keypoint(request: MatchKeypointRequest):
325
  """Détermine si un argument correspond à un keypoint"""
326
  try:
327
+ # Vérifier que le modèle est chargé
328
+ if not kpa_model_manager.model_loaded:
329
+ raise HTTPException(status_code=503, detail="KPA model not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ # Appeler directement le modèle (plus fiable que via MCP)
332
+ result = kpa_model_manager.predict(request.argument, request.key_point)
333
 
334
+ # Construire la réponse structurée directement depuis le résultat du modèle
335
  response = MatchKeypointResponse(
336
+ prediction=result["prediction"],
337
+ label=result["label"],
338
+ confidence=result["confidence"],
339
+ probabilities=result["probabilities"]
340
  )
341
 
342
  logger.info(f"Keypoint matching: {response.label} (conf={response.confidence:.4f})")
 
348
  logger.error(f"Missing key in match_keypoint response: {e}")
349
  raise HTTPException(status_code=500, detail=f"Invalid response format: missing {e}")
350
  except Exception as e:
351
+ logger.error(f"Error in match_keypoint_argument: {e}", exc_info=True)
352
  raise HTTPException(status_code=500, detail=f"Error executing tool match_keypoint_argument: {e}")
353
 
354
  @router.post("/tools/transcribe-audio", response_model=TranscribeAudioResponse, summary="Transcrire un audio en texte")
 
483
  async def mcp_generate_argument(request: GenerateRequest):
484
  """Génère un argument de débat pour un topic et une position donnés"""
485
  try:
486
+ # Vérifier que le modèle est chargé
487
+ if not generate_model_manager.model_loaded:
488
+ raise HTTPException(status_code=503, detail="Generation model not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ # Appeler directement le modèle (plus fiable que via MCP)
491
+ argument_text = generate_model_manager.generate(
492
+ topic=request.topic,
493
+ position=request.position
494
+ )
495
 
496
  # Construire la réponse structurée
497
  response = GenerateResponse(
498
+ topic=request.topic,
499
+ position=request.position,
500
  argument=argument_text,
501
  timestamp=datetime.now().isoformat()
502
  )
 
507
  except HTTPException:
508
  raise
509
  except Exception as e:
510
+ logger.error(f"Error in generate_argument: {e}", exc_info=True)
511
  raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}")
512
 
513
  @router.get("/tools/health-check", summary="Health check MCP (outil)")