File size: 2,521 Bytes
0d13811 66ccc9d 0d13811 306b243 0d13811 306b243 0d13811 306b243 0d13811 306b243 0d13811 306b243 0d13811 306b243 0d13811 306b243 0d13811 306b243 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""Generation specific endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from services import generate_model_manager
from models.generate import GenerateRequest, GenerateResponse, BatchGenerateRequest, BatchGenerateResponse
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict", response_model=GenerateResponse, tags=["Text Generation"])
async def generate_argument(request: GenerateRequest):
"""
Generate an argument for a given topic and position
- **topic**: The debate topic
- **position**: The stance (e.g. "positive", "negative")
"""
try:
# Generate text
result = generate_model_manager.generate(
topic=request.topic,
position=request.position
)
# Build response
response = GenerateResponse(
topic=request.topic,
position=request.position,
argument=result,
timestamp=datetime.now().isoformat()
)
logger.info(f"Generated argument: {result[:50]}...")
return response
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchGenerateResponse, tags=["Text Generation"])
async def batch_generate_argument(request: BatchGenerateRequest):
"""
Generate arguments for multiple topic-position pairs
"""
try:
items_data = [{"topic": item.topic, "position": item.position} for item in request.items]
# Batch generate
results = generate_model_manager.batch_generate(
items=items_data
)
# Build response
response_items = []
timestamp = datetime.now().isoformat()
for i, item in enumerate(request.items):
response_items.append(
GenerateResponse(
topic=item.topic,
position=item.position,
argument=results[i],
timestamp=timestamp
)
)
return BatchGenerateResponse(
results=response_items,
timestamp=timestamp
)
except Exception as e:
logger.error(f"Batch generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
|