"""Model manager for stance detection model""" import os import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import logging logger = logging.getLogger(__name__) class StanceModelManager: """Manages stance detection model loading and predictions""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False def load_model(self, model_id: str, api_key: str = None): """Load model and tokenizer from Hugging Face""" if self.model_loaded: logger.info("Stance model already loaded") return try: logger.info(f"Loading stance model from Hugging Face: {model_id}") # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Prepare token for authentication if API key is provided token = api_key if api_key else None # Load tokenizer and model from Hugging Face logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( model_id, token=token, trust_remote_code=True ) logger.info("Loading model...") self.model = AutoModelForSequenceClassification.from_pretrained( model_id, token=token, trust_remote_code=True ) self.model.to(self.device) self.model.eval() self.model_loaded = True logger.info("✓ Stance model loaded successfully from Hugging Face!") except Exception as e: logger.error(f"Error loading stance model: {str(e)}") raise RuntimeError(f"Failed to load stance model: {str(e)}") def predict(self, topic: str, argument: str) -> dict: """Make a single stance prediction""" if not self.model_loaded: raise RuntimeError("Stance model not loaded") # Format input text = f"Topic: {topic} [SEP] Argument: {argument}" # Tokenize inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(self.device) # Predict with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() # Extract probabilities prob_con = probabilities[0][0].item() prob_pro = probabilities[0][1].item() # Determine stance stance = "PRO" if predicted_class == 1 else "CON" confidence = probabilities[0][predicted_class].item() return { "predicted_stance": stance, "confidence": confidence, "probability_con": prob_con, "probability_pro": prob_pro } # Initialize singleton instance stance_model_manager = StanceModelManager()