File size: 3,329 Bytes
d9d9974
9db766f
 
 
 
 
 
 
 
 
d9d9974
 
 
9db766f
 
 
 
 
d9d9974
9db766f
d9d9974
9db766f
d9d9974
9db766f
d9d9974
9db766f
d9d9974
9db766f
 
 
 
 
 
 
 
d9d9974
9db766f
d9d9974
9db766f
 
 
 
 
 
 
 
 
 
 
 
 
d9d9974
9db766f
d9d9974
 
9db766f
d9d9974
 
 
 
 
9db766f
d9d9974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9db766f
 
d9d9974
 
 
 
9db766f
 
 
 
d9d9974
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Model manager for stance detection model"""

import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging

logger = logging.getLogger(__name__)


class StanceModelManager:
    """Manages stance detection model loading and predictions"""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.model_loaded = False
    
    def load_model(self, model_id: str, api_key: str = None):
        """Load model and tokenizer from Hugging Face"""
        if self.model_loaded:
            logger.info("Stance model already loaded")
            return
        
        try:
            logger.info(f"Loading stance model from Hugging Face: {model_id}")
            
            # Determine device
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Using device: {self.device}")
            
            # Prepare token for authentication if API key is provided
            token = api_key if api_key else None
            
            # Load tokenizer and model from Hugging Face
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained( 
                model_id,
                token=token,
                trust_remote_code=True
            )
            
            logger.info("Loading model...")
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_id,
                token=token,
                trust_remote_code=True
            )
            self.model.to(self.device)
            self.model.eval()
            
            self.model_loaded = True
            logger.info("✓ Stance model loaded successfully from Hugging Face!")
            
        except Exception as e:
            logger.error(f"Error loading stance model: {str(e)}")
            raise RuntimeError(f"Failed to load stance model: {str(e)}")
    
    def predict(self, topic: str, argument: str) -> dict:
        """Make a single stance prediction"""
        if not self.model_loaded:
            raise RuntimeError("Stance model not loaded")
        
        # Format input
        text = f"Topic: {topic} [SEP] Argument: {argument}"
        
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        ).to(self.device)
        
        # Predict
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
        
        # Extract probabilities
        prob_con = probabilities[0][0].item()
        prob_pro = probabilities[0][1].item()
        
        # Determine stance
        stance = "PRO" if predicted_class == 1 else "CON"
        confidence = probabilities[0][predicted_class].item()
        
        return {
            "predicted_stance": stance,
            "confidence": confidence,
            "probability_con": prob_con,
            "probability_pro": prob_pro
        }


# Initialize singleton instance
stance_model_manager = StanceModelManager()