"""Keypoint–Argument Matching Endpoints""" from fastapi import APIRouter, HTTPException from datetime import datetime import logging from models import ( PredictionRequest, PredictionResponse, BatchPredictionRequest, BatchPredictionResponse ) from services import kpa_model_manager router = APIRouter() logger = logging.getLogger(__name__) @router.get("/model-info", tags=["KPA"]) async def get_model_info(): """ Return information about the loaded KPA model. """ try: model_info = kpa_model_manager.get_model_info() return { "model_name": model_info.get("model_name", "unknown"), "device": model_info.get("device", "cpu"), "max_length": model_info.get("max_length", 256), "num_labels": model_info.get("num_labels", 2), "loaded": model_info.get("loaded", False), "timestamp": datetime.now().isoformat() } except Exception as e: logger.error(f"Model info error: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") @router.post("/predict", response_model=PredictionResponse, tags=["KPA"]) async def predict_kpa(request: PredictionRequest): """ Predict keypoint-argument matching for a single pair. - **argument**: The argument text - **key_point**: The key point to evaluate Returns the predicted class (apparie / non_apparie) with probabilities. """ try: result = kpa_model_manager.predict( argument=request.argument, key_point=request.key_point ) response = PredictionResponse( prediction=result["prediction"], confidence=result["confidence"], label=result["label"], probabilities=result["probabilities"] ) logger.info( f"KPA Prediction: {response.label} " f"(conf={response.confidence:.4f})" ) return response except Exception as e: logger.error(f"KPA prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"]) async def batch_predict_kpa(request: BatchPredictionRequest): """ Predict keypoint-argument matching for multiple argument/keypoint pairs. - **pairs**: List of items to classify Returns predictions for all pairs. """ try: results = [] for item in request.pairs: try: result = kpa_model_manager.predict( argument=item.argument, key_point=item.key_point ) response = PredictionResponse( prediction=result["prediction"], confidence=result["confidence"], label=result["label"], probabilities=result["probabilities"] ) results.append(response) except Exception: results.append( PredictionResponse( prediction=-1, confidence=0.0, label="error", probabilities={"error": 1.0} ) ) # CALCULER LE SUMMARY successful_predictions = [r for r in results if r.prediction != -1] if successful_predictions: total_apparie = sum(1 for r in successful_predictions if r.prediction == 1) total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0) average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions) summary = { "total_apparie": total_apparie, "total_non_apparie": total_non_apparie, "average_confidence": round(average_confidence, 4), "successful_predictions": len(successful_predictions), "failed_predictions": len(results) - len(successful_predictions) } else: summary = { "total_apparie": 0, "total_non_apparie": 0, "average_confidence": 0.0, "successful_predictions": 0, "failed_predictions": len(results) } logger.info(f"Batch KPA prediction completed — {len(results)} items processed") return BatchPredictionResponse( predictions=results, total_processed=len(results), summary=summary ) except Exception as e: logger.error(f"Batch KPA prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")