File size: 3,329 Bytes
9db766f 17abb72 9db766f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
"""Model manager for stance detection model"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
logger = logging.getLogger(__name__)
class StanceModelManager:
"""Manages stance detection model loading and predictions"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("Stance model already loaded")
return
try:
logger.info(f"Loading stance model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ Stance model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading stance model: {str(e)}")
raise RuntimeError(f"Failed to load stance model: {str(e)}")
def predict(self, topic: str, argument: str) -> dict:
"""Make a single stance prediction"""
if not self.model_loaded:
raise RuntimeError("Stance model not loaded")
# Format input
text = f"Topic: {topic} [SEP] Argument: {argument}"
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Predict
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
# Extract probabilities
prob_con = probabilities[0][0].item()
prob_pro = probabilities[0][1].item()
# Determine stance
stance = "PRO" if predicted_class == 1 else "CON"
confidence = probabilities[0][predicted_class].item()
return {
"predicted_stance": stance,
"confidence": confidence,
"probability_con": prob_con,
"probability_pro": prob_pro
}
# Initialize singleton instance
stance_model_manager = StanceModelManager()
|