"""Model manager for keypoint–argument matching model""" import os import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import logging logger = logging.getLogger(__name__) class KpaModelManager: """Manages loading and inference for keypoint matching model""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.max_length = 256 self.model_id = None def load_model(self, model_id: str, api_key: str = None): """Load complete model and tokenizer directly from Hugging Face""" if self.model_loaded: logger.info("KPA model already loaded") return try: # Debug: Vérifier les paramètres d'entrée logger.info(f"=== DEBUG KPA MODEL LOADING ===") logger.info(f"model_id reçu: {model_id}") logger.info(f"model_id type: {type(model_id)}") logger.info(f"api_key présent: {api_key is not None}") if model_id is None: raise ValueError("model_id cannot be None - check your .env file") logger.info(f"Loading KPA model from Hugging Face: {model_id}") # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Store model ID self.model_id = model_id # Prepare token for authentication if API key is provided token = api_key if api_key else None # Load tokenizer and model directly from Hugging Face logger.info("Step 1: Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( model_id, token=token, trust_remote_code=True ) logger.info("✓ Tokenizer loaded successfully") logger.info("Step 2: Loading model...") self.model = AutoModelForSequenceClassification.from_pretrained( model_id, token=token, trust_remote_code=True ) logger.info("✓ Model architecture loaded") self.model.to(self.device) self.model.eval() logger.info("✓ Model moved to device and set to eval mode") self.model_loaded = True logger.info("✓ KPA model loaded successfully from Hugging Face!") logger.info(f"=== KPA MODEL LOADING COMPLETE ===") except Exception as e: logger.error(f"❌ Error loading KPA model: {str(e)}") logger.error(f"❌ Model ID was: {model_id}") logger.error(f"❌ API Key present: {api_key is not None}") raise RuntimeError(f"Failed to load KPA model: {str(e)}") def predict(self, argument: str, key_point: str) -> dict: """Run a prediction for (argument, key_point)""" if not self.model_loaded: raise RuntimeError("KPA model not loaded") try: # Tokenize input encoding = self.tokenizer( argument, key_point, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt" ).to(self.device) # Forward pass with torch.no_grad(): outputs = self.model(**encoding) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() confidence = probabilities[0][predicted_class].item() return { "prediction": predicted_class, "confidence": confidence, "label": "apparie" if predicted_class == 1 else "non_apparie", "probabilities": { "non_apparie": probabilities[0][0].item(), "apparie": probabilities[0][1].item(), }, } except Exception as e: logger.error(f"Error during prediction: {str(e)}") raise RuntimeError(f"KPA prediction failed: {str(e)}") def get_model_info(self): """Get model information""" if not self.model_loaded: return {"loaded": False} return { "model_name": self.model_id, "device": str(self.device), "max_length": self.max_length, "num_labels": 2, "loaded": self.model_loaded } # Initialize singleton instance kpa_model_manager = KpaModelManager()