|
|
"""Stance detection endpoints""" |
|
|
|
|
|
from fastapi import APIRouter, HTTPException |
|
|
from datetime import datetime |
|
|
import logging |
|
|
|
|
|
from models import ( |
|
|
StanceRequest, |
|
|
StanceResponse, |
|
|
BatchStanceRequest, |
|
|
BatchStanceResponse, |
|
|
) |
|
|
from services import stance_model_manager |
|
|
|
|
|
router = APIRouter() |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"]) |
|
|
async def predict_stance(request: StanceRequest): |
|
|
""" |
|
|
Predict stance for a single topic-argument pair |
|
|
|
|
|
- **topic**: The debate topic or statement (5-500 chars) |
|
|
- **argument**: The argument to classify (5-1000 chars) |
|
|
|
|
|
Returns predicted stance (PRO/CON) with confidence scores |
|
|
""" |
|
|
try: |
|
|
|
|
|
result = stance_model_manager.predict(request.topic, request.argument) |
|
|
|
|
|
|
|
|
response = StanceResponse( |
|
|
topic=request.topic, |
|
|
argument=request.argument, |
|
|
predicted_stance=result["predicted_stance"], |
|
|
confidence=result["confidence"], |
|
|
probability_con=result["probability_con"], |
|
|
probability_pro=result["probability_pro"], |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
|
|
|
logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Prediction error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
|
|
|
@router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"]) |
|
|
async def batch_predict_stance(request: BatchStanceRequest): |
|
|
""" |
|
|
Predict stance for multiple topic-argument pairs |
|
|
|
|
|
- **items**: List of topic-argument pairs (max 50) |
|
|
|
|
|
Returns predictions for all items |
|
|
""" |
|
|
try: |
|
|
results = [] |
|
|
|
|
|
|
|
|
for item in request.items: |
|
|
result = stance_model_manager.predict(item.topic, item.argument) |
|
|
|
|
|
response = StanceResponse( |
|
|
topic=item.topic, |
|
|
argument=item.argument, |
|
|
predicted_stance=result["predicted_stance"], |
|
|
confidence=result["confidence"], |
|
|
probability_con=result["probability_con"], |
|
|
probability_pro=result["probability_pro"], |
|
|
timestamp=datetime.now().isoformat() |
|
|
) |
|
|
results.append(response) |
|
|
|
|
|
logger.info(f"Batch prediction completed: {len(results)} items") |
|
|
|
|
|
return BatchStanceResponse( |
|
|
results=results, |
|
|
total_processed=len(results) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Batch prediction error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") |
|
|
|
|
|
|