File size: 1,599 Bytes
0d13811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""Generation specific endpoints"""
from fastapi import APIRouter, HTTPException, pydantic
from datetime import datetime
import logging
from typing import Optional
from services import generate_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
class GenerateRequest(pydantic.BaseModel):
input_text: str
max_length: Optional[int] = 128
num_beams: Optional[int] = 4
class GenerateResponse(pydantic.BaseModel):
input_text: str
generated_text: str
timestamp: str
@router.post("/generate", response_model=GenerateResponse, tags=["Text Generation"])
async def generate_text(request: GenerateRequest):
"""
Generate text using the T5 model
- **input_text**: The input text for generation
- **max_length**: Maximum length of generated text (default: 128)
- **num_beams**: Number of beams for beam search (default: 4)
Returns generated text
"""
try:
# Generate text
result = generate_model_manager.generate(
request.input_text,
max_length=request.max_length,
num_beams=request.num_beams
)
# Build response
response = GenerateResponse(
input_text=request.input_text,
generated_text=result,
timestamp=datetime.now().isoformat()
)
logger.info(f"Generated text: {result}")
return response
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|