File size: 4,863 Bytes
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f28285b
03d23e8
f28285b
d9d9974
03d23e8
 
 
 
 
cfbe56e
 
 
 
 
 
 
 
 
f28285b
 
 
03d23e8
 
f28285b
 
 
 
 
 
 
d9d9974
cfbe56e
d9d9974
 
 
 
d289997
cfbe56e
d289997
cfbe56e
d9d9974
 
 
 
03d23e8
cfbe56e
 
03d23e8
 
cfbe56e
03d23e8
 
f28285b
cfbe56e
03d23e8
 
cfbe56e
 
 
03d23e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9d9974
 
 
 
03d23e8
f28285b
03d23e8
 
 
 
f28285b
03d23e8
 
d9d9974
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Model manager for keypoint–argument matching model"""

import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging

logger = logging.getLogger(__name__)


class KpaModelManager:
    """Manages loading and inference for keypoint matching model"""

    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_loaded = False
        self.max_length = 256
        self.model_id = None

    def load_model(self, model_id: str, api_key: str = None):
        """Load complete model and tokenizer directly from Hugging Face"""
        if self.model_loaded:
            logger.info("KPA model already loaded")
            return

        try:
            # Debug: Vérifier les paramètres d'entrée
            logger.info(f"=== DEBUG KPA MODEL LOADING ===")
            logger.info(f"model_id reçu: {model_id}")
            logger.info(f"model_id type: {type(model_id)}")
            logger.info(f"api_key présent: {api_key is not None}")
            
            if model_id is None:
                raise ValueError("model_id cannot be None - check your .env file")
                
            logger.info(f"Loading KPA model from Hugging Face: {model_id}")
            
            # Determine device
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {self.device}")
            
            # Store model ID
            self.model_id = model_id
            
            # Prepare token for authentication if API key is provided
            token = api_key if api_key else None
            
            # Load tokenizer and model directly from Hugging Face
            logger.info("Step 1: Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_id,
                token=token,
                trust_remote_code=True
            )
            logger.info("✓ Tokenizer loaded successfully")
            
            logger.info("Step 2: Loading model...")
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_id,
                token=token,
                trust_remote_code=True
            )
            logger.info("✓ Model architecture loaded")
            
            self.model.to(self.device)
            self.model.eval()
            logger.info("✓ Model moved to device and set to eval mode")

            self.model_loaded = True
            logger.info("✓ KPA model loaded successfully from Hugging Face!")
            logger.info(f"=== KPA MODEL LOADING COMPLETE ===")

        except Exception as e:
            logger.error(f"❌ Error loading KPA model: {str(e)}")
            logger.error(f"❌ Model ID was: {model_id}")
            logger.error(f"❌ API Key present: {api_key is not None}")
            raise RuntimeError(f"Failed to load KPA model: {str(e)}")

    def predict(self, argument: str, key_point: str) -> dict:
        """Run a prediction for (argument, key_point)"""
        if not self.model_loaded:
            raise RuntimeError("KPA model not loaded")

        try:
            # Tokenize input
            encoding = self.tokenizer(
                argument,
                key_point,
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt"
            ).to(self.device)

            # Forward pass
            with torch.no_grad():
                outputs = self.model(**encoding)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=-1)

                predicted_class = torch.argmax(probabilities, dim=-1).item()
                confidence = probabilities[0][predicted_class].item()

            return {
                "prediction": predicted_class,
                "confidence": confidence,
                "label": "apparie" if predicted_class == 1 else "non_apparie",
                "probabilities": {
                    "non_apparie": probabilities[0][0].item(),
                    "apparie": probabilities[0][1].item(),
                },
            }

        except Exception as e:
            logger.error(f"Error during prediction: {str(e)}")
            raise RuntimeError(f"KPA prediction failed: {str(e)}")
        
    def get_model_info(self):
        """Get model information"""
        if not self.model_loaded:
            return {"loaded": False}
            
        return {
            "model_name": self.model_id,
            "device": str(self.device),
            "max_length": self.max_length,
            "num_labels": 2,
            "loaded": self.model_loaded
        }


# Initialize singleton instance
kpa_model_manager = KpaModelManager()