FastAPI-Backend-Models / services /stance_model_manager.py
malek-messaoudii
Update training script and model files
d9d9974
raw
history blame
3.33 kB
"""Model manager for stance detection model"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
logger = logging.getLogger(__name__)
class StanceModelManager:
"""Manages stance detection model loading and predictions"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("Stance model already loaded")
return
try:
logger.info(f"Loading stance model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ Stance model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading stance model: {str(e)}")
raise RuntimeError(f"Failed to load stance model: {str(e)}")
def predict(self, topic: str, argument: str) -> dict:
"""Make a single stance prediction"""
if not self.model_loaded:
raise RuntimeError("Stance model not loaded")
# Format input
text = f"Topic: {topic} [SEP] Argument: {argument}"
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Predict
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
# Extract probabilities
prob_con = probabilities[0][0].item()
prob_pro = probabilities[0][1].item()
# Determine stance
stance = "PRO" if predicted_class == 1 else "CON"
confidence = probabilities[0][predicted_class].item()
return {
"predicted_stance": stance,
"confidence": confidence,
"probability_con": prob_con,
"probability_pro": prob_pro
}
# Initialize singleton instance
stance_model_manager = StanceModelManager()