Yassine Mhirsi
Add label/KPA-related schemas to models and update documentation in label routes
3afc11b
| """Keypoint–Argument Matching Endpoints""" | |
| from fastapi import APIRouter, HTTPException | |
| from datetime import datetime | |
| import logging | |
| from models import ( | |
| PredictionRequest, | |
| PredictionResponse, | |
| BatchPredictionRequest, | |
| BatchPredictionResponse | |
| ) | |
| from services import kpa_model_manager | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| async def get_model_info(): | |
| """ | |
| Return information about the loaded KPA model. | |
| """ | |
| try: | |
| model_info = kpa_model_manager.get_model_info() | |
| return { | |
| "model_name": model_info.get("model_name", "unknown"), | |
| "device": model_info.get("device", "cpu"), | |
| "max_length": model_info.get("max_length", 256), | |
| "num_labels": model_info.get("num_labels", 2), | |
| "loaded": model_info.get("loaded", False), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| logger.error(f"Model info error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") | |
| async def predict_kpa(request: PredictionRequest): | |
| """ | |
| Predict keypoint-argument matching for a single pair. | |
| - **argument**: The argument text | |
| - **key_point**: The key point to evaluate | |
| Returns the predicted class (apparie / non_apparie) with probabilities. | |
| """ | |
| try: | |
| result = kpa_model_manager.predict( | |
| argument=request.argument, | |
| key_point=request.key_point | |
| ) | |
| response = PredictionResponse( | |
| prediction=result["prediction"], | |
| confidence=result["confidence"], | |
| label=result["label"], | |
| probabilities=result["probabilities"] | |
| ) | |
| logger.info( | |
| f"KPA Prediction: {response.label} " | |
| f"(conf={response.confidence:.4f})" | |
| ) | |
| return response | |
| except Exception as e: | |
| logger.error(f"KPA prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def batch_predict_kpa(request: BatchPredictionRequest): | |
| """ | |
| Predict keypoint-argument matching for multiple argument/keypoint pairs. | |
| - **pairs**: List of items to classify | |
| Returns predictions for all pairs. | |
| """ | |
| try: | |
| results = [] | |
| for item in request.pairs: | |
| try: | |
| result = kpa_model_manager.predict( | |
| argument=item.argument, | |
| key_point=item.key_point | |
| ) | |
| response = PredictionResponse( | |
| prediction=result["prediction"], | |
| confidence=result["confidence"], | |
| label=result["label"], | |
| probabilities=result["probabilities"] | |
| ) | |
| results.append(response) | |
| except Exception: | |
| results.append( | |
| PredictionResponse( | |
| prediction=-1, | |
| confidence=0.0, | |
| label="error", | |
| probabilities={"error": 1.0} | |
| ) | |
| ) | |
| logger.info(f"Batch KPA prediction completed — {len(results)} items processed") | |
| return BatchPredictionResponse( | |
| predictions=results, | |
| total_processed=len(results) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Batch KPA prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") | |