File size: 4,863 Bytes
03d23e8 f28285b 03d23e8 f28285b d9d9974 03d23e8 cfbe56e f28285b 03d23e8 f28285b d9d9974 cfbe56e d9d9974 d289997 cfbe56e d289997 cfbe56e d9d9974 03d23e8 cfbe56e 03d23e8 cfbe56e 03d23e8 f28285b cfbe56e 03d23e8 cfbe56e 03d23e8 d9d9974 03d23e8 f28285b 03d23e8 f28285b 03d23e8 d9d9974 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""Model manager for keypoint–argument matching model"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
logger = logging.getLogger(__name__)
class KpaModelManager:
"""Manages loading and inference for keypoint matching model"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
self.max_length = 256
self.model_id = None
def load_model(self, model_id: str, api_key: str = None):
"""Load complete model and tokenizer directly from Hugging Face"""
if self.model_loaded:
logger.info("KPA model already loaded")
return
try:
# Debug: Vérifier les paramètres d'entrée
logger.info(f"=== DEBUG KPA MODEL LOADING ===")
logger.info(f"model_id reçu: {model_id}")
logger.info(f"model_id type: {type(model_id)}")
logger.info(f"api_key présent: {api_key is not None}")
if model_id is None:
raise ValueError("model_id cannot be None - check your .env file")
logger.info(f"Loading KPA model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Store model ID
self.model_id = model_id
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model directly from Hugging Face
logger.info("Step 1: Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("✓ Tokenizer loaded successfully")
logger.info("Step 2: Loading model...")
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("✓ Model architecture loaded")
self.model.to(self.device)
self.model.eval()
logger.info("✓ Model moved to device and set to eval mode")
self.model_loaded = True
logger.info("✓ KPA model loaded successfully from Hugging Face!")
logger.info(f"=== KPA MODEL LOADING COMPLETE ===")
except Exception as e:
logger.error(f"❌ Error loading KPA model: {str(e)}")
logger.error(f"❌ Model ID was: {model_id}")
logger.error(f"❌ API Key present: {api_key is not None}")
raise RuntimeError(f"Failed to load KPA model: {str(e)}")
def predict(self, argument: str, key_point: str) -> dict:
"""Run a prediction for (argument, key_point)"""
if not self.model_loaded:
raise RuntimeError("KPA model not loaded")
try:
# Tokenize input
encoding = self.tokenizer(
argument,
key_point,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
).to(self.device)
# Forward pass
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_class].item()
return {
"prediction": predicted_class,
"confidence": confidence,
"label": "apparie" if predicted_class == 1 else "non_apparie",
"probabilities": {
"non_apparie": probabilities[0][0].item(),
"apparie": probabilities[0][1].item(),
},
}
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise RuntimeError(f"KPA prediction failed: {str(e)}")
def get_model_info(self):
"""Get model information"""
if not self.model_loaded:
return {"loaded": False}
return {
"model_name": self.model_id,
"device": str(self.device),
"max_length": self.max_length,
"num_labels": 2,
"loaded": self.model_loaded
}
# Initialize singleton instance
kpa_model_manager = KpaModelManager() |