File size: 4,123 Bytes
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Model manager for keypoint–argument matching model"""

import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging

logger = logging.getLogger(__name__)


class KpaModelManager:
    """Manages loading and inference for keypoint matching model"""

    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_loaded = False
        self.max_length = 256

    def load_model(self, model_path: str = None):
        """Load fine-tuned model weights and tokenizer"""
        if self.model_loaded:
            logger.info("KPA model already loaded")
            return

        try:
            # Detect device
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {self.device}")

            # Resolve model path
            if model_path is None:
                base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
                model_path = os.path.join(base_dir, "models", "modele_appariement_rapide.pth")

            logger.info(f"Loading KPA model weights from: {model_path}")

            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file not found: {model_path}")

            # Load tokenizer + architecture
            model_name = "distilbert-base-uncased"
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)

            logger.info("Loading base architecture...")
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                num_labels=2
            )

            # Load trained weights
            checkpoint = torch.load(model_path, map_location=self.device)

            if "model_state_dict" in checkpoint:
                self.model.load_state_dict(checkpoint["model_state_dict"])
            else:
                self.model.load_state_dict(checkpoint)

            self.model.to(self.device)
            self.model.eval()

            self.model_loaded = True
            logger.info("✓ KPA model loaded successfully!")

        except Exception as e:
            logger.error(f"Error loading KPA model: {str(e)}")
            raise RuntimeError(f"Failed to load KPA model: {str(e)}")

    def predict(self, argument: str, key_point: str) -> dict:
        """Run a prediction for (argument, key_point)"""
        if not self.model_loaded:
            raise RuntimeError("KPA model not loaded")

        try:
            # Tokenize input
            encoding = self.tokenizer(
                argument,
                key_point,
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt"
            ).to(self.device)

            # Forward pass
            with torch.no_grad():
                outputs = self.model(**encoding)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=-1)

                predicted_class = torch.argmax(probabilities, dim=-1).item()
                confidence = probabilities[0][predicted_class].item()

            return {
                "prediction": predicted_class,
                "confidence": confidence,
                "label": "apparie" if predicted_class == 1 else "non_apparie",
                "probabilities": {
                    "non_apparie": probabilities[0][0].item(),
                    "apparie": probabilities[0][1].item(),
                },
            }

        except Exception as e:
            logger.error(f"Error during prediction: {str(e)}")
            raise RuntimeError(f"KPA prediction failed: {str(e)}")
        
    def get_model_info(self):
        return {
            "model_name": self.model_name,
            "device": str(self.device),
            "max_length": self.max_length,
            "num_labels": 2,
            "loaded": self.model_loaded
               }



kpa_model_manager = KpaModelManager()