malek-messaoudii
Update label files
8a40c79
raw
history blame
4.88 kB
"""Keypoint–Argument Matching Endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from models import (
PredictionRequest,
PredictionResponse,
BatchPredictionRequest,
BatchPredictionResponse
)
from services import kpa_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/model-info", tags=["KPA"])
async def get_model_info():
"""
Return information about the loaded KPA model.
"""
try:
model_info = kpa_model_manager.get_model_info()
return {
"model_name": model_info.get("model_name", "unknown"),
"device": model_info.get("device", "cpu"),
"max_length": model_info.get("max_length", 256),
"num_labels": model_info.get("num_labels", 2),
"loaded": model_info.get("loaded", False),
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Model info error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
@router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
async def predict_kpa(request: PredictionRequest):
"""
Predict keypoint-argument matching for a single pair.
- **argument**: The argument text
- **key_point**: The key point to evaluate
Returns the predicted class (apparie / non_apparie) with probabilities.
"""
try:
result = kpa_model_manager.predict(
argument=request.argument,
key_point=request.key_point
)
response = PredictionResponse(
prediction=result["prediction"],
confidence=result["confidence"],
label=result["label"],
probabilities=result["probabilities"]
)
logger.info(
f"KPA Prediction: {response.label} "
f"(conf={response.confidence:.4f})"
)
return response
except Exception as e:
logger.error(f"KPA prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"])
async def batch_predict_kpa(request: BatchPredictionRequest):
"""
Predict keypoint-argument matching for multiple argument/keypoint pairs.
- **pairs**: List of items to classify
Returns predictions for all pairs.
"""
try:
results = []
for item in request.pairs:
try:
result = kpa_model_manager.predict(
argument=item.argument,
key_point=item.key_point
)
response = PredictionResponse(
prediction=result["prediction"],
confidence=result["confidence"],
label=result["label"],
probabilities=result["probabilities"]
)
results.append(response)
except Exception:
results.append(
PredictionResponse(
prediction=-1,
confidence=0.0,
label="error",
probabilities={"error": 1.0}
)
)
# CALCULER LE SUMMARY
successful_predictions = [r for r in results if r.prediction != -1]
if successful_predictions:
total_apparie = sum(1 for r in successful_predictions if r.prediction == 1)
total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0)
average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions)
summary = {
"total_apparie": total_apparie,
"total_non_apparie": total_non_apparie,
"average_confidence": round(average_confidence, 4),
"successful_predictions": len(successful_predictions),
"failed_predictions": len(results) - len(successful_predictions)
}
else:
summary = {
"total_apparie": 0,
"total_non_apparie": 0,
"average_confidence": 0.0,
"successful_predictions": 0,
"failed_predictions": len(results)
}
logger.info(f"Batch KPA prediction completed — {len(results)} items processed")
return BatchPredictionResponse(
predictions=results,
total_processed=len(results),
summary=summary
)
except Exception as e:
logger.error(f"Batch KPA prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")