Yassine Mhirsi
first test
9db766f
raw
history blame
2.92 kB
"""Stance detection endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from models import (
StanceRequest,
StanceResponse,
BatchStanceRequest,
BatchStanceResponse,
)
from services import stance_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"])
async def predict_stance(request: StanceRequest):
"""
Predict stance for a single topic-argument pair
- **topic**: The debate topic or statement (5-500 chars)
- **argument**: The argument to classify (5-1000 chars)
Returns predicted stance (PRO/CON) with confidence scores
"""
try:
# Make prediction
result = stance_model_manager.predict(request.topic, request.argument)
# Build response
response = StanceResponse(
topic=request.topic,
argument=request.argument,
predicted_stance=result["predicted_stance"],
confidence=result["confidence"],
probability_con=result["probability_con"],
probability_pro=result["probability_pro"],
timestamp=datetime.now().isoformat()
)
logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})")
return response
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"])
async def batch_predict_stance(request: BatchStanceRequest):
"""
Predict stance for multiple topic-argument pairs
- **items**: List of topic-argument pairs (max 50)
Returns predictions for all items
"""
try:
results = []
# Process each item
for item in request.items:
result = stance_model_manager.predict(item.topic, item.argument)
response = StanceResponse(
topic=item.topic,
argument=item.argument,
predicted_stance=result["predicted_stance"],
confidence=result["confidence"],
probability_con=result["probability_con"],
probability_pro=result["probability_pro"],
timestamp=datetime.now().isoformat()
)
results.append(response)
logger.info(f"Batch prediction completed: {len(results)} items")
return BatchStanceResponse(
results=results,
total_processed=len(results)
)
except Exception as e:
logger.error(f"Batch prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")