|
|
"""Generation specific endpoints""" |
|
|
|
|
|
from fastapi import APIRouter, HTTPException, pydantic |
|
|
from datetime import datetime |
|
|
import logging |
|
|
from typing import Optional |
|
|
|
|
|
from services import generate_model_manager |
|
|
|
|
|
router = APIRouter() |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GenerateRequest(pydantic.BaseModel): |
|
|
input_text: str |
|
|
max_length: Optional[int] = 128 |
|
|
num_beams: Optional[int] = 4 |
|
|
|
|
|
|
|
|
class GenerateResponse(pydantic.BaseModel): |
|
|
input_text: str |
|
|
generated_text: str |
|
|
timestamp: str |
|
|
|
|
|
|
|
|
@router.post("/generate", response_model=GenerateResponse, tags=["Text Generation"]) |
|
|
async def generate_text(request: GenerateRequest): |
|
|
""" |
|
|
Generate text using the T5 model |
|
|
|
|
|
- **input_text**: The input text for generation |
|
|
- **max_length**: Maximum length of generated text (default: 128) |
|
|
- **num_beams**: Number of beams for beam search (default: 4) |
|
|
|
|
|
Returns generated text |
|
|
""" |
|
|
try: |
|
|
|
|
|
result = generate_model_manager.generate( |
|
|
request.input_text, |
|
|
max_length=request.max_length, |
|
|
num_beams=request.num_beams |
|
|
) |
|
|
|
|
|
|
|
|
response = GenerateResponse( |
|
|
input_text=request.input_text, |
|
|
generated_text=result, |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
|
|
|
logger.info(f"Generated text: {result}") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|