File size: 6,395 Bytes
22ad0ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from datetime import datetime
import os
import json
import hashlib
from pathlib import Path
from dotenv import load_dotenv
from google import genai
from google.genai import types
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Load environment variables from .env file
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is not set in environment variables.")
# Get the path to topics.json relative to this file
TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json"
# Cache file for topic embeddings
EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json"
# Create a Google Generative AI client with the API key
client = genai.Client(api_key=GOOGLE_API_KEY)
def load_topics():
"""Load topics from topics.json file."""
with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
data = json.load(f)
return data.get("topics", [])
def get_topics_hash(topics):
"""Generate a hash of the topics list to verify cache validity."""
topics_str = json.dumps(topics, sort_keys=True)
return hashlib.md5(topics_str.encode('utf-8')).hexdigest()
def load_cached_embeddings():
"""Load cached topic embeddings if they exist and are valid."""
if not EMBEDDINGS_CACHE_FILE.exists():
return None
try:
with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
# Verify cache is valid by checking topics hash
current_topics = load_topics()
current_hash = get_topics_hash(current_topics)
if cache_data.get("topics_hash") == current_hash:
# Convert list embeddings back to numpy arrays
embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])]
return embeddings
else:
# Topics have changed, cache is invalid
return None
except (json.JSONDecodeError, KeyError, ValueError) as e:
# Cache file is corrupted or invalid format
print(f"Warning: Could not load cached embeddings: {e}")
return None
def save_cached_embeddings(embeddings, topics):
"""Save topic embeddings to cache file."""
topics_hash = get_topics_hash(topics)
# Convert numpy arrays to lists for JSON serialization
embeddings_list = [emb.tolist() for emb in embeddings]
cache_data = {
"topics_hash": topics_hash,
"embeddings": embeddings_list,
"model": "models/text-embedding-004",
"cached_at": datetime.now().isoformat()
}
try:
with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
json.dump(cache_data, f, indent=2)
print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
except Exception as e:
print(f"Warning: Could not save cached embeddings: {e}")
def get_topic_embeddings():
"""
Get topic embeddings, loading from cache if available, otherwise generating and caching them.
Returns:
numpy.ndarray: Array of topic embeddings
"""
topics = load_topics()
# Try to load from cache first
cached_embeddings = load_cached_embeddings()
if cached_embeddings is not None:
print(f"Loaded {len(cached_embeddings)} topic embeddings from cache")
return np.array(cached_embeddings)
# Cache miss or invalid - generate embeddings
print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...")
embedding_response = client.models.embed_content(
model="models/text-embedding-004",
contents=topics,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
raise RuntimeError("Embedding API did not return embeddings.")
embeddings = [np.array(e.values) for e in embedding_response.embeddings]
# Save to cache for future use
save_cached_embeddings(embeddings, topics)
return np.array(embeddings)
def find_most_similar_topic(input_text: str):
"""
Compare a single input text to all topics and return the highest cosine similarity.
Uses cached topic embeddings to avoid re-embedding topics on every call.
Args:
input_text: The text to compare against topics
Returns:
dict: Contains 'topic', 'similarity', and 'index' of the most similar topic
"""
# Load topics from JSON file
topics = load_topics()
if not topics:
raise ValueError("No topics found in topics.json")
# Get topic embeddings (from cache or generate)
topic_embeddings = get_topic_embeddings()
# Only embed the input text (much faster!)
embedding_response = client.models.embed_content(
model="models/text-embedding-004",
contents=[input_text],
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
raise RuntimeError("Embedding API did not return embeddings.")
# Extract input embedding
input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1)
# Calculate cosine similarity between input and each topic
similarities = cosine_similarity(input_embedding, topic_embeddings)[0]
# Find the highest similarity
max_index = np.argmax(similarities)
max_similarity = similarities[max_index]
most_similar_topic = topics[max_index]
return {
"topic": most_similar_topic,
"similarity": float(max_similarity),
"index": int(max_index)
}
if __name__ == "__main__":
# Example usage
#start time
start_time = datetime.now()
test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were."
result = find_most_similar_topic(test_text)
print(f"Input text: '{test_text}'")
print(f"Most similar topic: '{result['topic']}'")
print(f"Cosine similarity: {result['similarity']:.4f}%")
#end time
end_time = datetime.now()
#in seconds
print(f"Time taken: {(end_time - start_time).total_seconds()} seconds") |