from datetime import datetime import os import json import hashlib from pathlib import Path from dotenv import load_dotenv from google import genai from google.genai import types import numpy as np from sklearn.metrics.pairwise import cosine_similarity # Load environment variables from .env file load_dotenv() GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") if not GOOGLE_API_KEY: raise ValueError("GOOGLE_API_KEY is not set in environment variables.") # Get the path to topics.json relative to this file TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json" # Cache file for topic embeddings EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json" # Create a Google Generative AI client with the API key client = genai.Client(api_key=GOOGLE_API_KEY) def load_topics(): """Load topics from topics.json file.""" with open(TOPICS_FILE, 'r', encoding='utf-8') as f: data = json.load(f) return data.get("topics", []) def get_topics_hash(topics): """Generate a hash of the topics list to verify cache validity.""" topics_str = json.dumps(topics, sort_keys=True) return hashlib.md5(topics_str.encode('utf-8')).hexdigest() def load_cached_embeddings(): """Load cached topic embeddings if they exist and are valid.""" if not EMBEDDINGS_CACHE_FILE.exists(): return None try: with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f: cache_data = json.load(f) # Verify cache is valid by checking topics hash current_topics = load_topics() current_hash = get_topics_hash(current_topics) if cache_data.get("topics_hash") == current_hash: # Convert list embeddings back to numpy arrays embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])] return embeddings else: # Topics have changed, cache is invalid return None except (json.JSONDecodeError, KeyError, ValueError) as e: # Cache file is corrupted or invalid format print(f"Warning: Could not load cached embeddings: {e}") return None def save_cached_embeddings(embeddings, topics): """Save topic embeddings to cache file.""" topics_hash = get_topics_hash(topics) # Convert numpy arrays to lists for JSON serialization embeddings_list = [emb.tolist() for emb in embeddings] cache_data = { "topics_hash": topics_hash, "embeddings": embeddings_list, "model": "models/text-embedding-004", "cached_at": datetime.now().isoformat() } try: with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f: json.dump(cache_data, f, indent=2) print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}") except Exception as e: print(f"Warning: Could not save cached embeddings: {e}") def get_topic_embeddings(): """ Get topic embeddings, loading from cache if available, otherwise generating and caching them. Returns: numpy.ndarray: Array of topic embeddings """ topics = load_topics() # Try to load from cache first cached_embeddings = load_cached_embeddings() if cached_embeddings is not None: print(f"Loaded {len(cached_embeddings)} topic embeddings from cache") return np.array(cached_embeddings) # Cache miss or invalid - generate embeddings print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...") embedding_response = client.models.embed_content( model="models/text-embedding-004", contents=topics, config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") ) if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None: raise RuntimeError("Embedding API did not return embeddings.") embeddings = [np.array(e.values) for e in embedding_response.embeddings] # Save to cache for future use save_cached_embeddings(embeddings, topics) return np.array(embeddings) def find_most_similar_topic(input_text: str): """ Compare a single input text to all topics and return the highest cosine similarity. Uses cached topic embeddings to avoid re-embedding topics on every call. Args: input_text: The text to compare against topics Returns: dict: Contains 'topic', 'similarity', and 'index' of the most similar topic """ # Load topics from JSON file topics = load_topics() if not topics: raise ValueError("No topics found in topics.json") # Get topic embeddings (from cache or generate) topic_embeddings = get_topic_embeddings() # Only embed the input text (much faster!) embedding_response = client.models.embed_content( model="models/text-embedding-004", contents=[input_text], config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") ) if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None: raise RuntimeError("Embedding API did not return embeddings.") # Extract input embedding input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1) # Calculate cosine similarity between input and each topic similarities = cosine_similarity(input_embedding, topic_embeddings)[0] # Find the highest similarity max_index = np.argmax(similarities) max_similarity = similarities[max_index] most_similar_topic = topics[max_index] return { "topic": most_similar_topic, "similarity": float(max_similarity), "index": int(max_index) } if __name__ == "__main__": # Example usage #start time start_time = datetime.now() test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were." result = find_most_similar_topic(test_text) print(f"Input text: '{test_text}'") print(f"Most similar topic: '{result['topic']}'") print(f"Cosine similarity: {result['similarity']:.4f}%") #end time end_time = datetime.now() #in seconds print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")