FastAPI-Backend-Models / services /analysis_service.py
Yassine Mhirsi
feat: Add topic similarity service using Google Generative AI embeddings, enabling improved topic matching and similarity analysis. Update topic extraction logic to utilize this service and enhance overall functionality.
a453c29
raw
history blame
6.09 kB
"""Service for analysis operations: processing arguments, extracting topics, and predicting stance"""
import logging
from typing import List, Dict, Optional
from datetime import datetime
from services.database_service import database_service
from services.topic_similarity_service import topic_similarity_service
from services.stance_model_manager import stance_model_manager
logger = logging.getLogger(__name__)
class AnalysisService:
"""Service for analyzing arguments with topic extraction and stance prediction"""
def __init__(self):
self.table_name = "analysis_results"
def _get_client(self):
"""Get Supabase client"""
return database_service.get_client()
def analyze_arguments(
self,
user_id: str,
arguments: List[str]
) -> List[Dict]:
"""
Analyze arguments: extract topics, predict stance, and save to database
Args:
user_id: User UUID
arguments: List of argument texts
Returns:
List of analysis results with argument, topic, and stance prediction
"""
try:
if not arguments or len(arguments) == 0:
raise ValueError("Arguments list cannot be empty")
logger.info(f"Starting analysis for {len(arguments)} arguments for user {user_id}")
# Step 1: Find most similar topics for all arguments
logger.info("Step 1: Finding most similar topics...")
topics = topic_similarity_service.batch_find_similar_topics(arguments)
if len(topics) != len(arguments):
raise RuntimeError(f"Topic extraction returned {len(topics)} topics but expected {len(arguments)}")
# Step 2: Predict stance for each argument-topic pair
logger.info("Step 2: Predicting stance for argument-topic pairs...")
stance_results = []
for arg, topic in zip(arguments, topics):
if topic is None:
logger.warning(f"Skipping argument with null topic: {arg[:50]}...")
continue
stance_result = stance_model_manager.predict(topic, arg)
stance_results.append({
"argument": arg,
"topic": topic,
"predicted_stance": stance_result["predicted_stance"],
"confidence": stance_result["confidence"],
"probability_con": stance_result["probability_con"],
"probability_pro": stance_result["probability_pro"],
})
# Step 3: Save all results to database
logger.info(f"Step 3: Saving {len(stance_results)} results to database...")
saved_results = self._save_analysis_results(user_id, stance_results)
logger.info(f"Analysis completed: {len(saved_results)} results saved")
return saved_results
except Exception as e:
logger.error(f"Error in analyze_arguments: {str(e)}")
raise RuntimeError(f"Analysis failed: {str(e)}")
def _save_analysis_results(
self,
user_id: str,
results: List[Dict]
) -> List[Dict]:
"""
Save analysis results to database
Args:
user_id: User UUID
results: List of analysis result dictionaries
Returns:
List of saved results with database IDs
"""
try:
client = self._get_client()
# Prepare data for batch insert
insert_data = []
for result in results:
insert_data.append({
"user_id": user_id,
"argument": result["argument"],
"topic": result["topic"],
"predicted_stance": result["predicted_stance"],
"confidence": result["confidence"],
"probability_con": result["probability_con"],
"probability_pro": result["probability_pro"],
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat()
})
# Batch insert
response = client.table(self.table_name).insert(insert_data).execute()
if response.data:
logger.info(f"Successfully saved {len(response.data)} analysis results")
return response.data
else:
raise RuntimeError("Failed to save analysis results: no data returned")
except Exception as e:
logger.error(f"Error saving analysis results: {str(e)}")
raise RuntimeError(f"Failed to save analysis results: {str(e)}")
def get_user_analysis_results(
self,
user_id: str,
limit: Optional[int] = 100,
offset: Optional[int] = 0
) -> List[Dict]:
"""
Get analysis results for a user
Args:
user_id: User UUID
limit: Maximum number of results to return
offset: Number of results to skip
Returns:
List of analysis results
"""
try:
client = self._get_client()
query = client.table(self.table_name)\
.select("*")\
.eq("user_id", user_id)\
.order("created_at", desc=True)\
.limit(limit)\
.offset(offset)
result = query.execute()
return result.data if result.data else []
except Exception as e:
logger.error(f"Error getting user analysis results: {str(e)}")
raise RuntimeError(f"Failed to get analysis results: {str(e)}")
# Initialize singleton instance
analysis_service = AnalysisService()