File size: 3,329 Bytes
9db766f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17abb72
9db766f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Model manager for stance detection model"""

import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging

logger = logging.getLogger(__name__)


class StanceModelManager:
    """Manages stance detection model loading and predictions"""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_loaded = False
    
    def load_model(self, model_id: str, api_key: str = None):
        """Load model and tokenizer from Hugging Face"""
        if self.model_loaded:
            logger.info("Stance model already loaded")
            return
        
        try:
            logger.info(f"Loading stance model from Hugging Face: {model_id}")
            
            # Determine device
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {self.device}")
            
            # Prepare token for authentication if API key is provided
            token = api_key if api_key else None
            
            # Load tokenizer and model from Hugging Face
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained( 
                model_id,
                token=token,
                trust_remote_code=True
            )
            
            logger.info("Loading model...")
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_id,
                token=token,
                trust_remote_code=True
            )
            self.model.to(self.device)
            self.model.eval()
            
            self.model_loaded = True
            logger.info("✓ Stance model loaded successfully from Hugging Face!")
            
        except Exception as e:
            logger.error(f"Error loading stance model: {str(e)}")
            raise RuntimeError(f"Failed to load stance model: {str(e)}")
    
    def predict(self, topic: str, argument: str) -> dict:
        """Make a single stance prediction"""
        if not self.model_loaded:
            raise RuntimeError("Stance model not loaded")
        
        # Format input
        text = f"Topic: {topic} [SEP] Argument: {argument}"
        
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        ).to(self.device)
        
        # Predict
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
        
        # Extract probabilities
        prob_con = probabilities[0][0].item()
        prob_pro = probabilities[0][1].item()
        
        # Determine stance
        stance = "PRO" if predicted_class == 1 else "CON"
        confidence = probabilities[0][predicted_class].item()
        
        return {
            "predicted_stance": stance,
            "confidence": confidence,
            "probability_con": prob_con,
            "probability_pro": prob_pro
        }


# Initialize singleton instance
stance_model_manager = StanceModelManager()