"""Model manager for generation model""" import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import logging logger = logging.getLogger(__name__) class GenerateModelManager: """Manages generation model loading and predictions""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False def load_model(self, model_id: str, api_key: str = None): """Load model and tokenizer from Hugging Face""" if self.model_loaded: logger.info("Generation model already loaded") return try: logger.info(f"Loading generation model from Hugging Face: {model_id}") # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Prepare token for authentication if API key is provided token = api_key if api_key else None # Load tokenizer and model from Hugging Face logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( model_id, token=token, trust_remote_code=True ) logger.info("Loading model...") self.model = AutoModelForSeq2SeqLM.from_pretrained( model_id, token=token, trust_remote_code=True ) self.model.to(self.device) self.model.eval() self.model_loaded = True logger.info("✓ Generation model loaded successfully from Hugging Face!") except Exception as e: logger.error(f"Error loading generation model: {str(e)}") raise RuntimeError(f"Failed to load generation model: {str(e)}") def generate(self, input_text: str, max_length: int = 128, num_beams: int = 4) -> str: """Generate text from input""" if not self.model_loaded: raise RuntimeError("Generation model not loaded") # Tokenize inputs = self.tokenizer( input_text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(self.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True ) # Decode generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text # Initialize singleton instance generate_model_manager = GenerateModelManager()