FastAPI-Backend-Models / services /generate_model_manager.py
S01Nour
feat: Add `GenerateModelManager` for loading Hugging Face generation models and performing text generation.
2ed0ab3
"""Model manager for generation model"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
logger = logging.getLogger(__name__)
class GenerateModelManager:
"""Manages generation model loading and predictions"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("Generation model already loaded")
return
try:
logger.info(f"Loading generation model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ Generation model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading generation model: {str(e)}")
raise RuntimeError(f"Failed to load generation model: {str(e)}")
def _format_input(self, topic: str, position: str) -> str:
"""Format input for the model"""
return f"topic: {topic} stance: {position}"
def generate(self, topic: str, position: str, max_length: int = 128, num_beams: int = 4) -> str:
"""Generate argument for a topic and position"""
if not self.model_loaded:
raise RuntimeError("Generation model not loaded")
input_text = self._format_input(topic, position)
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=2.5,
length_penalty=1.0
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
def batch_generate(self, items: list[dict], max_length: int = 128, num_beams: int = 4) -> list[str]:
"""Batch generate arguments"""
if not self.model_loaded:
raise RuntimeError("Generation model not loaded")
# Prepare inputs
input_texts = [self._format_input(item["topic"], item["position"]) for item in items]
# Tokenize batch
inputs = self.tokenizer(
input_texts,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Generate batch
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=2.5,
length_penalty=1.0
)
# Decode batch
generated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return generated_texts
# Initialize singleton instance
generate_model_manager = GenerateModelManager()