Yassine Mhirsi
Fix formatting in stance_model_manager.py by adjusting whitespace in tokenizer loading code
17abb72
| """Model manager for stance detection model""" | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class StanceModelManager: | |
| """Manages stance detection model loading and predictions""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| def load_model(self, model_id: str, api_key: str = None): | |
| """Load model and tokenizer from Hugging Face""" | |
| if self.model_loaded: | |
| logger.info("Stance model already loaded") | |
| return | |
| try: | |
| logger.info(f"Loading stance model from Hugging Face: {model_id}") | |
| # Determine device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Prepare token for authentication if API key is provided | |
| token = api_key if api_key else None | |
| # Load tokenizer and model from Hugging Face | |
| logger.info("Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Loading model...") | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.model_loaded = True | |
| logger.info("✓ Stance model loaded successfully from Hugging Face!") | |
| except Exception as e: | |
| logger.error(f"Error loading stance model: {str(e)}") | |
| raise RuntimeError(f"Failed to load stance model: {str(e)}") | |
| def predict(self, topic: str, argument: str) -> dict: | |
| """Make a single stance prediction""" | |
| if not self.model_loaded: | |
| raise RuntimeError("Stance model not loaded") | |
| # Format input | |
| text = f"Topic: {topic} [SEP] Argument: {argument}" | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ).to(self.device) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities, dim=-1).item() | |
| # Extract probabilities | |
| prob_con = probabilities[0][0].item() | |
| prob_pro = probabilities[0][1].item() | |
| # Determine stance | |
| stance = "PRO" if predicted_class == 1 else "CON" | |
| confidence = probabilities[0][predicted_class].item() | |
| return { | |
| "predicted_stance": stance, | |
| "confidence": confidence, | |
| "probability_con": prob_con, | |
| "probability_pro": prob_pro | |
| } | |
| # Initialize singleton instance | |
| stance_model_manager = StanceModelManager() | |