File size: 6,395 Bytes
22ad0ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from datetime import datetime
import os
import json
import hashlib
from pathlib import Path
from dotenv import load_dotenv
from google import genai
from google.genai import types
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load environment variables from .env file
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
    raise ValueError("GOOGLE_API_KEY is not set in environment variables.")

# Get the path to topics.json relative to this file
TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json"
# Cache file for topic embeddings
EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json"

# Create a Google Generative AI client with the API key
client = genai.Client(api_key=GOOGLE_API_KEY)


def load_topics():
    """Load topics from topics.json file."""
    with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data.get("topics", [])


def get_topics_hash(topics):
    """Generate a hash of the topics list to verify cache validity."""
    topics_str = json.dumps(topics, sort_keys=True)
    return hashlib.md5(topics_str.encode('utf-8')).hexdigest()


def load_cached_embeddings():
    """Load cached topic embeddings if they exist and are valid."""
    if not EMBEDDINGS_CACHE_FILE.exists():
        return None
    
    try:
        with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
            cache_data = json.load(f)
        
        # Verify cache is valid by checking topics hash
        current_topics = load_topics()
        current_hash = get_topics_hash(current_topics)
        
        if cache_data.get("topics_hash") == current_hash:
            # Convert list embeddings back to numpy arrays
            embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])]
            return embeddings
        else:
            # Topics have changed, cache is invalid
            return None
    except (json.JSONDecodeError, KeyError, ValueError) as e:
        # Cache file is corrupted or invalid format
        print(f"Warning: Could not load cached embeddings: {e}")
        return None


def save_cached_embeddings(embeddings, topics):
    """Save topic embeddings to cache file."""
    topics_hash = get_topics_hash(topics)
    
    # Convert numpy arrays to lists for JSON serialization
    embeddings_list = [emb.tolist() for emb in embeddings]
    
    cache_data = {
        "topics_hash": topics_hash,
        "embeddings": embeddings_list,
        "model": "models/text-embedding-004",
        "cached_at": datetime.now().isoformat()
    }
    
    try:
        with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
            json.dump(cache_data, f, indent=2)
        print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
    except Exception as e:
        print(f"Warning: Could not save cached embeddings: {e}")


def get_topic_embeddings():
    """
    Get topic embeddings, loading from cache if available, otherwise generating and caching them.
    
    Returns:
        numpy.ndarray: Array of topic embeddings
    """
    topics = load_topics()
    
    # Try to load from cache first
    cached_embeddings = load_cached_embeddings()
    if cached_embeddings is not None:
        print(f"Loaded {len(cached_embeddings)} topic embeddings from cache")
        return np.array(cached_embeddings)
    
    # Cache miss or invalid - generate embeddings
    print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...")
    embedding_response = client.models.embed_content(
        model="models/text-embedding-004",
        contents=topics,
        config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
    )
    
    if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
        raise RuntimeError("Embedding API did not return embeddings.")
    
    embeddings = [np.array(e.values) for e in embedding_response.embeddings]
    
    # Save to cache for future use
    save_cached_embeddings(embeddings, topics)
    
    return np.array(embeddings)


def find_most_similar_topic(input_text: str):
    """
    Compare a single input text to all topics and return the highest cosine similarity.
    Uses cached topic embeddings to avoid re-embedding topics on every call.
    
    Args:
        input_text: The text to compare against topics
        
    Returns:
        dict: Contains 'topic', 'similarity', and 'index' of the most similar topic
    """
    # Load topics from JSON file
    topics = load_topics()
    
    if not topics:
        raise ValueError("No topics found in topics.json")
    
    # Get topic embeddings (from cache or generate)
    topic_embeddings = get_topic_embeddings()
    
    # Only embed the input text (much faster!)
    embedding_response = client.models.embed_content(
        model="models/text-embedding-004",
        contents=[input_text],
        config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
    )
    
    if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
        raise RuntimeError("Embedding API did not return embeddings.")
    
    # Extract input embedding
    input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1)
    
    # Calculate cosine similarity between input and each topic
    similarities = cosine_similarity(input_embedding, topic_embeddings)[0]
    
    # Find the highest similarity
    max_index = np.argmax(similarities)
    max_similarity = similarities[max_index]
    most_similar_topic = topics[max_index]
    
    return {
        "topic": most_similar_topic,
        "similarity": float(max_similarity),
        "index": int(max_index)
    }


if __name__ == "__main__":
    # Example usage
    #start time
    start_time = datetime.now()
    test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were."
    result = find_most_similar_topic(test_text)
    print(f"Input text: '{test_text}'")
    print(f"Most similar topic: '{result['topic']}'")
    print(f"Cosine similarity: {result['similarity']:.4f}%")
    #end time
    end_time = datetime.now()
    #in seconds
    print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")