"""Model manager for generation model""" import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import logging logger = logging.getLogger(__name__) class GenerateModelManager: """Manages generation model loading and predictions""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False def load_model(self, model_id: str, api_key: str = None): """Load model and tokenizer from Hugging Face""" if self.model_loaded: logger.info("Generation model already loaded") return try: logger.info(f"Loading generation model from Hugging Face: {model_id}") # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Prepare token for authentication if API key is provided token = api_key if api_key else None # Load tokenizer and model from Hugging Face logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( model_id, token=token, trust_remote_code=True ) logger.info("Loading model...") self.model = AutoModelForSeq2SeqLM.from_pretrained( model_id, token=token, trust_remote_code=True ) self.model.to(self.device) self.model.eval() self.model_loaded = True logger.info("✓ Generation model loaded successfully from Hugging Face!") except Exception as e: logger.error(f"Error loading generation model: {str(e)}") raise RuntimeError(f"Failed to load generation model: {str(e)}") def _format_input(self, topic: str, position: str) -> str: """Format input for the model""" return f"topic: {topic} stance: {position}" def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str: """Generate argument for a topic and position""" if not self.model_loaded: raise RuntimeError("Generation model not loaded") input_text = self._format_input(topic, position) # Tokenize inputs = self.tokenizer( input_text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(self.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True, no_repeat_ngram_size=3, repetition_penalty=2.5, length_penalty=1.0 ) # Decode generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text def batch_generate(self, items: list[dict], max_length: int = 128, num_beams: int = 4) -> list[str]: """Batch generate arguments""" if not self.model_loaded: raise RuntimeError("Generation model not loaded") # Prepare inputs input_texts = [self._format_input(item["topic"], item["position"]) for item in items] # Tokenize batch inputs = self.tokenizer( input_texts, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(self.device) # Generate batch with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True, no_repeat_ngram_size=3, repetition_penalty=2.5, length_penalty=1.0 ) # Decode batch generated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return generated_texts # Initialize singleton instance generate_model_manager = GenerateModelManager()