FastAPI-Backend-Models / services /generate_model_manager.py
S01Nour
feat: Add text generation API endpoint with T5 model integration.
0d13811
raw
history blame
2.92 kB
"""Model manager for generation model"""
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
logger = logging.getLogger(__name__)
class GenerateModelManager:
"""Manages generation model loading and predictions"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("Generation model already loaded")
return
try:
logger.info(f"Loading generation model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ Generation model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading generation model: {str(e)}")
raise RuntimeError(f"Failed to load generation model: {str(e)}")
def generate(self, input_text: str, max_length: int = 128, num_beams: int = 4) -> str:
"""Generate text from input"""
if not self.model_loaded:
raise RuntimeError("Generation model not loaded")
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Initialize singleton instance
generate_model_manager = GenerateModelManager()