| """Pydantic schemas for key-point matching prediction endpoints""" | |
| from pydantic import BaseModel, Field, ConfigDict | |
| from typing import List, Optional, Dict | |
| class PredictionRequest(BaseModel): | |
| """Request model for single key-point/argument prediction""" | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "argument": "Climate change is accelerating due to industrial emissions.", | |
| "key_point": "Human industry contributes significantly to global warming." | |
| } | |
| } | |
| ) | |
| argument: str = Field( | |
| ..., min_length=5, max_length=1000, | |
| description="The argument text to evaluate" | |
| ) | |
| key_point: str = Field( | |
| ..., min_length=5, max_length=500, | |
| description="The key point used for comparison" | |
| ) | |
| class PredictionResponse(BaseModel): | |
| """Response model for single prediction""" | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "prediction": 1, | |
| "confidence": 0.874, | |
| "label": "apparie", | |
| "probabilities": { | |
| "non_apparie": 0.126, | |
| "apparie": 0.874 | |
| } | |
| } | |
| } | |
| ) | |
| prediction: int = Field(..., description="1 = apparie, 0 = non_apparie") | |
| confidence: float = Field(..., ge=0.0, le=1.0, | |
| description="Confidence score of the prediction") | |
| label: str = Field(..., description="apparie or non_apparie") | |
| probabilities: Dict[str, float] = Field( | |
| ..., description="Dictionary of class probabilities" | |
| ) | |
| class BatchPredictionRequest(BaseModel): | |
| """Request model for batch predictions""" | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "pairs": [ | |
| { | |
| "argument": "School uniforms limit students' self-expression and creativity", | |
| "key_point": "Uniforms restrict personal freedom and individuality" | |
| }, | |
| { | |
| "argument": "Renewable energy creates more jobs than fossil fuel industries", | |
| "key_point": "Green energy generates employment opportunities" | |
| }, | |
| { | |
| "argument": "We should invest more in renewable energy to combat climate change", | |
| "key_point": "Capital punishment violates human rights" | |
| }, | |
| { | |
| "argument": "Online education provides flexibility and accessibility for all students", | |
| "key_point": "Digital learning offers convenient and inclusive education" | |
| }, | |
| { | |
| "argument": "Vaccinations are essential for public health and disease prevention", | |
| "key_point": "Space exploration leads to scientific discoveries" | |
| } | |
| ] | |
| } | |
| } | |
| ) | |
| pairs: List[PredictionRequest] = Field( | |
| ..., max_length=100, | |
| description="List of argument-keypoint pairs (max 100)" | |
| ) | |
| class BatchPredictionResponse(BaseModel): | |
| """Response model for batch key-point predictions""" | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "predictions": [ | |
| { | |
| "prediction": 1, | |
| "confidence": 0.956, | |
| "label": "apparie", | |
| "probabilities": { | |
| "non_apparie": 0.044, | |
| "apparie": 0.956 | |
| } | |
| }, | |
| { | |
| "prediction": 1, | |
| "confidence": 0.892, | |
| "label": "apparie", | |
| "probabilities": { | |
| "non_apparie": 0.108, | |
| "apparie": 0.892 | |
| } | |
| }, | |
| { | |
| "prediction": 0, | |
| "confidence": 0.934, | |
| "label": "non_apparie", | |
| "probabilities": { | |
| "non_apparie": 0.934, | |
| "apparie": 0.066 | |
| } | |
| }, | |
| { | |
| "prediction": 1, | |
| "confidence": 0.878, | |
| "label": "apparie", | |
| "probabilities": { | |
| "non_apparie": 0.122, | |
| "apparie": 0.878 | |
| } | |
| }, | |
| { | |
| "prediction": 0, | |
| "confidence": 0.967, | |
| "label": "non_apparie", | |
| "probabilities": { | |
| "non_apparie": 0.967, | |
| "apparie": 0.033 | |
| } | |
| } | |
| ], | |
| "total_processed": 5, | |
| "summary": { | |
| "total_apparie": 3, | |
| "total_non_apparie": 2, | |
| "average_confidence": 0.9254 | |
| } | |
| } | |
| } | |
| ) | |
| predictions: List[PredictionResponse] | |
| total_processed: int = Field(..., description="Number of processed items") | |
| summary: Dict[str, float] = Field( | |
| default_factory=dict, | |
| description="Summary statistics of the batch prediction" | |
| ) | |
| class HealthResponse(BaseModel): | |
| """Health check model for the API""" | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "device": "cpu", | |
| "model_name": "NLP-Debater-Project/destlibert-keypoint-matching", | |
| "timestamp": "2024-01-01T12:00:00Z" | |
| } | |
| } | |
| ) | |
| status: str = Field(..., description="API health status") | |
| model_loaded: bool = Field(..., description="Whether the model is loaded") | |
| device: str = Field(..., description="Device used for inference (cpu/cuda)") | |
| model_name: Optional[str] = Field(None, description="Name of the loaded model") | |
| timestamp: str = Field(..., description="Timestamp of the health check") | |
| class ModelInfoResponse(BaseModel): | |
| """Detailed model information response""" | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "model_name": "NLP-Debater-Project/destlibert-keypoint-matching", | |
| "device": "cpu", | |
| "max_length": 256, | |
| "num_labels": 2, | |
| "loaded": True, | |
| "performance": { | |
| "accuracy": 0.9285, | |
| "f1_score": 0.8836, | |
| "f1_apparie": 0.8113, | |
| "f1_non_apparie": 0.9559 | |
| }, | |
| "description": "DistilBERT model for key point - argument semantic matching" | |
| } | |
| } | |
| ) | |
| model_name: str | |
| device: str | |
| max_length: int | |
| num_labels: int | |
| loaded: bool | |
| performance: Dict[str, float] = Field( | |
| ..., description="Model performance metrics" | |
| ) | |
| description: str = Field(..., description="Model description") |