"""Stance detection endpoints""" from fastapi import APIRouter, HTTPException from datetime import datetime import logging from models import ( StanceRequest, StanceResponse, BatchStanceRequest, BatchStanceResponse, ) from services import stance_model_manager router = APIRouter() logger = logging.getLogger(__name__) @router.post("/predict", response_model=StanceResponse, tags=["Stance Detection"]) async def predict_stance(request: StanceRequest): """ Predict stance for a single topic-argument pair - **topic**: The debate topic or statement (5-500 chars) - **argument**: The argument to classify (5-1000 chars) Returns predicted stance (PRO/CON) with confidence scores """ try: # Make prediction result = stance_model_manager.predict(request.topic, request.argument) # Build response response = StanceResponse( topic=request.topic, argument=request.argument, predicted_stance=result["predicted_stance"], confidence=result["confidence"], probability_con=result["probability_con"], probability_pro=result["probability_pro"], timestamp=datetime.now().isoformat() ) logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})") return response except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @router.post("/batch-predict", response_model=BatchStanceResponse, tags=["Stance Detection"]) async def batch_predict_stance(request: BatchStanceRequest): """ Predict stance for multiple topic-argument pairs - **items**: List of topic-argument pairs (max 50) Returns predictions for all items """ try: results = [] # Process each item for item in request.items: result = stance_model_manager.predict(item.topic, item.argument) response = StanceResponse( topic=item.topic, argument=item.argument, predicted_stance=result["predicted_stance"], confidence=result["confidence"], probability_con=result["probability_con"], probability_pro=result["probability_pro"], timestamp=datetime.now().isoformat() ) results.append(response) logger.info(f"Batch prediction completed: {len(results)} items") return BatchStanceResponse( results=results, total_processed=len(results) ) except Exception as e: logger.error(f"Batch prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}")