Yassine Mhirsi
refactor: Update file path handling in TopicSimilarityService for improved error reporting and ensure absolute paths for topic and embeddings cache files
e97ac87
| """Service for finding similar topics using Google Generative AI embeddings""" | |
| import logging | |
| import json | |
| import hashlib | |
| from pathlib import Path | |
| from typing import Optional, List, Dict | |
| from datetime import datetime | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from google import genai | |
| from google.genai import types | |
| from config import GOOGLE_API_KEY, PROJECT_ROOT | |
| logger = logging.getLogger(__name__) | |
| # Paths for topics and cache files | |
| TOPICS_FILE = PROJECT_ROOT / "data" / "topics.json" | |
| EMBEDDINGS_CACHE_FILE = PROJECT_ROOT / "data" / "topic_embeddings_cache.json" | |
| class TopicSimilarityService: | |
| """Service for finding the most similar topic from a predefined list using embeddings""" | |
| def __init__(self): | |
| self.client = None | |
| self.topics = [] | |
| self.topic_embeddings = None | |
| self.initialized = False | |
| self.model_name = "models/text-embedding-004" | |
| def initialize(self): | |
| """Initialize the Google Generative AI client and load topic embeddings""" | |
| if self.initialized: | |
| logger.info("Topic similarity service already initialized") | |
| return | |
| if not GOOGLE_API_KEY: | |
| raise ValueError("GOOGLE_API_KEY not found in environment variables") | |
| try: | |
| logger.info("Initializing topic similarity service with Google Generative AI") | |
| # Create Google Generative AI client | |
| self.client = genai.Client(api_key=GOOGLE_API_KEY) | |
| # Load topics | |
| self.topics = self._load_topics() | |
| logger.info(f"Loaded {len(self.topics)} topics from {TOPICS_FILE}") | |
| # Load or generate topic embeddings | |
| self.topic_embeddings = self._get_topic_embeddings() | |
| logger.info(f"Loaded {len(self.topic_embeddings)} topic embeddings") | |
| self.initialized = True | |
| logger.info("✓ Topic similarity service initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error initializing topic similarity service: {str(e)}") | |
| raise RuntimeError(f"Failed to initialize topic similarity service: {str(e)}") | |
| def _load_topics(self) -> List[str]: | |
| """Load topics from topics.json file""" | |
| # Ensure path is absolute | |
| topics_file = Path(TOPICS_FILE).resolve() | |
| if not topics_file.exists(): | |
| raise FileNotFoundError( | |
| f"Topics file not found: {topics_file}\n" | |
| f"Current working directory: {Path.cwd()}\n" | |
| f"PROJECT_ROOT: {PROJECT_ROOT}\n" | |
| f"TOPICS_FILE path: {TOPICS_FILE}" | |
| ) | |
| try: | |
| with open(topics_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| topics = data.get("topics", []) | |
| if not topics: | |
| raise ValueError(f"No topics found in {topics_file}") | |
| return topics | |
| except (json.JSONDecodeError, KeyError) as e: | |
| raise ValueError(f"Error loading topics from {topics_file}: {str(e)}") | |
| def _get_topics_hash(self, topics: List[str]) -> str: | |
| """Generate a hash of the topics list to verify cache validity""" | |
| topics_str = json.dumps(topics, sort_keys=True) | |
| return hashlib.md5(topics_str.encode('utf-8')).hexdigest() | |
| def _load_cached_embeddings(self) -> Optional[np.ndarray]: | |
| """Load cached topic embeddings if they exist and are valid""" | |
| # Ensure path is absolute | |
| cache_file = Path(EMBEDDINGS_CACHE_FILE).resolve() | |
| if not cache_file.exists(): | |
| return None | |
| try: | |
| with open(cache_file, 'r', encoding='utf-8') as f: | |
| cache_data = json.load(f) | |
| # Verify cache is valid by checking topics hash | |
| current_hash = self._get_topics_hash(self.topics) | |
| if cache_data.get("topics_hash") == current_hash: | |
| # Convert list embeddings back to numpy arrays | |
| embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])] | |
| logger.info(f"Loaded {len(embeddings)} topic embeddings from cache") | |
| return np.array(embeddings) | |
| else: | |
| # Topics have changed, cache is invalid | |
| logger.info("Topics have changed, cache is invalid") | |
| return None | |
| except (json.JSONDecodeError, KeyError, ValueError) as e: | |
| logger.warning(f"Could not load cached embeddings: {e}") | |
| return None | |
| def _save_cached_embeddings(self, embeddings: np.ndarray): | |
| """Save topic embeddings to cache file""" | |
| topics_hash = self._get_topics_hash(self.topics) | |
| # Convert numpy arrays to lists for JSON serialization | |
| embeddings_list = [emb.tolist() for emb in embeddings] | |
| cache_data = { | |
| "topics_hash": topics_hash, | |
| "embeddings": embeddings_list, | |
| "model": self.model_name, | |
| "cached_at": datetime.now().isoformat() | |
| } | |
| try: | |
| # Ensure path is absolute and directory exists | |
| cache_file = Path(EMBEDDINGS_CACHE_FILE).resolve() | |
| cache_file.parent.mkdir(parents=True, exist_ok=True) | |
| with open(cache_file, 'w', encoding='utf-8') as f: | |
| json.dump(cache_data, f, indent=2) | |
| logger.info(f"Cached {len(embeddings)} topic embeddings to {cache_file}") | |
| except Exception as e: | |
| logger.warning(f"Could not save cached embeddings: {e}") | |
| def _get_topic_embeddings(self) -> np.ndarray: | |
| """ | |
| Get topic embeddings, loading from cache if available, otherwise generating and caching them | |
| Returns: | |
| numpy.ndarray: Array of topic embeddings | |
| """ | |
| # Try to load from cache first | |
| cached_embeddings = self._load_cached_embeddings() | |
| if cached_embeddings is not None: | |
| return cached_embeddings | |
| # Cache miss or invalid - generate embeddings | |
| logger.info(f"Generating embeddings for {len(self.topics)} topics (this may take a moment)...") | |
| try: | |
| embedding_response = self.client.models.embed_content( | |
| model=self.model_name, | |
| contents=self.topics, | |
| config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") | |
| ) | |
| if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None: | |
| raise RuntimeError("Embedding API did not return embeddings") | |
| embeddings = [np.array(e.values) for e in embedding_response.embeddings] | |
| embeddings_array = np.array(embeddings) | |
| # Save to cache for future use | |
| self._save_cached_embeddings(embeddings_array) | |
| return embeddings_array | |
| except Exception as e: | |
| logger.error(f"Error generating topic embeddings: {str(e)}") | |
| raise RuntimeError(f"Failed to generate topic embeddings: {str(e)}") | |
| def find_most_similar_topic(self, input_text: str) -> Dict[str, any]: | |
| """ | |
| Compare a single input text to all topics and return the highest cosine similarity | |
| Args: | |
| input_text: The text to compare against topics | |
| Returns: | |
| dict: Contains 'topic', 'similarity', and 'index' of the most similar topic | |
| """ | |
| if not self.initialized: | |
| self.initialize() | |
| if not input_text or not isinstance(input_text, str): | |
| raise ValueError("Input text must be a non-empty string") | |
| input_text = input_text.strip() | |
| if len(input_text) == 0: | |
| raise ValueError("Input text cannot be empty") | |
| if not self.topics: | |
| raise ValueError("No topics found in topics.json") | |
| try: | |
| # Embed the input text | |
| embedding_response = self.client.models.embed_content( | |
| model=self.model_name, | |
| contents=[input_text], | |
| config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") | |
| ) | |
| if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None: | |
| raise RuntimeError("Embedding API did not return embeddings") | |
| # Extract input embedding | |
| input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1) | |
| # Calculate cosine similarity between input and each topic | |
| similarities = cosine_similarity(input_embedding, self.topic_embeddings)[0] | |
| # Find the highest similarity | |
| max_index = np.argmax(similarities) | |
| max_similarity = similarities[max_index] | |
| most_similar_topic = self.topics[max_index] | |
| return { | |
| "topic": most_similar_topic, | |
| "similarity": float(max_similarity), | |
| "index": int(max_index) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error finding similar topic: {str(e)}") | |
| raise RuntimeError(f"Failed to find similar topic: {str(e)}") | |
| def batch_find_similar_topics(self, input_texts: List[str]) -> List[str]: | |
| """ | |
| Find the most similar topic for each input text | |
| Args: | |
| input_texts: List of input texts to compare against topics | |
| Returns: | |
| List of most similar topics (one per input text) | |
| """ | |
| if not self.initialized: | |
| self.initialize() | |
| if not input_texts or not isinstance(input_texts, list): | |
| raise ValueError("Input texts must be a non-empty list") | |
| results = [] | |
| for text in input_texts: | |
| try: | |
| result = self.find_most_similar_topic(text) | |
| results.append(result["topic"]) | |
| except Exception as e: | |
| logger.error(f"Error finding similar topic for text '{text[:50]}...': {str(e)}") | |
| results.append(None) # Or raise, depending on desired behavior | |
| return results | |
| # Initialize singleton instance | |
| topic_similarity_service = TopicSimilarityService() | |