Yassine Mhirsi commited on
Commit
430b54f
·
1 Parent(s): 6e8d513

feat: Add analysis endpoints and service for processing arguments, extracting topics, and predicting stance

Browse files
models/__init__.py CHANGED
@@ -44,6 +44,15 @@ from .user import (
44
  UserGetRequest,
45
  )
46
 
 
 
 
 
 
 
 
 
 
47
  # Import MCP-related schemas
48
  from .mcp_models import (
49
  ToolCallRequest,
@@ -85,6 +94,12 @@ __all__ = [
85
  "UserResponse",
86
  "UserUpdateNameRequest",
87
  "UserGetRequest",
 
 
 
 
 
 
88
  # MCP schemas
89
  "ToolCallRequest",
90
  "ToolCallResponse",
 
44
  UserGetRequest,
45
  )
46
 
47
+ # Import analysis-related schemas
48
+ from .analysis import (
49
+ AnalysisRequest,
50
+ AnalysisResponse,
51
+ AnalysisResult,
52
+ GetAnalysisRequest,
53
+ GetAnalysisResponse,
54
+ )
55
+
56
  # Import MCP-related schemas
57
  from .mcp_models import (
58
  ToolCallRequest,
 
94
  "UserResponse",
95
  "UserUpdateNameRequest",
96
  "UserGetRequest",
97
+ # Analysis schemas
98
+ "AnalysisRequest",
99
+ "AnalysisResponse",
100
+ "AnalysisResult",
101
+ "GetAnalysisRequest",
102
+ "GetAnalysisResponse",
103
  # MCP schemas
104
  "ToolCallRequest",
105
  "ToolCallResponse",
models/analysis.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for analysis endpoints"""
2
+
3
+ from pydantic import BaseModel, Field, ConfigDict
4
+ from typing import List, Optional
5
+
6
+
7
+ class AnalysisRequest(BaseModel):
8
+ """Request model for analysis with arguments"""
9
+ model_config = ConfigDict(
10
+ json_schema_extra={
11
+ "example": {
12
+ "arguments": [
13
+ "Social media companies must NOT be allowed to track people across websites.",
14
+ "I don't think universal basic income is a good idea — it'll disincentivize work.",
15
+ "We must invest in renewable energy to combat climate change."
16
+ ]
17
+ }
18
+ }
19
+ )
20
+
21
+ arguments: List[str] = Field(
22
+ ..., min_length=1, max_length=100,
23
+ description="List of argument texts to analyze (max 100)"
24
+ )
25
+
26
+
27
+ class AnalysisResult(BaseModel):
28
+ """Model for a single analysis result"""
29
+ model_config = ConfigDict(
30
+ json_schema_extra={
31
+ "example": {
32
+ "id": "123e4567-e89b-12d3-a456-426614174000",
33
+ "user_id": "123e4567-e89b-12d3-a456-426614174000",
34
+ "argument": "Social media companies must NOT be allowed to track people across websites.",
35
+ "topic": "social media tracking and cross-website user privacy",
36
+ "predicted_stance": "CON",
37
+ "confidence": 0.9234,
38
+ "probability_con": 0.9234,
39
+ "probability_pro": 0.0766,
40
+ "created_at": "2024-01-01T12:00:00Z",
41
+ "updated_at": "2024-01-01T12:00:00Z"
42
+ }
43
+ }
44
+ )
45
+
46
+ id: str = Field(..., description="Analysis result UUID")
47
+ user_id: str = Field(..., description="User UUID")
48
+ argument: str = Field(..., description="The argument text")
49
+ topic: str = Field(..., description="Extracted topic")
50
+ predicted_stance: str = Field(..., description="PRO or CON")
51
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
52
+ probability_con: float = Field(..., ge=0.0, le=1.0, description="Probability of CON")
53
+ probability_pro: float = Field(..., ge=0.0, le=1.0, description="Probability of PRO")
54
+ created_at: str = Field(..., description="Creation timestamp")
55
+ updated_at: str = Field(..., description="Last update timestamp")
56
+
57
+
58
+ class AnalysisResponse(BaseModel):
59
+ """Response model for analysis endpoint"""
60
+ model_config = ConfigDict(
61
+ json_schema_extra={
62
+ "example": {
63
+ "results": [
64
+ {
65
+ "id": "123e4567-e89b-12d3-a456-426614174000",
66
+ "user_id": "123e4567-e89b-12d3-a456-426614174000",
67
+ "argument": "Social media companies must NOT be allowed to track people across websites.",
68
+ "topic": "social media tracking and cross-website user privacy",
69
+ "predicted_stance": "CON",
70
+ "confidence": 0.9234,
71
+ "probability_con": 0.9234,
72
+ "probability_pro": 0.0766,
73
+ "created_at": "2024-01-01T12:00:00Z",
74
+ "updated_at": "2024-01-01T12:00:00Z"
75
+ }
76
+ ],
77
+ "total_processed": 1,
78
+ "timestamp": "2024-01-01T12:00:00Z"
79
+ }
80
+ }
81
+ )
82
+
83
+ results: List[AnalysisResult] = Field(..., description="List of analysis results")
84
+ total_processed: int = Field(..., description="Number of arguments processed")
85
+ timestamp: str = Field(..., description="Analysis timestamp")
86
+
87
+
88
+ class GetAnalysisRequest(BaseModel):
89
+ """Request model for getting user's analysis results"""
90
+ limit: Optional[int] = Field(100, ge=1, le=1000, description="Maximum number of results")
91
+ offset: Optional[int] = Field(0, ge=0, description="Number of results to skip")
92
+
93
+
94
+ class GetAnalysisResponse(BaseModel):
95
+ """Response model for getting analysis results"""
96
+ results: List[AnalysisResult] = Field(..., description="List of analysis results")
97
+ total: int = Field(..., description="Total number of results")
98
+ limit: int = Field(..., description="Limit used")
99
+ offset: int = Field(..., description="Offset used")
100
+
routes/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  """API route handlers"""
2
 
3
  from fastapi import APIRouter
4
- from . import root, health, stance, label, generate, topic, user
5
  from routes.tts_routes import router as audio_router
6
  # Create main router
7
  api_router = APIRouter()
@@ -14,6 +14,7 @@ api_router.include_router(label.router, prefix="/label")
14
  api_router.include_router(generate.router, prefix="/generate")
15
  api_router.include_router(topic.router, prefix="/topic")
16
  api_router.include_router(user.router, prefix="/user")
 
17
  api_router.include_router(audio_router)
18
 
19
  __all__ = ["api_router"]
 
1
  """API route handlers"""
2
 
3
  from fastapi import APIRouter
4
+ from . import root, health, stance, label, generate, topic, user, analysis
5
  from routes.tts_routes import router as audio_router
6
  # Create main router
7
  api_router = APIRouter()
 
14
  api_router.include_router(generate.router, prefix="/generate")
15
  api_router.include_router(topic.router, prefix="/topic")
16
  api_router.include_router(user.router, prefix="/user")
17
+ api_router.include_router(analysis.router, prefix="")
18
  api_router.include_router(audio_router)
19
 
20
  __all__ = ["api_router"]
routes/analysis.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Analysis endpoints for processing arguments, extracting topics, and predicting stance"""
2
+
3
+ from fastapi import APIRouter, HTTPException, Header, UploadFile, File
4
+ from typing import Optional
5
+ import logging
6
+ import csv
7
+ import io
8
+ from datetime import datetime
9
+
10
+ from services.analysis_service import analysis_service
11
+ from models.analysis import (
12
+ AnalysisRequest,
13
+ AnalysisResponse,
14
+ AnalysisResult,
15
+ GetAnalysisRequest,
16
+ GetAnalysisResponse,
17
+ )
18
+
19
+ router = APIRouter()
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def parse_csv_file(file_content: bytes) -> List[str]:
24
+ """
25
+ Parse CSV file and extract arguments
26
+
27
+ Args:
28
+ file_content: CSV file content as bytes
29
+
30
+ Returns:
31
+ List of argument strings
32
+ """
33
+ try:
34
+ # Decode bytes to string
35
+ content = file_content.decode('utf-8')
36
+
37
+ # Parse CSV
38
+ csv_reader = csv.reader(io.StringIO(content))
39
+ arguments = []
40
+
41
+ # Skip header row if present, extract arguments from first column or 'argument' column
42
+ rows = list(csv_reader)
43
+ if len(rows) == 0:
44
+ return []
45
+
46
+ # Check if first row is header
47
+ header = rows[0] if rows else []
48
+ start_idx = 1 if any(col.lower() in ['argument', 'text', 'content'] for col in header) else 0
49
+
50
+ # Find argument column index
51
+ arg_col_idx = 0
52
+ if start_idx == 1:
53
+ for idx, col in enumerate(header):
54
+ if col.lower() in ['argument', 'text', 'content']:
55
+ arg_col_idx = idx
56
+ break
57
+
58
+ # Extract arguments
59
+ for row in rows[start_idx:]:
60
+ if row and len(row) > arg_col_idx:
61
+ arg = row[arg_col_idx].strip()
62
+ if arg: # Only add non-empty arguments
63
+ arguments.append(arg)
64
+
65
+ return arguments
66
+
67
+ except Exception as e:
68
+ logger.error(f"Error parsing CSV file: {str(e)}")
69
+ raise ValueError(f"Failed to parse CSV file: {str(e)}")
70
+
71
+
72
+ @router.post("/analyse", response_model=AnalysisResponse, tags=["Analysis"])
73
+ async def analyse_arguments(
74
+ request: Optional[AnalysisRequest] = None,
75
+ file: Optional[UploadFile] = File(None),
76
+ x_user_id: Optional[str] = Header(None, alias="X-User-ID")
77
+ ):
78
+ """
79
+ Analyze arguments: extract topics and predict stance
80
+
81
+ Accepts either:
82
+ - JSON body with `arguments` array, OR
83
+ - CSV file upload with arguments
84
+
85
+ - **X-User-ID**: User UUID (required in header)
86
+ - **arguments**: List of argument texts (if using JSON)
87
+ - **file**: CSV file with arguments (if using file upload)
88
+
89
+ Returns analysis results with extracted topics and stance predictions
90
+ """
91
+ if not x_user_id:
92
+ raise HTTPException(status_code=400, detail="X-User-ID header is required")
93
+
94
+ try:
95
+ arguments = []
96
+
97
+ # Get arguments from either JSON body or CSV file
98
+ if file:
99
+ # Read CSV file
100
+ file_content = await file.read()
101
+ arguments = parse_csv_file(file_content)
102
+
103
+ if not arguments:
104
+ raise HTTPException(
105
+ status_code=400,
106
+ detail="CSV file is empty or contains no valid arguments"
107
+ )
108
+
109
+ logger.info(f"Parsed {len(arguments)} arguments from CSV file")
110
+
111
+ elif request and request.arguments:
112
+ arguments = request.arguments
113
+ logger.info(f"Received {len(arguments)} arguments from JSON body")
114
+ else:
115
+ raise HTTPException(
116
+ status_code=400,
117
+ detail="Either 'arguments' in JSON body or CSV 'file' must be provided"
118
+ )
119
+
120
+ # Analyze arguments
121
+ results = analysis_service.analyze_arguments(
122
+ user_id=x_user_id,
123
+ arguments=arguments
124
+ )
125
+
126
+ # Convert to response models
127
+ analysis_results = [
128
+ AnalysisResult(**result) for result in results
129
+ ]
130
+
131
+ logger.info(f"Analysis completed: {len(analysis_results)} results")
132
+
133
+ return AnalysisResponse(
134
+ results=analysis_results,
135
+ total_processed=len(analysis_results),
136
+ timestamp=datetime.now().isoformat()
137
+ )
138
+
139
+ except HTTPException:
140
+ raise
141
+ except ValueError as e:
142
+ logger.error(f"Validation error: {str(e)}")
143
+ raise HTTPException(status_code=400, detail=str(e))
144
+ except Exception as e:
145
+ logger.error(f"Analysis error: {str(e)}")
146
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
147
+
148
+
149
+ @router.get("/analyse", response_model=GetAnalysisResponse, tags=["Analysis"])
150
+ async def get_analysis_results(
151
+ limit: int = 100,
152
+ offset: int = 0,
153
+ x_user_id: Optional[str] = Header(None, alias="X-User-ID")
154
+ ):
155
+ """
156
+ Get user's analysis results
157
+
158
+ - **X-User-ID**: User UUID (required in header)
159
+ - **limit**: Maximum number of results to return (default: 100, max: 1000)
160
+ - **offset**: Number of results to skip (default: 0)
161
+
162
+ Returns paginated analysis results
163
+ """
164
+ if not x_user_id:
165
+ raise HTTPException(status_code=400, detail="X-User-ID header is required")
166
+
167
+ try:
168
+ results = analysis_service.get_user_analysis_results(
169
+ user_id=x_user_id,
170
+ limit=limit,
171
+ offset=offset
172
+ )
173
+
174
+ # Convert to response models
175
+ analysis_results = [
176
+ AnalysisResult(**result) for result in results
177
+ ]
178
+
179
+ return GetAnalysisResponse(
180
+ results=analysis_results,
181
+ total=len(analysis_results),
182
+ limit=limit,
183
+ offset=offset
184
+ )
185
+
186
+ except Exception as e:
187
+ logger.error(f"Error getting analysis results: {str(e)}")
188
+ raise HTTPException(status_code=500, detail=f"Failed to get analysis results: {str(e)}")
189
+
services/__init__.py CHANGED
@@ -10,6 +10,7 @@ from .tts_service import text_to_speech
10
  from .topic_service import TopicService, topic_service
11
  from .database_service import DatabaseService, database_service
12
  from .user_service import UserService, user_service
 
13
 
14
  __all__ = [
15
  "StanceModelManager",
@@ -24,6 +25,8 @@ __all__ = [
24
  "database_service",
25
  "UserService",
26
  "user_service",
 
 
27
 
28
  # NEW exports
29
  "speech_to_text",
 
10
  from .topic_service import TopicService, topic_service
11
  from .database_service import DatabaseService, database_service
12
  from .user_service import UserService, user_service
13
+ from .analysis_service import AnalysisService, analysis_service
14
 
15
  __all__ = [
16
  "StanceModelManager",
 
25
  "database_service",
26
  "UserService",
27
  "user_service",
28
+ "AnalysisService",
29
+ "analysis_service",
30
 
31
  # NEW exports
32
  "speech_to_text",
services/analysis_service.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service for analysis operations: processing arguments, extracting topics, and predicting stance"""
2
+
3
+ import logging
4
+ from typing import List, Dict, Optional
5
+ from datetime import datetime
6
+
7
+ from services.database_service import database_service
8
+ from services.topic_service import topic_service
9
+ from services.stance_model_manager import stance_model_manager
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class AnalysisService:
15
+ """Service for analyzing arguments with topic extraction and stance prediction"""
16
+
17
+ def __init__(self):
18
+ self.table_name = "analysis_results"
19
+
20
+ def _get_client(self):
21
+ """Get Supabase client"""
22
+ return database_service.get_client()
23
+
24
+ def analyze_arguments(
25
+ self,
26
+ user_id: str,
27
+ arguments: List[str]
28
+ ) -> List[Dict]:
29
+ """
30
+ Analyze arguments: extract topics, predict stance, and save to database
31
+
32
+ Args:
33
+ user_id: User UUID
34
+ arguments: List of argument texts
35
+
36
+ Returns:
37
+ List of analysis results with argument, topic, and stance prediction
38
+ """
39
+ try:
40
+ if not arguments or len(arguments) == 0:
41
+ raise ValueError("Arguments list cannot be empty")
42
+
43
+ logger.info(f"Starting analysis for {len(arguments)} arguments for user {user_id}")
44
+
45
+ # Step 1: Extract topics for all arguments
46
+ logger.info("Step 1: Extracting topics...")
47
+ topics = topic_service.batch_extract_topics(arguments)
48
+
49
+ if len(topics) != len(arguments):
50
+ raise RuntimeError(f"Topic extraction returned {len(topics)} topics but expected {len(arguments)}")
51
+
52
+ # Step 2: Predict stance for each argument-topic pair
53
+ logger.info("Step 2: Predicting stance for argument-topic pairs...")
54
+ stance_results = []
55
+ for arg, topic in zip(arguments, topics):
56
+ if topic is None:
57
+ logger.warning(f"Skipping argument with null topic: {arg[:50]}...")
58
+ continue
59
+
60
+ stance_result = stance_model_manager.predict(topic, arg)
61
+ stance_results.append({
62
+ "argument": arg,
63
+ "topic": topic,
64
+ "predicted_stance": stance_result["predicted_stance"],
65
+ "confidence": stance_result["confidence"],
66
+ "probability_con": stance_result["probability_con"],
67
+ "probability_pro": stance_result["probability_pro"],
68
+ })
69
+
70
+ # Step 3: Save all results to database
71
+ logger.info(f"Step 3: Saving {len(stance_results)} results to database...")
72
+ saved_results = self._save_analysis_results(user_id, stance_results)
73
+
74
+ logger.info(f"Analysis completed: {len(saved_results)} results saved")
75
+ return saved_results
76
+
77
+ except Exception as e:
78
+ logger.error(f"Error in analyze_arguments: {str(e)}")
79
+ raise RuntimeError(f"Analysis failed: {str(e)}")
80
+
81
+ def _save_analysis_results(
82
+ self,
83
+ user_id: str,
84
+ results: List[Dict]
85
+ ) -> List[Dict]:
86
+ """
87
+ Save analysis results to database
88
+
89
+ Args:
90
+ user_id: User UUID
91
+ results: List of analysis result dictionaries
92
+
93
+ Returns:
94
+ List of saved results with database IDs
95
+ """
96
+ try:
97
+ client = self._get_client()
98
+
99
+ # Prepare data for batch insert
100
+ insert_data = []
101
+ for result in results:
102
+ insert_data.append({
103
+ "user_id": user_id,
104
+ "argument": result["argument"],
105
+ "topic": result["topic"],
106
+ "predicted_stance": result["predicted_stance"],
107
+ "confidence": result["confidence"],
108
+ "probability_con": result["probability_con"],
109
+ "probability_pro": result["probability_pro"],
110
+ "created_at": datetime.utcnow().isoformat(),
111
+ "updated_at": datetime.utcnow().isoformat()
112
+ })
113
+
114
+ # Batch insert
115
+ response = client.table(self.table_name).insert(insert_data).execute()
116
+
117
+ if response.data:
118
+ logger.info(f"Successfully saved {len(response.data)} analysis results")
119
+ return response.data
120
+ else:
121
+ raise RuntimeError("Failed to save analysis results: no data returned")
122
+
123
+ except Exception as e:
124
+ logger.error(f"Error saving analysis results: {str(e)}")
125
+ raise RuntimeError(f"Failed to save analysis results: {str(e)}")
126
+
127
+ def get_user_analysis_results(
128
+ self,
129
+ user_id: str,
130
+ limit: Optional[int] = 100,
131
+ offset: Optional[int] = 0
132
+ ) -> List[Dict]:
133
+ """
134
+ Get analysis results for a user
135
+
136
+ Args:
137
+ user_id: User UUID
138
+ limit: Maximum number of results to return
139
+ offset: Number of results to skip
140
+
141
+ Returns:
142
+ List of analysis results
143
+ """
144
+ try:
145
+ client = self._get_client()
146
+
147
+ query = client.table(self.table_name)\
148
+ .select("*")\
149
+ .eq("user_id", user_id)\
150
+ .order("created_at", desc=True)\
151
+ .limit(limit)\
152
+ .offset(offset)
153
+
154
+ result = query.execute()
155
+
156
+ return result.data if result.data else []
157
+
158
+ except Exception as e:
159
+ logger.error(f"Error getting user analysis results: {str(e)}")
160
+ raise RuntimeError(f"Failed to get analysis results: {str(e)}")
161
+
162
+
163
+ # Initialize singleton instance
164
+ analysis_service = AnalysisService()
165
+