malek-messaoudii commited on
Commit
03d23e8
·
1 Parent(s): e1f71bd

Add Kpa classification model

Browse files
Files changed (3) hide show
  1. models/label.py +98 -1
  2. routes/label.py +123 -0
  3. services/label_model_manage.py +122 -0
models/label.py CHANGED
@@ -1 +1,98 @@
1
- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic schemas for key-point matching prediction endpoints"""
2
+
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+ from typing import List, Optional, Dict
5
+
6
+
7
+ class PredictionRequest(BaseModel):
8
+ """Request model for single key-point/argument prediction"""
9
+ model_config = ConfigDict(
10
+ json_schema_extra={
11
+ "example": {
12
+ "argument": "Climate change is accelerating due to industrial emissions.",
13
+ "key_point": "Human industry contributes significantly to global warming."
14
+ }
15
+ }
16
+ )
17
+
18
+ argument: str = Field(
19
+ ..., min_length=5, max_length=1000,
20
+ description="The argument text to evaluate"
21
+ )
22
+ key_point: str = Field(
23
+ ..., min_length=5, max_length=500,
24
+ description="The key point used for comparison"
25
+ )
26
+
27
+
28
+ class PredictionResponse(BaseModel):
29
+ """Response model for single prediction"""
30
+ model_config = ConfigDict(
31
+ json_schema_extra={
32
+ "example": {
33
+ "prediction": 1,
34
+ "confidence": 0.874,
35
+ "label": "MATCH",
36
+ "probabilities": {
37
+ "match": 0.874,
38
+ "no_match": 0.126
39
+ }
40
+ }
41
+ }
42
+ )
43
+
44
+ prediction: int = Field(..., description="1 = match, 0 = no match")
45
+ confidence: float = Field(..., ge=0.0, le=1.0,
46
+ description="Confidence score of the prediction")
47
+ label: str = Field(..., description="MATCH or NO_MATCH")
48
+ probabilities: Dict[str, float] = Field(
49
+ ..., description="Dictionary of class probabilities"
50
+ )
51
+
52
+
53
+ class BatchPredictionRequest(BaseModel):
54
+ """Request model for batch predictions"""
55
+ model_config = ConfigDict(
56
+ json_schema_extra={
57
+ "example": {
58
+ "pairs": [
59
+ {
60
+ "argument": "Schools should implement AI tools to support learning.",
61
+ "key_point": "AI can improve student engagement."
62
+ },
63
+ {
64
+ "argument": "Governments must reduce plastic usage.",
65
+ "key_point": "Plastic waste harms the environment."
66
+ }
67
+ ]
68
+ }
69
+ }
70
+ )
71
+
72
+ pairs: List[PredictionRequest] = Field(
73
+ ..., max_length=100,
74
+ description="List of argument-keypoint pairs (max 100)"
75
+ )
76
+
77
+
78
+ class BatchPredictionResponse(BaseModel):
79
+ """Response model for batch key-point predictions"""
80
+ predictions: List[PredictionResponse]
81
+ total_processed: int = Field(..., description="Number of processed items")
82
+
83
+
84
+ class HealthResponse(BaseModel):
85
+ """Health check model for the API"""
86
+ model_config = ConfigDict(
87
+ json_schema_extra={
88
+ "example": {
89
+ "status": "ok",
90
+ "model_loaded": True,
91
+ "device": "cuda"
92
+ }
93
+ }
94
+ )
95
+
96
+ status: str = Field(..., description="API health status")
97
+ model_loaded: bool = Field(..., description="Whether the model is loaded")
98
+ device: str = Field(..., description="Device used for inference (cpu/cuda)")
routes/label.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Keypoint–Argument Matching Endpoints"""
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+ from datetime import datetime
5
+ import logging
6
+
7
+ from models import (
8
+ PredictionRequest,
9
+ PredictionResponse,
10
+ BatchPredictionRequest,
11
+ BatchPredictionResponse
12
+ )
13
+
14
+ from services import label_model_manage # ton ModelManager clean
15
+
16
+ router = APIRouter()
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @router.get("/model-info", tags=["KPA"])
21
+ async def get_model_info():
22
+ """
23
+ Return information about the loaded KPA model.
24
+ """
25
+ try:
26
+ model_info = label_model_manage.get_model_info()
27
+
28
+ return {
29
+ "model_name": model_info.get("model_name", "unknown"),
30
+ "device": model_info.get("device", "cpu"),
31
+ "max_length": model_info.get("max_length", 256),
32
+ "num_labels": model_info.get("num_labels", 2),
33
+ "loaded": model_info.get("loaded", False),
34
+ "timestamp": datetime.now().isoformat()
35
+ }
36
+
37
+ except Exception as e:
38
+ logger.error(f"Model info error: {str(e)}")
39
+ raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
40
+
41
+
42
+ @router.post("/predict", response_model=PredictionResponse, tags=["KPA"])
43
+ async def predict_kpa(request: PredictionRequest):
44
+ """
45
+ Predict keypoint–argument matching for a single pair.
46
+
47
+ - **argument**: The argument text
48
+ - **key_point**: The key point to evaluate
49
+
50
+ Returns the predicted class (apparie / non_apparie) with probabilities.
51
+ """
52
+ try:
53
+ result = label_model_manage.predict(
54
+ argument=request.argument,
55
+ key_point=request.key_point
56
+ )
57
+
58
+ response = PredictionResponse(
59
+ prediction=result["prediction"],
60
+ confidence=result["confidence"],
61
+ label=result["label"],
62
+ probabilities=result["probabilities"]
63
+ )
64
+
65
+ logger.info(
66
+ f"KPA Prediction: {response.label} "
67
+ f"(conf={response.confidence:.4f})"
68
+ )
69
+
70
+ return response
71
+
72
+ except Exception as e:
73
+ logger.error(f"KPA prediction error: {str(e)}")
74
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
75
+
76
+
77
+ @router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"])
78
+ async def batch_predict_kpa(request: BatchPredictionRequest):
79
+ """
80
+ Predict keypoint–argument matching for multiple argument/keypoint pairs.
81
+
82
+ - **pairs**: List of items to classify
83
+
84
+ Returns predictions for all pairs.
85
+ """
86
+ try:
87
+ results = []
88
+
89
+ for item in request.pairs:
90
+ try:
91
+ result = label_model_manage.predict(
92
+ argument=item.argument,
93
+ key_point=item.key_point
94
+ )
95
+
96
+ response = PredictionResponse(
97
+ prediction=result["prediction"],
98
+ confidence=result["confidence"],
99
+ label=result["label"],
100
+ probabilities=result["probabilities"]
101
+ )
102
+
103
+ results.append(response)
104
+
105
+ except Exception:
106
+ results.append(
107
+ PredictionResponse(
108
+ prediction=-1,
109
+ confidence=0.0,
110
+ label="error",
111
+ probabilities={"error": 1.0}
112
+ )
113
+ )
114
+
115
+ logger.info(f"Batch KPA prediction completed — {len(results)} items processed")
116
+
117
+ return BatchPredictionResponse(
118
+ predictions=results
119
+ )
120
+
121
+ except Exception as e:
122
+ logger.error(f"Batch KPA prediction error: {str(e)}")
123
+ raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
services/label_model_manage.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model manager for keypoint–argument matching model"""
2
+
3
+ import os
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class KpaModelManager:
12
+ """Manages loading and inference for keypoint matching model"""
13
+
14
+ def __init__(self):
15
+ self.model = None
16
+ self.tokenizer = None
17
+ self.device = None
18
+ self.model_loaded = False
19
+ self.max_length = 256
20
+
21
+ def load_model(self, model_path: str = None):
22
+ """Load fine-tuned model weights and tokenizer"""
23
+ if self.model_loaded:
24
+ logger.info("KPA model already loaded")
25
+ return
26
+
27
+ try:
28
+ # Detect device
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ logger.info(f"Using device: {self.device}")
31
+
32
+ # Resolve model path
33
+ if model_path is None:
34
+ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
35
+ model_path = os.path.join(base_dir, "models", "modele_appariement_rapide.pth")
36
+
37
+ logger.info(f"Loading KPA model weights from: {model_path}")
38
+
39
+ if not os.path.exists(model_path):
40
+ raise FileNotFoundError(f"Model file not found: {model_path}")
41
+
42
+ # Load tokenizer + architecture
43
+ model_name = "distilbert-base-uncased"
44
+ logger.info("Loading tokenizer...")
45
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+
47
+ logger.info("Loading base architecture...")
48
+ self.model = AutoModelForSequenceClassification.from_pretrained(
49
+ model_name,
50
+ num_labels=2
51
+ )
52
+
53
+ # Load trained weights
54
+ checkpoint = torch.load(model_path, map_location=self.device)
55
+
56
+ if "model_state_dict" in checkpoint:
57
+ self.model.load_state_dict(checkpoint["model_state_dict"])
58
+ else:
59
+ self.model.load_state_dict(checkpoint)
60
+
61
+ self.model.to(self.device)
62
+ self.model.eval()
63
+
64
+ self.model_loaded = True
65
+ logger.info("✓ KPA model loaded successfully!")
66
+
67
+ except Exception as e:
68
+ logger.error(f"Error loading KPA model: {str(e)}")
69
+ raise RuntimeError(f"Failed to load KPA model: {str(e)}")
70
+
71
+ def predict(self, argument: str, key_point: str) -> dict:
72
+ """Run a prediction for (argument, key_point)"""
73
+ if not self.model_loaded:
74
+ raise RuntimeError("KPA model not loaded")
75
+
76
+ try:
77
+ # Tokenize input
78
+ encoding = self.tokenizer(
79
+ argument,
80
+ key_point,
81
+ truncation=True,
82
+ padding="max_length",
83
+ max_length=self.max_length,
84
+ return_tensors="pt"
85
+ ).to(self.device)
86
+
87
+ # Forward pass
88
+ with torch.no_grad():
89
+ outputs = self.model(**encoding)
90
+ logits = outputs.logits
91
+ probabilities = torch.softmax(logits, dim=-1)
92
+
93
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
94
+ confidence = probabilities[0][predicted_class].item()
95
+
96
+ return {
97
+ "prediction": predicted_class,
98
+ "confidence": confidence,
99
+ "label": "apparie" if predicted_class == 1 else "non_apparie",
100
+ "probabilities": {
101
+ "non_apparie": probabilities[0][0].item(),
102
+ "apparie": probabilities[0][1].item(),
103
+ },
104
+ }
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error during prediction: {str(e)}")
108
+ raise RuntimeError(f"KPA prediction failed: {str(e)}")
109
+
110
+ def get_model_info(self):
111
+ return {
112
+ "model_name": self.model_name,
113
+ "device": str(self.device),
114
+ "max_length": self.max_length,
115
+ "num_labels": 2,
116
+ "loaded": self.model_loaded
117
+ }
118
+
119
+
120
+
121
+ # Singleton instance
122
+ kpa_model_manager = KpaModelManager()