File size: 2,521 Bytes
0d13811
 
66ccc9d
0d13811
 
 
 
306b243
0d13811
 
 
 
 
306b243
 
0d13811
306b243
0d13811
306b243
 
0d13811
 
 
 
306b243
 
0d13811
 
 
 
306b243
 
 
0d13811
 
 
306b243
0d13811
 
 
 
 
306b243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Generation specific endpoints"""

from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging

from services import generate_model_manager
from models.generate import GenerateRequest, GenerateResponse, BatchGenerateRequest, BatchGenerateResponse

router = APIRouter()
logger = logging.getLogger(__name__)


@router.post("/predict", response_model=GenerateResponse, tags=["Text Generation"])
async def generate_argument(request: GenerateRequest):
    """
    Generate an argument for a given topic and position
    
    - **topic**: The debate topic
    - **position**: The stance (e.g. "positive", "negative")
    """
    try:
        # Generate text
        result = generate_model_manager.generate(
            topic=request.topic,
            position=request.position
        )
        
        # Build response
        response = GenerateResponse(
            topic=request.topic,
            position=request.position,
            argument=result,
            timestamp=datetime.now().isoformat()
        )
        
        logger.info(f"Generated argument: {result[:50]}...")
        return response
        
    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")


@router.post("/batch-predict", response_model=BatchGenerateResponse, tags=["Text Generation"])
async def batch_generate_argument(request: BatchGenerateRequest):
    """
    Generate arguments for multiple topic-position pairs
    """
    try:
        items_data = [{"topic": item.topic, "position": item.position} for item in request.items]
        
        # Batch generate
        results = generate_model_manager.batch_generate(
            items=items_data
        )
        
        # Build response
        response_items = []
        timestamp = datetime.now().isoformat()
        
        for i, item in enumerate(request.items):
            response_items.append(
                GenerateResponse(
                    topic=item.topic,
                    position=item.position,
                    argument=results[i],
                    timestamp=timestamp
                )
            )
            
        return BatchGenerateResponse(
            results=response_items,
            timestamp=timestamp
        )

    except Exception as e:
        logger.error(f"Batch generation error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")