Yassine Mhirsi
Add label/KPA-related schemas to models and update documentation in label routes
3afc11b
raw
history blame
3.76 kB
"""Keypoint–Argument Matching Endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from models import (
PredictionRequest,
PredictionResponse,
BatchPredictionRequest,
BatchPredictionResponse
)
from services import kpa_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/model-info", tags=["KPA"])
async def get_model_info():
"""
Return information about the loaded KPA model.
"""
try:
model_info = kpa_model_manager.get_model_info()
return {
"model_name": model_info.get("model_name", "unknown"),
"device": model_info.get("device", "cpu"),
"max_length": model_info.get("max_length", 256),
"num_labels": model_info.get("num_labels", 2),
"loaded": model_info.get("loaded", False),
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Model info error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
@router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
async def predict_kpa(request: PredictionRequest):
"""
Predict keypoint-argument matching for a single pair.
- **argument**: The argument text
- **key_point**: The key point to evaluate
Returns the predicted class (apparie / non_apparie) with probabilities.
"""
try:
result = kpa_model_manager.predict(
argument=request.argument,
key_point=request.key_point
)
response = PredictionResponse(
prediction=result["prediction"],
confidence=result["confidence"],
label=result["label"],
probabilities=result["probabilities"]
)
logger.info(
f"KPA Prediction: {response.label} "
f"(conf={response.confidence:.4f})"
)
return response
except Exception as e:
logger.error(f"KPA prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"])
async def batch_predict_kpa(request: BatchPredictionRequest):
"""
Predict keypoint-argument matching for multiple argument/keypoint pairs.
- **pairs**: List of items to classify
Returns predictions for all pairs.
"""
try:
results = []
for item in request.pairs:
try:
result = kpa_model_manager.predict(
argument=item.argument,
key_point=item.key_point
)
response = PredictionResponse(
prediction=result["prediction"],
confidence=result["confidence"],
label=result["label"],
probabilities=result["probabilities"]
)
results.append(response)
except Exception:
results.append(
PredictionResponse(
prediction=-1,
confidence=0.0,
label="error",
probabilities={"error": 1.0}
)
)
logger.info(f"Batch KPA prediction completed — {len(results)} items processed")
return BatchPredictionResponse(
predictions=results,
total_processed=len(results)
)
except Exception as e:
logger.error(f"Batch KPA prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")