|
|
"""Model manager for keypoint–argument matching model""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class KpaModelManager: |
|
|
"""Manages loading and inference for keypoint matching model""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = None |
|
|
self.model_loaded = False |
|
|
self.max_length = 256 |
|
|
self.model_id = None |
|
|
|
|
|
def load_model(self, model_id: str, api_key: str = None): |
|
|
"""Load complete model and tokenizer directly from Hugging Face""" |
|
|
if self.model_loaded: |
|
|
logger.info("KPA model already loaded") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"=== DEBUG KPA MODEL LOADING ===") |
|
|
logger.info(f"model_id reçu: {model_id}") |
|
|
logger.info(f"model_id type: {type(model_id)}") |
|
|
logger.info(f"api_key présent: {api_key is not None}") |
|
|
|
|
|
if model_id is None: |
|
|
raise ValueError("model_id cannot be None - check your .env file") |
|
|
|
|
|
logger.info(f"Loading KPA model from Hugging Face: {model_id}") |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
token = api_key if api_key else None |
|
|
|
|
|
|
|
|
logger.info("Step 1: Loading tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_id, |
|
|
token=token, |
|
|
trust_remote_code=True |
|
|
) |
|
|
logger.info("✓ Tokenizer loaded successfully") |
|
|
|
|
|
logger.info("Step 2: Loading model...") |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_id, |
|
|
token=token, |
|
|
trust_remote_code=True |
|
|
) |
|
|
logger.info("✓ Model architecture loaded") |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
logger.info("✓ Model moved to device and set to eval mode") |
|
|
|
|
|
self.model_loaded = True |
|
|
logger.info("✓ KPA model loaded successfully from Hugging Face!") |
|
|
logger.info(f"=== KPA MODEL LOADING COMPLETE ===") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error loading KPA model: {str(e)}") |
|
|
logger.error(f"❌ Model ID was: {model_id}") |
|
|
logger.error(f"❌ API Key present: {api_key is not None}") |
|
|
raise RuntimeError(f"Failed to load KPA model: {str(e)}") |
|
|
|
|
|
def predict(self, argument: str, key_point: str) -> dict: |
|
|
"""Run a prediction for (argument, key_point)""" |
|
|
if not self.model_loaded: |
|
|
raise RuntimeError("KPA model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
argument, |
|
|
key_point, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**encoding) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.softmax(logits, dim=-1) |
|
|
|
|
|
predicted_class = torch.argmax(probabilities, dim=-1).item() |
|
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
return { |
|
|
"prediction": predicted_class, |
|
|
"confidence": confidence, |
|
|
"label": "apparie" if predicted_class == 1 else "non_apparie", |
|
|
"probabilities": { |
|
|
"non_apparie": probabilities[0][0].item(), |
|
|
"apparie": probabilities[0][1].item(), |
|
|
}, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during prediction: {str(e)}") |
|
|
raise RuntimeError(f"KPA prediction failed: {str(e)}") |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Get model information""" |
|
|
if not self.model_loaded: |
|
|
return {"loaded": False} |
|
|
|
|
|
return { |
|
|
"model_name": self.model_id, |
|
|
"device": str(self.device), |
|
|
"max_length": self.max_length, |
|
|
"num_labels": 2, |
|
|
"loaded": self.model_loaded |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
kpa_model_manager = KpaModelManager() |