File size: 4,877 Bytes
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
f28285b
03d23e8
 
 
 
 
 
 
 
 
 
 
f28285b
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f28285b
03d23e8
 
 
 
 
 
 
f28285b
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3afc11b
03d23e8
 
 
 
 
 
 
 
 
 
f28285b
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a40c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d23e8
 
 
f28285b
8a40c79
 
03d23e8
 
 
 
8a40c79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Keypoint–Argument Matching Endpoints"""

from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging

from models import (
    PredictionRequest,
    PredictionResponse,
    BatchPredictionRequest,
    BatchPredictionResponse
)

from services import kpa_model_manager

router = APIRouter()
logger = logging.getLogger(__name__)


@router.get("/model-info", tags=["KPA"])
async def get_model_info():
    """
    Return information about the loaded KPA model.
    """
    try:
        model_info = kpa_model_manager.get_model_info()

        return {
            "model_name": model_info.get("model_name", "unknown"),
            "device": model_info.get("device", "cpu"),
            "max_length": model_info.get("max_length", 256),
            "num_labels": model_info.get("num_labels", 2),
            "loaded": model_info.get("loaded", False),
            "timestamp": datetime.now().isoformat()
        }

    except Exception as e:
        logger.error(f"Model info error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")


@router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
async def predict_kpa(request: PredictionRequest):
    """
    Predict keypoint-argument matching for a single pair.

    - **argument**: The argument text  
    - **key_point**: The key point to evaluate  

    Returns the predicted class (apparie / non_apparie) with probabilities.
    """
    try:
        result = kpa_model_manager.predict(
            argument=request.argument,
            key_point=request.key_point
        )

        response = PredictionResponse(
            prediction=result["prediction"],
            confidence=result["confidence"],
            label=result["label"],
            probabilities=result["probabilities"]
        )

        logger.info(
            f"KPA Prediction: {response.label} "
            f"(conf={response.confidence:.4f})"
        )

        return response

    except Exception as e:
        logger.error(f"KPA prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")


@router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"])
async def batch_predict_kpa(request: BatchPredictionRequest):
    """
    Predict keypoint-argument matching for multiple argument/keypoint pairs.

    - **pairs**: List of items to classify

    Returns predictions for all pairs.
    """
    try:
        results = []

        for item in request.pairs:
            try:
                result = kpa_model_manager.predict(
                    argument=item.argument,
                    key_point=item.key_point
                )

                response = PredictionResponse(
                    prediction=result["prediction"],
                    confidence=result["confidence"],
                    label=result["label"],
                    probabilities=result["probabilities"]
                )

                results.append(response)

            except Exception:
                results.append(
                    PredictionResponse(
                        prediction=-1,
                        confidence=0.0,
                        label="error",
                        probabilities={"error": 1.0}
                    )
                )

        # CALCULER LE SUMMARY
        successful_predictions = [r for r in results if r.prediction != -1]
        
        if successful_predictions:
            total_apparie = sum(1 for r in successful_predictions if r.prediction == 1)
            total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0)
            average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions)
            
            summary = {
                "total_apparie": total_apparie,
                "total_non_apparie": total_non_apparie,
                "average_confidence": round(average_confidence, 4),
                "successful_predictions": len(successful_predictions),
                "failed_predictions": len(results) - len(successful_predictions)
            }
        else:
            summary = {
                "total_apparie": 0,
                "total_non_apparie": 0,
                "average_confidence": 0.0,
                "successful_predictions": 0,
                "failed_predictions": len(results)
            }

        logger.info(f"Batch KPA prediction completed — {len(results)} items processed")

        return BatchPredictionResponse(
            predictions=results,
            total_processed=len(results),
            summary=summary
        )

    except Exception as e:
        logger.error(f"Batch KPA prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")