Yassine Mhirsi
Refactor import statements in health.py and stance.py to use absolute paths for improved clarity and consistency
24cfb63
| """Stance detection endpoints""" | |
| from fastapi import APIRouter, HTTPException | |
| from datetime import datetime | |
| import logging | |
| from models import ( | |
| StanceRequest, | |
| StanceResponse, | |
| BatchStanceRequest, | |
| BatchStanceResponse, | |
| ) | |
| from services import stance_model_manager | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| async def predict_stance(request: StanceRequest): | |
| """ | |
| Predict stance for a single topic-argument pair | |
| - **topic**: The debate topic or statement (5-500 chars) | |
| - **argument**: The argument to classify (5-1000 chars) | |
| Returns predicted stance (PRO/CON) with confidence scores | |
| """ | |
| try: | |
| # Make prediction | |
| result = stance_model_manager.predict(request.topic, request.argument) | |
| # Build response | |
| response = StanceResponse( | |
| topic=request.topic, | |
| argument=request.argument, | |
| predicted_stance=result["predicted_stance"], | |
| confidence=result["confidence"], | |
| probability_con=result["probability_con"], | |
| probability_pro=result["probability_pro"], | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| logger.info(f"Prediction: {result['predicted_stance']} ({result['confidence']:.4f})") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def batch_predict_stance(request: BatchStanceRequest): | |
| """ | |
| Predict stance for multiple topic-argument pairs | |
| - **items**: List of topic-argument pairs (max 50) | |
| Returns predictions for all items | |
| """ | |
| try: | |
| results = [] | |
| # Process each item | |
| for item in request.items: | |
| result = stance_model_manager.predict(item.topic, item.argument) | |
| response = StanceResponse( | |
| topic=item.topic, | |
| argument=item.argument, | |
| predicted_stance=result["predicted_stance"], | |
| confidence=result["confidence"], | |
| probability_con=result["probability_con"], | |
| probability_pro=result["probability_pro"], | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| results.append(response) | |
| logger.info(f"Batch prediction completed: {len(results)} items") | |
| return BatchStanceResponse( | |
| results=results, | |
| total_processed=len(results) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Batch prediction error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") | |