File size: 2,916 Bytes
9db766f
 
 
 
 
 
24cfb63
9db766f
 
 
 
 
24cfb63
9db766f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Stance detection endpoints"""

from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging

from models import (
    StanceRequest,
    StanceResponse,
    BatchStanceRequest,
    BatchStanceResponse,
)
from services import stance_model_manager

router = APIRouter()
logger = logging.getLogger(__name__)


@router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"])
async def predict_stance(request: StanceRequest):
    """
    Predict stance for a single topic-argument pair
    
    - **topic**: The debate topic or statement (5-500 chars)
    - **argument**: The argument to classify (5-1000 chars)
    
    Returns predicted stance (PRO/CON) with confidence scores
    """
    try:
        # Make prediction
        result = stance_model_manager.predict(request.topic, request.argument)
        
        # Build response
        response = StanceResponse(
            topic=request.topic,
            argument=request.argument,
            predicted_stance=result["predicted_stance"],
            confidence=result["confidence"],
            probability_con=result["probability_con"],
            probability_pro=result["probability_pro"],
            timestamp=datetime.now().isoformat()
        )
        
        logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})")
        return response
        
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")


@router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"])
async def batch_predict_stance(request: BatchStanceRequest):
    """
    Predict stance for multiple topic-argument pairs
    
    - **items**: List of topic-argument pairs (max 50)
    
    Returns predictions for all items
    """
    try:
        results = []
        
        # Process each item
        for item in request.items:
            result = stance_model_manager.predict(item.topic, item.argument)
            
            response = StanceResponse(
                topic=item.topic,
                argument=item.argument,
                predicted_stance=result["predicted_stance"],
                confidence=result["confidence"],
                probability_con=result["probability_con"],
                probability_pro=result["probability_pro"],
                timestamp=datetime.now().isoformat()
            )
            results.append(response)
        
        logger.info(f"Batch prediction completed: {len(results)} items")
        
        return BatchStanceResponse(
            results=results,
            total_processed=len(results)
        )
        
    except Exception as e:
        logger.error(f"Batch prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")