Yassine Mhirsi
feat: Add topic similarity service using Google Generative AI embeddings, enabling improved topic matching and similarity analysis. Update topic extraction logic to utilize this service and enhance overall functionality.
a453c29
| """Service for analysis operations: processing arguments, extracting topics, and predicting stance""" | |
| import logging | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from services.database_service import database_service | |
| from services.topic_similarity_service import topic_similarity_service | |
| from services.stance_model_manager import stance_model_manager | |
| logger = logging.getLogger(__name__) | |
| class AnalysisService: | |
| """Service for analyzing arguments with topic extraction and stance prediction""" | |
| def __init__(self): | |
| self.table_name = "analysis_results" | |
| def _get_client(self): | |
| """Get Supabase client""" | |
| return database_service.get_client() | |
| def analyze_arguments( | |
| self, | |
| user_id: str, | |
| arguments: List[str] | |
| ) -> List[Dict]: | |
| """ | |
| Analyze arguments: extract topics, predict stance, and save to database | |
| Args: | |
| user_id: User UUID | |
| arguments: List of argument texts | |
| Returns: | |
| List of analysis results with argument, topic, and stance prediction | |
| """ | |
| try: | |
| if not arguments or len(arguments) == 0: | |
| raise ValueError("Arguments list cannot be empty") | |
| logger.info(f"Starting analysis for {len(arguments)} arguments for user {user_id}") | |
| # Step 1: Find most similar topics for all arguments | |
| logger.info("Step 1: Finding most similar topics...") | |
| topics = topic_similarity_service.batch_find_similar_topics(arguments) | |
| if len(topics) != len(arguments): | |
| raise RuntimeError(f"Topic extraction returned {len(topics)} topics but expected {len(arguments)}") | |
| # Step 2: Predict stance for each argument-topic pair | |
| logger.info("Step 2: Predicting stance for argument-topic pairs...") | |
| stance_results = [] | |
| for arg, topic in zip(arguments, topics): | |
| if topic is None: | |
| logger.warning(f"Skipping argument with null topic: {arg[:50]}...") | |
| continue | |
| stance_result = stance_model_manager.predict(topic, arg) | |
| stance_results.append({ | |
| "argument": arg, | |
| "topic": topic, | |
| "predicted_stance": stance_result["predicted_stance"], | |
| "confidence": stance_result["confidence"], | |
| "probability_con": stance_result["probability_con"], | |
| "probability_pro": stance_result["probability_pro"], | |
| }) | |
| # Step 3: Save all results to database | |
| logger.info(f"Step 3: Saving {len(stance_results)} results to database...") | |
| saved_results = self._save_analysis_results(user_id, stance_results) | |
| logger.info(f"Analysis completed: {len(saved_results)} results saved") | |
| return saved_results | |
| except Exception as e: | |
| logger.error(f"Error in analyze_arguments: {str(e)}") | |
| raise RuntimeError(f"Analysis failed: {str(e)}") | |
| def _save_analysis_results( | |
| self, | |
| user_id: str, | |
| results: List[Dict] | |
| ) -> List[Dict]: | |
| """ | |
| Save analysis results to database | |
| Args: | |
| user_id: User UUID | |
| results: List of analysis result dictionaries | |
| Returns: | |
| List of saved results with database IDs | |
| """ | |
| try: | |
| client = self._get_client() | |
| # Prepare data for batch insert | |
| insert_data = [] | |
| for result in results: | |
| insert_data.append({ | |
| "user_id": user_id, | |
| "argument": result["argument"], | |
| "topic": result["topic"], | |
| "predicted_stance": result["predicted_stance"], | |
| "confidence": result["confidence"], | |
| "probability_con": result["probability_con"], | |
| "probability_pro": result["probability_pro"], | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat() | |
| }) | |
| # Batch insert | |
| response = client.table(self.table_name).insert(insert_data).execute() | |
| if response.data: | |
| logger.info(f"Successfully saved {len(response.data)} analysis results") | |
| return response.data | |
| else: | |
| raise RuntimeError("Failed to save analysis results: no data returned") | |
| except Exception as e: | |
| logger.error(f"Error saving analysis results: {str(e)}") | |
| raise RuntimeError(f"Failed to save analysis results: {str(e)}") | |
| def get_user_analysis_results( | |
| self, | |
| user_id: str, | |
| limit: Optional[int] = 100, | |
| offset: Optional[int] = 0 | |
| ) -> List[Dict]: | |
| """ | |
| Get analysis results for a user | |
| Args: | |
| user_id: User UUID | |
| limit: Maximum number of results to return | |
| offset: Number of results to skip | |
| Returns: | |
| List of analysis results | |
| """ | |
| try: | |
| client = self._get_client() | |
| query = client.table(self.table_name)\ | |
| .select("*")\ | |
| .eq("user_id", user_id)\ | |
| .order("created_at", desc=True)\ | |
| .limit(limit)\ | |
| .offset(offset) | |
| result = query.execute() | |
| return result.data if result.data else [] | |
| except Exception as e: | |
| logger.error(f"Error getting user analysis results: {str(e)}") | |
| raise RuntimeError(f"Failed to get analysis results: {str(e)}") | |
| # Initialize singleton instance | |
| analysis_service = AnalysisService() | |