"""Model manager for keypoint–argument matching model""" import os import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import logging logger = logging.getLogger(__name__) class KpaModelManager: """Manages loading and inference for keypoint matching model""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.max_length = 256 def load_model(self, model_path: str = None): """Load fine-tuned model weights and tokenizer""" if self.model_loaded: logger.info("KPA model already loaded") return try: # Detect device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Resolve model path if model_path is None: base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) model_path = os.path.join(base_dir, "models", "modele_appariement_rapide.pth") logger.info(f"Loading KPA model weights from: {model_path}") if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # Load tokenizer + architecture model_name = "distilbert-base-uncased" logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info("Loading base architecture...") self.model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=2 ) # Load trained weights checkpoint = torch.load(model_path, map_location=self.device) if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: self.model.load_state_dict(checkpoint) self.model.to(self.device) self.model.eval() self.model_loaded = True logger.info("✓ KPA model loaded successfully!") except Exception as e: logger.error(f"Error loading KPA model: {str(e)}") raise RuntimeError(f"Failed to load KPA model: {str(e)}") def predict(self, argument: str, key_point: str) -> dict: """Run a prediction for (argument, key_point)""" if not self.model_loaded: raise RuntimeError("KPA model not loaded") try: # Tokenize input encoding = self.tokenizer( argument, key_point, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt" ).to(self.device) # Forward pass with torch.no_grad(): outputs = self.model(**encoding) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() confidence = probabilities[0][predicted_class].item() return { "prediction": predicted_class, "confidence": confidence, "label": "apparie" if predicted_class == 1 else "non_apparie", "probabilities": { "non_apparie": probabilities[0][0].item(), "apparie": probabilities[0][1].item(), }, } except Exception as e: logger.error(f"Error during prediction: {str(e)}") raise RuntimeError(f"KPA prediction failed: {str(e)}") def get_model_info(self): return { "model_name": self.model_name, "device": str(self.device), "max_length": self.max_length, "num_labels": 2, "loaded": self.model_loaded } kpa_model_manager = KpaModelManager()