File size: 2,916 Bytes
9db766f 24cfb63 9db766f 24cfb63 9db766f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""Stance detection endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from models import (
StanceRequest,
StanceResponse,
BatchStanceRequest,
BatchStanceResponse,
)
from services import stance_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"])
async def predict_stance(request: StanceRequest):
"""
Predict stance for a single topic-argument pair
- **topic**: The debate topic or statement (5-500 chars)
- **argument**: The argument to classify (5-1000 chars)
Returns predicted stance (PRO/CON) with confidence scores
"""
try:
# Make prediction
result = stance_model_manager.predict(request.topic, request.argument)
# Build response
response = StanceResponse(
topic=request.topic,
argument=request.argument,
predicted_stance=result["predicted_stance"],
confidence=result["confidence"],
probability_con=result["probability_con"],
probability_pro=result["probability_pro"],
timestamp=datetime.now().isoformat()
)
logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})")
return response
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"])
async def batch_predict_stance(request: BatchStanceRequest):
"""
Predict stance for multiple topic-argument pairs
- **items**: List of topic-argument pairs (max 50)
Returns predictions for all items
"""
try:
results = []
# Process each item
for item in request.items:
result = stance_model_manager.predict(item.topic, item.argument)
response = StanceResponse(
topic=item.topic,
argument=item.argument,
predicted_stance=result["predicted_stance"],
confidence=result["confidence"],
probability_con=result["probability_con"],
probability_pro=result["probability_pro"],
timestamp=datetime.now().isoformat()
)
results.append(response)
logger.info(f"Batch prediction completed: {len(results)} items")
return BatchStanceResponse(
results=results,
total_processed=len(results)
)
except Exception as e:
logger.error(f"Batch prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")
|