File size: 4,877 Bytes
03d23e8 f28285b 03d23e8 f28285b 03d23e8 f28285b 03d23e8 f28285b 03d23e8 3afc11b 03d23e8 f28285b 03d23e8 8a40c79 03d23e8 f28285b 8a40c79 03d23e8 8a40c79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
"""Keypoint–Argument Matching Endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from models import (
PredictionRequest,
PredictionResponse,
BatchPredictionRequest,
BatchPredictionResponse
)
from services import kpa_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/model-info", tags=["KPA"])
async def get_model_info():
"""
Return information about the loaded KPA model.
"""
try:
model_info = kpa_model_manager.get_model_info()
return {
"model_name": model_info.get("model_name", "unknown"),
"device": model_info.get("device", "cpu"),
"max_length": model_info.get("max_length", 256),
"num_labels": model_info.get("num_labels", 2),
"loaded": model_info.get("loaded", False),
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Model info error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
@router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
async def predict_kpa(request: PredictionRequest):
"""
Predict keypoint-argument matching for a single pair.
- **argument**: The argument text
- **key_point**: The key point to evaluate
Returns the predicted class (apparie / non_apparie) with probabilities.
"""
try:
result = kpa_model_manager.predict(
argument=request.argument,
key_point=request.key_point
)
response = PredictionResponse(
prediction=result["prediction"],
confidence=result["confidence"],
label=result["label"],
probabilities=result["probabilities"]
)
logger.info(
f"KPA Prediction: {response.label} "
f"(conf={response.confidence:.4f})"
)
return response
except Exception as e:
logger.error(f"KPA prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"])
async def batch_predict_kpa(request: BatchPredictionRequest):
"""
Predict keypoint-argument matching for multiple argument/keypoint pairs.
- **pairs**: List of items to classify
Returns predictions for all pairs.
"""
try:
results = []
for item in request.pairs:
try:
result = kpa_model_manager.predict(
argument=item.argument,
key_point=item.key_point
)
response = PredictionResponse(
prediction=result["prediction"],
confidence=result["confidence"],
label=result["label"],
probabilities=result["probabilities"]
)
results.append(response)
except Exception:
results.append(
PredictionResponse(
prediction=-1,
confidence=0.0,
label="error",
probabilities={"error": 1.0}
)
)
# CALCULER LE SUMMARY
successful_predictions = [r for r in results if r.prediction != -1]
if successful_predictions:
total_apparie = sum(1 for r in successful_predictions if r.prediction == 1)
total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0)
average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions)
summary = {
"total_apparie": total_apparie,
"total_non_apparie": total_non_apparie,
"average_confidence": round(average_confidence, 4),
"successful_predictions": len(successful_predictions),
"failed_predictions": len(results) - len(successful_predictions)
}
else:
summary = {
"total_apparie": 0,
"total_non_apparie": 0,
"average_confidence": 0.0,
"successful_predictions": 0,
"failed_predictions": len(results)
}
logger.info(f"Batch KPA prediction completed — {len(results)} items processed")
return BatchPredictionResponse(
predictions=results,
total_processed=len(results),
summary=summary
)
except Exception as e:
logger.error(f"Batch KPA prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") |