S01Nour
feat: Add `GenerateModelManager` for loading Hugging Face generation models and performing text generation.
2ed0ab3
| """Model manager for generation model""" | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class GenerateModelManager: | |
| """Manages generation model loading and predictions""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| def load_model(self, model_id: str, api_key: str = None): | |
| """Load model and tokenizer from Hugging Face""" | |
| if self.model_loaded: | |
| logger.info("Generation model already loaded") | |
| return | |
| try: | |
| logger.info(f"Loading generation model from Hugging Face: {model_id}") | |
| # Determine device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Prepare token for authentication if API key is provided | |
| token = api_key if api_key else None | |
| # Load tokenizer and model from Hugging Face | |
| logger.info("Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| logger.info("Loading model...") | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.model_loaded = True | |
| logger.info("✓ Generation model loaded successfully from Hugging Face!") | |
| except Exception as e: | |
| logger.error(f"Error loading generation model: {str(e)}") | |
| raise RuntimeError(f"Failed to load generation model: {str(e)}") | |
| def _format_input(self, topic: str, position: str) -> str: | |
| """Format input for the model""" | |
| return f"topic: {topic} stance: {position}" | |
| def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str: | |
| """Generate argument for a topic and position""" | |
| if not self.model_loaded: | |
| raise RuntimeError("Generation model not loaded") | |
| input_text = self._format_input(topic, position) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ).to(self.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=2.5, | |
| length_penalty=1.0 | |
| ) | |
| # Decode | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |
| def batch_generate(self, items: list[dict], max_length: int = 128, num_beams: int = 4) -> list[str]: | |
| """Batch generate arguments""" | |
| if not self.model_loaded: | |
| raise RuntimeError("Generation model not loaded") | |
| # Prepare inputs | |
| input_texts = [self._format_input(item["topic"], item["position"]) for item in items] | |
| # Tokenize batch | |
| inputs = self.tokenizer( | |
| input_texts, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ).to(self.device) | |
| # Generate batch | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=2.5, | |
| length_penalty=1.0 | |
| ) | |
| # Decode batch | |
| generated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| return generated_texts | |
| # Initialize singleton instance | |
| generate_model_manager = GenerateModelManager() | |