File size: 3,173 Bytes
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Pydantic schemas for key-point matching prediction endpoints"""

from pydantic import BaseModel, Field, ConfigDict
from typing import List, Optional, Dict


class PredictionRequest(BaseModel):
    """Request model for single key-point/argument prediction"""
    model_config = ConfigDict(
        json_schema_extra={
            "example": {
                "argument": "Climate change is accelerating due to industrial emissions.",
                "key_point": "Human industry contributes significantly to global warming."
            }
        }
    )

    argument: str = Field(
        ..., min_length=5, max_length=1000,
        description="The argument text to evaluate"
    )
    key_point: str = Field(
        ..., min_length=5, max_length=500,
        description="The key point used for comparison"
    )


class PredictionResponse(BaseModel):
    """Response model for single prediction"""
    model_config = ConfigDict(
        json_schema_extra={
            "example": {
                "prediction": 1,
                "confidence": 0.874,
                "label": "MATCH",
                "probabilities": {
                    "match": 0.874,
                    "no_match": 0.126
                }
            }
        }
    )

    prediction: int = Field(..., description="1 = match, 0 = no match")
    confidence: float = Field(..., ge=0.0, le=1.0,
                              description="Confidence score of the prediction")
    label: str = Field(..., description="MATCH or NO_MATCH")
    probabilities: Dict[str, float] = Field(
        ..., description="Dictionary of class probabilities"
    )


class BatchPredictionRequest(BaseModel):
    """Request model for batch predictions"""
    model_config = ConfigDict(
        json_schema_extra={
            "example": {
                "pairs": [
                    {
                        "argument": "Schools should implement AI tools to support learning.",
                        "key_point": "AI can improve student engagement."
                    },
                    {
                        "argument": "Governments must reduce plastic usage.",
                        "key_point": "Plastic waste harms the environment."
                    }
                ]
            }
        }
    )

    pairs: List[PredictionRequest] = Field(
        ..., max_length=100,
        description="List of argument-keypoint pairs (max 100)"
    )


class BatchPredictionResponse(BaseModel):
    """Response model for batch key-point predictions"""
    predictions: List[PredictionResponse]
    total_processed: int = Field(..., description="Number of processed items")


class HealthResponse(BaseModel):
    """Health check model for the API"""
    model_config = ConfigDict(
        json_schema_extra={
            "example": {
                "status": "ok",
                "model_loaded": True,
                "device": "cuda"
            }
        }
    )
    
    status: str = Field(..., description="API health status")
    model_loaded: bool = Field(..., description="Whether the model is loaded")
    device: str = Field(..., description="Device used for inference (cpu/cuda)")