|
|
"""Model manager for keypoint–argument matching model""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class KpaModelManager: |
|
|
"""Manages loading and inference for keypoint matching model""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = None |
|
|
self.model_loaded = False |
|
|
self.max_length = 256 |
|
|
|
|
|
def load_model(self, model_path: str = None): |
|
|
"""Load fine-tuned model weights and tokenizer""" |
|
|
if self.model_loaded: |
|
|
logger.info("KPA model already loaded") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
if model_path is None: |
|
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
model_path = os.path.join(base_dir, "models", "modele_appariement_rapide.pth") |
|
|
|
|
|
logger.info(f"Loading KPA model weights from: {model_path}") |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
|
|
|
|
|
model_name = "distilbert-base-uncased" |
|
|
logger.info("Loading tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
logger.info("Loading base architecture...") |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
num_labels=2 |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
|
|
|
|
if "model_state_dict" in checkpoint: |
|
|
self.model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
else: |
|
|
self.model.load_state_dict(checkpoint) |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.model_loaded = True |
|
|
logger.info("✓ KPA model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading KPA model: {str(e)}") |
|
|
raise RuntimeError(f"Failed to load KPA model: {str(e)}") |
|
|
|
|
|
def predict(self, argument: str, key_point: str) -> dict: |
|
|
"""Run a prediction for (argument, key_point)""" |
|
|
if not self.model_loaded: |
|
|
raise RuntimeError("KPA model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
argument, |
|
|
key_point, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**encoding) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.softmax(logits, dim=-1) |
|
|
|
|
|
predicted_class = torch.argmax(probabilities, dim=-1).item() |
|
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
return { |
|
|
"prediction": predicted_class, |
|
|
"confidence": confidence, |
|
|
"label": "apparie" if predicted_class == 1 else "non_apparie", |
|
|
"probabilities": { |
|
|
"non_apparie": probabilities[0][0].item(), |
|
|
"apparie": probabilities[0][1].item(), |
|
|
}, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during prediction: {str(e)}") |
|
|
raise RuntimeError(f"KPA prediction failed: {str(e)}") |
|
|
|
|
|
def get_model_info(self): |
|
|
return { |
|
|
"model_name": self.model_name, |
|
|
"device": str(self.device), |
|
|
"max_length": self.max_length, |
|
|
"num_labels": 2, |
|
|
"loaded": self.model_loaded |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
kpa_model_manager = KpaModelManager() |
|
|
|