Yassine Mhirsi
commited on
Commit
·
430b54f
1
Parent(s):
6e8d513
feat: Add analysis endpoints and service for processing arguments, extracting topics, and predicting stance
Browse files- models/__init__.py +15 -0
- models/analysis.py +100 -0
- routes/__init__.py +2 -1
- routes/analysis.py +189 -0
- services/__init__.py +3 -0
- services/analysis_service.py +165 -0
models/__init__.py
CHANGED
|
@@ -44,6 +44,15 @@ from .user import (
|
|
| 44 |
UserGetRequest,
|
| 45 |
)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Import MCP-related schemas
|
| 48 |
from .mcp_models import (
|
| 49 |
ToolCallRequest,
|
|
@@ -85,6 +94,12 @@ __all__ = [
|
|
| 85 |
"UserResponse",
|
| 86 |
"UserUpdateNameRequest",
|
| 87 |
"UserGetRequest",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
# MCP schemas
|
| 89 |
"ToolCallRequest",
|
| 90 |
"ToolCallResponse",
|
|
|
|
| 44 |
UserGetRequest,
|
| 45 |
)
|
| 46 |
|
| 47 |
+
# Import analysis-related schemas
|
| 48 |
+
from .analysis import (
|
| 49 |
+
AnalysisRequest,
|
| 50 |
+
AnalysisResponse,
|
| 51 |
+
AnalysisResult,
|
| 52 |
+
GetAnalysisRequest,
|
| 53 |
+
GetAnalysisResponse,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
# Import MCP-related schemas
|
| 57 |
from .mcp_models import (
|
| 58 |
ToolCallRequest,
|
|
|
|
| 94 |
"UserResponse",
|
| 95 |
"UserUpdateNameRequest",
|
| 96 |
"UserGetRequest",
|
| 97 |
+
# Analysis schemas
|
| 98 |
+
"AnalysisRequest",
|
| 99 |
+
"AnalysisResponse",
|
| 100 |
+
"AnalysisResult",
|
| 101 |
+
"GetAnalysisRequest",
|
| 102 |
+
"GetAnalysisResponse",
|
| 103 |
# MCP schemas
|
| 104 |
"ToolCallRequest",
|
| 105 |
"ToolCallResponse",
|
models/analysis.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for analysis endpoints"""
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AnalysisRequest(BaseModel):
|
| 8 |
+
"""Request model for analysis with arguments"""
|
| 9 |
+
model_config = ConfigDict(
|
| 10 |
+
json_schema_extra={
|
| 11 |
+
"example": {
|
| 12 |
+
"arguments": [
|
| 13 |
+
"Social media companies must NOT be allowed to track people across websites.",
|
| 14 |
+
"I don't think universal basic income is a good idea — it'll disincentivize work.",
|
| 15 |
+
"We must invest in renewable energy to combat climate change."
|
| 16 |
+
]
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
arguments: List[str] = Field(
|
| 22 |
+
..., min_length=1, max_length=100,
|
| 23 |
+
description="List of argument texts to analyze (max 100)"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AnalysisResult(BaseModel):
|
| 28 |
+
"""Model for a single analysis result"""
|
| 29 |
+
model_config = ConfigDict(
|
| 30 |
+
json_schema_extra={
|
| 31 |
+
"example": {
|
| 32 |
+
"id": "123e4567-e89b-12d3-a456-426614174000",
|
| 33 |
+
"user_id": "123e4567-e89b-12d3-a456-426614174000",
|
| 34 |
+
"argument": "Social media companies must NOT be allowed to track people across websites.",
|
| 35 |
+
"topic": "social media tracking and cross-website user privacy",
|
| 36 |
+
"predicted_stance": "CON",
|
| 37 |
+
"confidence": 0.9234,
|
| 38 |
+
"probability_con": 0.9234,
|
| 39 |
+
"probability_pro": 0.0766,
|
| 40 |
+
"created_at": "2024-01-01T12:00:00Z",
|
| 41 |
+
"updated_at": "2024-01-01T12:00:00Z"
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
id: str = Field(..., description="Analysis result UUID")
|
| 47 |
+
user_id: str = Field(..., description="User UUID")
|
| 48 |
+
argument: str = Field(..., description="The argument text")
|
| 49 |
+
topic: str = Field(..., description="Extracted topic")
|
| 50 |
+
predicted_stance: str = Field(..., description="PRO or CON")
|
| 51 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
| 52 |
+
probability_con: float = Field(..., ge=0.0, le=1.0, description="Probability of CON")
|
| 53 |
+
probability_pro: float = Field(..., ge=0.0, le=1.0, description="Probability of PRO")
|
| 54 |
+
created_at: str = Field(..., description="Creation timestamp")
|
| 55 |
+
updated_at: str = Field(..., description="Last update timestamp")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AnalysisResponse(BaseModel):
|
| 59 |
+
"""Response model for analysis endpoint"""
|
| 60 |
+
model_config = ConfigDict(
|
| 61 |
+
json_schema_extra={
|
| 62 |
+
"example": {
|
| 63 |
+
"results": [
|
| 64 |
+
{
|
| 65 |
+
"id": "123e4567-e89b-12d3-a456-426614174000",
|
| 66 |
+
"user_id": "123e4567-e89b-12d3-a456-426614174000",
|
| 67 |
+
"argument": "Social media companies must NOT be allowed to track people across websites.",
|
| 68 |
+
"topic": "social media tracking and cross-website user privacy",
|
| 69 |
+
"predicted_stance": "CON",
|
| 70 |
+
"confidence": 0.9234,
|
| 71 |
+
"probability_con": 0.9234,
|
| 72 |
+
"probability_pro": 0.0766,
|
| 73 |
+
"created_at": "2024-01-01T12:00:00Z",
|
| 74 |
+
"updated_at": "2024-01-01T12:00:00Z"
|
| 75 |
+
}
|
| 76 |
+
],
|
| 77 |
+
"total_processed": 1,
|
| 78 |
+
"timestamp": "2024-01-01T12:00:00Z"
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
results: List[AnalysisResult] = Field(..., description="List of analysis results")
|
| 84 |
+
total_processed: int = Field(..., description="Number of arguments processed")
|
| 85 |
+
timestamp: str = Field(..., description="Analysis timestamp")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class GetAnalysisRequest(BaseModel):
|
| 89 |
+
"""Request model for getting user's analysis results"""
|
| 90 |
+
limit: Optional[int] = Field(100, ge=1, le=1000, description="Maximum number of results")
|
| 91 |
+
offset: Optional[int] = Field(0, ge=0, description="Number of results to skip")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class GetAnalysisResponse(BaseModel):
|
| 95 |
+
"""Response model for getting analysis results"""
|
| 96 |
+
results: List[AnalysisResult] = Field(..., description="List of analysis results")
|
| 97 |
+
total: int = Field(..., description="Total number of results")
|
| 98 |
+
limit: int = Field(..., description="Limit used")
|
| 99 |
+
offset: int = Field(..., description="Offset used")
|
| 100 |
+
|
routes/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""API route handlers"""
|
| 2 |
|
| 3 |
from fastapi import APIRouter
|
| 4 |
-
from . import root, health, stance, label, generate, topic, user
|
| 5 |
from routes.tts_routes import router as audio_router
|
| 6 |
# Create main router
|
| 7 |
api_router = APIRouter()
|
|
@@ -14,6 +14,7 @@ api_router.include_router(label.router, prefix="/label")
|
|
| 14 |
api_router.include_router(generate.router, prefix="/generate")
|
| 15 |
api_router.include_router(topic.router, prefix="/topic")
|
| 16 |
api_router.include_router(user.router, prefix="/user")
|
|
|
|
| 17 |
api_router.include_router(audio_router)
|
| 18 |
|
| 19 |
__all__ = ["api_router"]
|
|
|
|
| 1 |
"""API route handlers"""
|
| 2 |
|
| 3 |
from fastapi import APIRouter
|
| 4 |
+
from . import root, health, stance, label, generate, topic, user, analysis
|
| 5 |
from routes.tts_routes import router as audio_router
|
| 6 |
# Create main router
|
| 7 |
api_router = APIRouter()
|
|
|
|
| 14 |
api_router.include_router(generate.router, prefix="/generate")
|
| 15 |
api_router.include_router(topic.router, prefix="/topic")
|
| 16 |
api_router.include_router(user.router, prefix="/user")
|
| 17 |
+
api_router.include_router(analysis.router, prefix="")
|
| 18 |
api_router.include_router(audio_router)
|
| 19 |
|
| 20 |
__all__ = ["api_router"]
|
routes/analysis.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analysis endpoints for processing arguments, extracting topics, and predicting stance"""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, HTTPException, Header, UploadFile, File
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import logging
|
| 6 |
+
import csv
|
| 7 |
+
import io
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
from services.analysis_service import analysis_service
|
| 11 |
+
from models.analysis import (
|
| 12 |
+
AnalysisRequest,
|
| 13 |
+
AnalysisResponse,
|
| 14 |
+
AnalysisResult,
|
| 15 |
+
GetAnalysisRequest,
|
| 16 |
+
GetAnalysisResponse,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
router = APIRouter()
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_csv_file(file_content: bytes) -> List[str]:
|
| 24 |
+
"""
|
| 25 |
+
Parse CSV file and extract arguments
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
file_content: CSV file content as bytes
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
List of argument strings
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
# Decode bytes to string
|
| 35 |
+
content = file_content.decode('utf-8')
|
| 36 |
+
|
| 37 |
+
# Parse CSV
|
| 38 |
+
csv_reader = csv.reader(io.StringIO(content))
|
| 39 |
+
arguments = []
|
| 40 |
+
|
| 41 |
+
# Skip header row if present, extract arguments from first column or 'argument' column
|
| 42 |
+
rows = list(csv_reader)
|
| 43 |
+
if len(rows) == 0:
|
| 44 |
+
return []
|
| 45 |
+
|
| 46 |
+
# Check if first row is header
|
| 47 |
+
header = rows[0] if rows else []
|
| 48 |
+
start_idx = 1 if any(col.lower() in ['argument', 'text', 'content'] for col in header) else 0
|
| 49 |
+
|
| 50 |
+
# Find argument column index
|
| 51 |
+
arg_col_idx = 0
|
| 52 |
+
if start_idx == 1:
|
| 53 |
+
for idx, col in enumerate(header):
|
| 54 |
+
if col.lower() in ['argument', 'text', 'content']:
|
| 55 |
+
arg_col_idx = idx
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
# Extract arguments
|
| 59 |
+
for row in rows[start_idx:]:
|
| 60 |
+
if row and len(row) > arg_col_idx:
|
| 61 |
+
arg = row[arg_col_idx].strip()
|
| 62 |
+
if arg: # Only add non-empty arguments
|
| 63 |
+
arguments.append(arg)
|
| 64 |
+
|
| 65 |
+
return arguments
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Error parsing CSV file: {str(e)}")
|
| 69 |
+
raise ValueError(f"Failed to parse CSV file: {str(e)}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@router.post("/analyse", response_model=AnalysisResponse, tags=["Analysis"])
|
| 73 |
+
async def analyse_arguments(
|
| 74 |
+
request: Optional[AnalysisRequest] = None,
|
| 75 |
+
file: Optional[UploadFile] = File(None),
|
| 76 |
+
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Analyze arguments: extract topics and predict stance
|
| 80 |
+
|
| 81 |
+
Accepts either:
|
| 82 |
+
- JSON body with `arguments` array, OR
|
| 83 |
+
- CSV file upload with arguments
|
| 84 |
+
|
| 85 |
+
- **X-User-ID**: User UUID (required in header)
|
| 86 |
+
- **arguments**: List of argument texts (if using JSON)
|
| 87 |
+
- **file**: CSV file with arguments (if using file upload)
|
| 88 |
+
|
| 89 |
+
Returns analysis results with extracted topics and stance predictions
|
| 90 |
+
"""
|
| 91 |
+
if not x_user_id:
|
| 92 |
+
raise HTTPException(status_code=400, detail="X-User-ID header is required")
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
arguments = []
|
| 96 |
+
|
| 97 |
+
# Get arguments from either JSON body or CSV file
|
| 98 |
+
if file:
|
| 99 |
+
# Read CSV file
|
| 100 |
+
file_content = await file.read()
|
| 101 |
+
arguments = parse_csv_file(file_content)
|
| 102 |
+
|
| 103 |
+
if not arguments:
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=400,
|
| 106 |
+
detail="CSV file is empty or contains no valid arguments"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
logger.info(f"Parsed {len(arguments)} arguments from CSV file")
|
| 110 |
+
|
| 111 |
+
elif request and request.arguments:
|
| 112 |
+
arguments = request.arguments
|
| 113 |
+
logger.info(f"Received {len(arguments)} arguments from JSON body")
|
| 114 |
+
else:
|
| 115 |
+
raise HTTPException(
|
| 116 |
+
status_code=400,
|
| 117 |
+
detail="Either 'arguments' in JSON body or CSV 'file' must be provided"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Analyze arguments
|
| 121 |
+
results = analysis_service.analyze_arguments(
|
| 122 |
+
user_id=x_user_id,
|
| 123 |
+
arguments=arguments
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Convert to response models
|
| 127 |
+
analysis_results = [
|
| 128 |
+
AnalysisResult(**result) for result in results
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
logger.info(f"Analysis completed: {len(analysis_results)} results")
|
| 132 |
+
|
| 133 |
+
return AnalysisResponse(
|
| 134 |
+
results=analysis_results,
|
| 135 |
+
total_processed=len(analysis_results),
|
| 136 |
+
timestamp=datetime.now().isoformat()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
except HTTPException:
|
| 140 |
+
raise
|
| 141 |
+
except ValueError as e:
|
| 142 |
+
logger.error(f"Validation error: {str(e)}")
|
| 143 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Analysis error: {str(e)}")
|
| 146 |
+
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@router.get("/analyse", response_model=GetAnalysisResponse, tags=["Analysis"])
|
| 150 |
+
async def get_analysis_results(
|
| 151 |
+
limit: int = 100,
|
| 152 |
+
offset: int = 0,
|
| 153 |
+
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Get user's analysis results
|
| 157 |
+
|
| 158 |
+
- **X-User-ID**: User UUID (required in header)
|
| 159 |
+
- **limit**: Maximum number of results to return (default: 100, max: 1000)
|
| 160 |
+
- **offset**: Number of results to skip (default: 0)
|
| 161 |
+
|
| 162 |
+
Returns paginated analysis results
|
| 163 |
+
"""
|
| 164 |
+
if not x_user_id:
|
| 165 |
+
raise HTTPException(status_code=400, detail="X-User-ID header is required")
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
results = analysis_service.get_user_analysis_results(
|
| 169 |
+
user_id=x_user_id,
|
| 170 |
+
limit=limit,
|
| 171 |
+
offset=offset
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Convert to response models
|
| 175 |
+
analysis_results = [
|
| 176 |
+
AnalysisResult(**result) for result in results
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
return GetAnalysisResponse(
|
| 180 |
+
results=analysis_results,
|
| 181 |
+
total=len(analysis_results),
|
| 182 |
+
limit=limit,
|
| 183 |
+
offset=offset
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.error(f"Error getting analysis results: {str(e)}")
|
| 188 |
+
raise HTTPException(status_code=500, detail=f"Failed to get analysis results: {str(e)}")
|
| 189 |
+
|
services/__init__.py
CHANGED
|
@@ -10,6 +10,7 @@ from .tts_service import text_to_speech
|
|
| 10 |
from .topic_service import TopicService, topic_service
|
| 11 |
from .database_service import DatabaseService, database_service
|
| 12 |
from .user_service import UserService, user_service
|
|
|
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
"StanceModelManager",
|
|
@@ -24,6 +25,8 @@ __all__ = [
|
|
| 24 |
"database_service",
|
| 25 |
"UserService",
|
| 26 |
"user_service",
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# NEW exports
|
| 29 |
"speech_to_text",
|
|
|
|
| 10 |
from .topic_service import TopicService, topic_service
|
| 11 |
from .database_service import DatabaseService, database_service
|
| 12 |
from .user_service import UserService, user_service
|
| 13 |
+
from .analysis_service import AnalysisService, analysis_service
|
| 14 |
|
| 15 |
__all__ = [
|
| 16 |
"StanceModelManager",
|
|
|
|
| 25 |
"database_service",
|
| 26 |
"UserService",
|
| 27 |
"user_service",
|
| 28 |
+
"AnalysisService",
|
| 29 |
+
"analysis_service",
|
| 30 |
|
| 31 |
# NEW exports
|
| 32 |
"speech_to_text",
|
services/analysis_service.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Service for analysis operations: processing arguments, extracting topics, and predicting stance"""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
from services.database_service import database_service
|
| 8 |
+
from services.topic_service import topic_service
|
| 9 |
+
from services.stance_model_manager import stance_model_manager
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AnalysisService:
|
| 15 |
+
"""Service for analyzing arguments with topic extraction and stance prediction"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.table_name = "analysis_results"
|
| 19 |
+
|
| 20 |
+
def _get_client(self):
|
| 21 |
+
"""Get Supabase client"""
|
| 22 |
+
return database_service.get_client()
|
| 23 |
+
|
| 24 |
+
def analyze_arguments(
|
| 25 |
+
self,
|
| 26 |
+
user_id: str,
|
| 27 |
+
arguments: List[str]
|
| 28 |
+
) -> List[Dict]:
|
| 29 |
+
"""
|
| 30 |
+
Analyze arguments: extract topics, predict stance, and save to database
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
user_id: User UUID
|
| 34 |
+
arguments: List of argument texts
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
List of analysis results with argument, topic, and stance prediction
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
if not arguments or len(arguments) == 0:
|
| 41 |
+
raise ValueError("Arguments list cannot be empty")
|
| 42 |
+
|
| 43 |
+
logger.info(f"Starting analysis for {len(arguments)} arguments for user {user_id}")
|
| 44 |
+
|
| 45 |
+
# Step 1: Extract topics for all arguments
|
| 46 |
+
logger.info("Step 1: Extracting topics...")
|
| 47 |
+
topics = topic_service.batch_extract_topics(arguments)
|
| 48 |
+
|
| 49 |
+
if len(topics) != len(arguments):
|
| 50 |
+
raise RuntimeError(f"Topic extraction returned {len(topics)} topics but expected {len(arguments)}")
|
| 51 |
+
|
| 52 |
+
# Step 2: Predict stance for each argument-topic pair
|
| 53 |
+
logger.info("Step 2: Predicting stance for argument-topic pairs...")
|
| 54 |
+
stance_results = []
|
| 55 |
+
for arg, topic in zip(arguments, topics):
|
| 56 |
+
if topic is None:
|
| 57 |
+
logger.warning(f"Skipping argument with null topic: {arg[:50]}...")
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
stance_result = stance_model_manager.predict(topic, arg)
|
| 61 |
+
stance_results.append({
|
| 62 |
+
"argument": arg,
|
| 63 |
+
"topic": topic,
|
| 64 |
+
"predicted_stance": stance_result["predicted_stance"],
|
| 65 |
+
"confidence": stance_result["confidence"],
|
| 66 |
+
"probability_con": stance_result["probability_con"],
|
| 67 |
+
"probability_pro": stance_result["probability_pro"],
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
# Step 3: Save all results to database
|
| 71 |
+
logger.info(f"Step 3: Saving {len(stance_results)} results to database...")
|
| 72 |
+
saved_results = self._save_analysis_results(user_id, stance_results)
|
| 73 |
+
|
| 74 |
+
logger.info(f"Analysis completed: {len(saved_results)} results saved")
|
| 75 |
+
return saved_results
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"Error in analyze_arguments: {str(e)}")
|
| 79 |
+
raise RuntimeError(f"Analysis failed: {str(e)}")
|
| 80 |
+
|
| 81 |
+
def _save_analysis_results(
|
| 82 |
+
self,
|
| 83 |
+
user_id: str,
|
| 84 |
+
results: List[Dict]
|
| 85 |
+
) -> List[Dict]:
|
| 86 |
+
"""
|
| 87 |
+
Save analysis results to database
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
user_id: User UUID
|
| 91 |
+
results: List of analysis result dictionaries
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
List of saved results with database IDs
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
client = self._get_client()
|
| 98 |
+
|
| 99 |
+
# Prepare data for batch insert
|
| 100 |
+
insert_data = []
|
| 101 |
+
for result in results:
|
| 102 |
+
insert_data.append({
|
| 103 |
+
"user_id": user_id,
|
| 104 |
+
"argument": result["argument"],
|
| 105 |
+
"topic": result["topic"],
|
| 106 |
+
"predicted_stance": result["predicted_stance"],
|
| 107 |
+
"confidence": result["confidence"],
|
| 108 |
+
"probability_con": result["probability_con"],
|
| 109 |
+
"probability_pro": result["probability_pro"],
|
| 110 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 111 |
+
"updated_at": datetime.utcnow().isoformat()
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
# Batch insert
|
| 115 |
+
response = client.table(self.table_name).insert(insert_data).execute()
|
| 116 |
+
|
| 117 |
+
if response.data:
|
| 118 |
+
logger.info(f"Successfully saved {len(response.data)} analysis results")
|
| 119 |
+
return response.data
|
| 120 |
+
else:
|
| 121 |
+
raise RuntimeError("Failed to save analysis results: no data returned")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Error saving analysis results: {str(e)}")
|
| 125 |
+
raise RuntimeError(f"Failed to save analysis results: {str(e)}")
|
| 126 |
+
|
| 127 |
+
def get_user_analysis_results(
|
| 128 |
+
self,
|
| 129 |
+
user_id: str,
|
| 130 |
+
limit: Optional[int] = 100,
|
| 131 |
+
offset: Optional[int] = 0
|
| 132 |
+
) -> List[Dict]:
|
| 133 |
+
"""
|
| 134 |
+
Get analysis results for a user
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
user_id: User UUID
|
| 138 |
+
limit: Maximum number of results to return
|
| 139 |
+
offset: Number of results to skip
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
List of analysis results
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
client = self._get_client()
|
| 146 |
+
|
| 147 |
+
query = client.table(self.table_name)\
|
| 148 |
+
.select("*")\
|
| 149 |
+
.eq("user_id", user_id)\
|
| 150 |
+
.order("created_at", desc=True)\
|
| 151 |
+
.limit(limit)\
|
| 152 |
+
.offset(offset)
|
| 153 |
+
|
| 154 |
+
result = query.execute()
|
| 155 |
+
|
| 156 |
+
return result.data if result.data else []
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Error getting user analysis results: {str(e)}")
|
| 160 |
+
raise RuntimeError(f"Failed to get analysis results: {str(e)}")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Initialize singleton instance
|
| 164 |
+
analysis_service = AnalysisService()
|
| 165 |
+
|