|
|
"""Keypoint–Argument Matching Endpoints""" |
|
|
|
|
|
from fastapi import APIRouter, HTTPException |
|
|
from datetime import datetime |
|
|
import logging |
|
|
|
|
|
from models import ( |
|
|
PredictionRequest, |
|
|
PredictionResponse, |
|
|
BatchPredictionRequest, |
|
|
BatchPredictionResponse |
|
|
) |
|
|
|
|
|
from services import kpa_model_manager |
|
|
|
|
|
router = APIRouter() |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@router.get("/model-info", tags=["KPA"]) |
|
|
async def get_model_info(): |
|
|
""" |
|
|
Return information about the loaded KPA model. |
|
|
""" |
|
|
try: |
|
|
model_info = kpa_model_manager.get_model_info() |
|
|
|
|
|
return { |
|
|
"model_name": model_info.get("model_name", "unknown"), |
|
|
"device": model_info.get("device", "cpu"), |
|
|
"max_length": model_info.get("max_length", 256), |
|
|
"num_labels": model_info.get("num_labels", 2), |
|
|
"loaded": model_info.get("loaded", False), |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model info error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") |
|
|
|
|
|
|
|
|
@router.post("/predict", response_model=PredictionResponse, tags=["KPA"]) |
|
|
async def predict_kpa(request: PredictionRequest): |
|
|
""" |
|
|
Predict keypoint-argument matching for a single pair. |
|
|
|
|
|
- **argument**: The argument text |
|
|
- **key_point**: The key point to evaluate |
|
|
|
|
|
Returns the predicted class (apparie / non_apparie) with probabilities. |
|
|
""" |
|
|
try: |
|
|
result = kpa_model_manager.predict( |
|
|
argument=request.argument, |
|
|
key_point=request.key_point |
|
|
) |
|
|
|
|
|
response = PredictionResponse( |
|
|
prediction=result["prediction"], |
|
|
confidence=result["confidence"], |
|
|
label=result["label"], |
|
|
probabilities=result["probabilities"] |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"KPA Prediction: {response.label} " |
|
|
f"(conf={response.confidence:.4f})" |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"KPA prediction error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
|
|
|
@router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"]) |
|
|
async def batch_predict_kpa(request: BatchPredictionRequest): |
|
|
""" |
|
|
Predict keypoint-argument matching for multiple argument/keypoint pairs. |
|
|
|
|
|
- **pairs**: List of items to classify |
|
|
|
|
|
Returns predictions for all pairs. |
|
|
""" |
|
|
try: |
|
|
results = [] |
|
|
|
|
|
for item in request.pairs: |
|
|
try: |
|
|
result = kpa_model_manager.predict( |
|
|
argument=item.argument, |
|
|
key_point=item.key_point |
|
|
) |
|
|
|
|
|
response = PredictionResponse( |
|
|
prediction=result["prediction"], |
|
|
confidence=result["confidence"], |
|
|
label=result["label"], |
|
|
probabilities=result["probabilities"] |
|
|
) |
|
|
|
|
|
results.append(response) |
|
|
|
|
|
except Exception: |
|
|
results.append( |
|
|
PredictionResponse( |
|
|
prediction=-1, |
|
|
confidence=0.0, |
|
|
label="error", |
|
|
probabilities={"error": 1.0} |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
successful_predictions = [r for r in results if r.prediction != -1] |
|
|
|
|
|
if successful_predictions: |
|
|
total_apparie = sum(1 for r in successful_predictions if r.prediction == 1) |
|
|
total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0) |
|
|
average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions) |
|
|
|
|
|
summary = { |
|
|
"total_apparie": total_apparie, |
|
|
"total_non_apparie": total_non_apparie, |
|
|
"average_confidence": round(average_confidence, 4), |
|
|
"successful_predictions": len(successful_predictions), |
|
|
"failed_predictions": len(results) - len(successful_predictions) |
|
|
} |
|
|
else: |
|
|
summary = { |
|
|
"total_apparie": 0, |
|
|
"total_non_apparie": 0, |
|
|
"average_confidence": 0.0, |
|
|
"successful_predictions": 0, |
|
|
"failed_predictions": len(results) |
|
|
} |
|
|
|
|
|
logger.info(f"Batch KPA prediction completed — {len(results)} items processed") |
|
|
|
|
|
return BatchPredictionResponse( |
|
|
predictions=results, |
|
|
total_processed=len(results), |
|
|
summary=summary |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Batch KPA prediction error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") |