Yassine Mhirsi
Refactor import statements in health.py and stance.py to use absolute paths for improved clarity and consistency
24cfb63
raw
history blame
2.92 kB
"""Stance detection endpoints"""
from fastapi import APIRouter, HTTPException
from datetime import datetime
import logging
from models import (
StanceRequest,
StanceResponse,
BatchStanceRequest,
BatchStanceResponse,
)
from services import stance_model_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"])
async def predict_stance(request: StanceRequest):
"""
Predict stance for a single topic-argument pair
- **topic**: The debate topic or statement (5-500 chars)
- **argument**: The argument to classify (5-1000 chars)
Returns predicted stance (PRO/CON) with confidence scores
"""
try:
# Make prediction
result = stance_model_manager.predict(request.topic, request.argument)
# Build response
response = StanceResponse(
topic=request.topic,
argument=request.argument,
predicted_stance=result["predicted_stance"],
confidence=result["confidence"],
probability_con=result["probability_con"],
probability_pro=result["probability_pro"],
timestamp=datetime.now().isoformat()
)
logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})")
return response
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"])
async def batch_predict_stance(request: BatchStanceRequest):
"""
Predict stance for multiple topic-argument pairs
- **items**: List of topic-argument pairs (max 50)
Returns predictions for all items
"""
try:
results = []
# Process each item
for item in request.items:
result = stance_model_manager.predict(item.topic, item.argument)
response = StanceResponse(
topic=item.topic,
argument=item.argument,
predicted_stance=result["predicted_stance"],
confidence=result["confidence"],
probability_con=result["probability_con"],
probability_pro=result["probability_pro"],
timestamp=datetime.now().isoformat()
)
results.append(response)
logger.info(f"Batch prediction completed: {len(results)} items")
return BatchStanceResponse(
results=results,
total_processed=len(results)
)
except Exception as e:
logger.error(f"Batch prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")