FastAPI-Backend-Models / services /stance_model_manager.py
Yassine Mhirsi
Fix formatting in stance_model_manager.py by adjusting whitespace in tokenizer loading code
17abb72
raw
history blame
3.33 kB
"""Model manager for stance detection model"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
logger = logging.getLogger(__name__)
class StanceModelManager:
"""Manages stance detection model loading and predictions"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("Stance model already loaded")
return
try:
logger.info(f"Loading stance model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ Stance model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading stance model: {str(e)}")
raise RuntimeError(f"Failed to load stance model: {str(e)}")
def predict(self, topic: str, argument: str) -> dict:
"""Make a single stance prediction"""
if not self.model_loaded:
raise RuntimeError("Stance model not loaded")
# Format input
text = f"Topic: {topic} [SEP] Argument: {argument}"
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Predict
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
# Extract probabilities
prob_con = probabilities[0][0].item()
prob_pro = probabilities[0][1].item()
# Determine stance
stance = "PRO" if predicted_class == 1 else "CON"
confidence = probabilities[0][predicted_class].item()
return {
"predicted_stance": stance,
"confidence": confidence,
"probability_con": prob_con,
"probability_pro": prob_pro
}
# Initialize singleton instance
stance_model_manager = StanceModelManager()