Yassine Mhirsi commited on
Commit
22ad0ba
·
1 Parent(s): e97ac87

similarity

Browse files
routes/topic.py CHANGED
@@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException
4
  from datetime import datetime
5
  import logging
6
 
7
- from services.topic_service import topic_service
8
  from models.topic import (
9
  TopicRequest,
10
  TopicResponse,
@@ -19,15 +19,16 @@ logger = logging.getLogger(__name__)
19
  @router.post("/extract", response_model=TopicResponse, tags=["Topic Extraction"])
20
  async def extract_topic(request: TopicRequest):
21
  """
22
- Extract a topic from a given text/argument
23
 
24
- - **text**: The input text or argument to extract topic from (5-5000 chars)
25
 
26
- Returns the extracted topic description
27
  """
28
  try:
29
- # Extract topic
30
- topic = topic_service.extract_topic(request.text)
 
31
 
32
  # Build response
33
  response = TopicResponse(
@@ -36,29 +37,29 @@ async def extract_topic(request: TopicRequest):
36
  timestamp=datetime.now().isoformat()
37
  )
38
 
39
- logger.info(f"Topic extracted: {topic[:50]}...")
40
  return response
41
 
42
  except ValueError as e:
43
  logger.error(f"Validation error: {str(e)}")
44
  raise HTTPException(status_code=400, detail=str(e))
45
  except Exception as e:
46
- logger.error(f"Topic extraction error: {str(e)}")
47
- raise HTTPException(status_code=500, detail=f"Topic extraction failed: {str(e)}")
48
 
49
 
50
  @router.post("/batch-extract", response_model=BatchTopicResponse, tags=["Topic Extraction"])
51
  async def batch_extract_topics(request: BatchTopicRequest):
52
  """
53
- Extract topics from multiple texts/arguments
54
 
55
- - **texts**: List of texts to extract topics from (max 50)
56
 
57
- Returns extracted topics for all texts
58
  """
59
  try:
60
- # Batch extract topics
61
- topics = topic_service.batch_extract_topics(request.texts)
62
 
63
  # Build response
64
  results = []
@@ -74,10 +75,10 @@ async def batch_extract_topics(request: BatchTopicRequest):
74
  )
75
  )
76
  else:
77
- # Skip failed extractions or handle as needed
78
- logger.warning(f"Failed to extract topic for text at index {i}")
79
 
80
- logger.info(f"Batch topic extraction completed: {len(results)}/{len(request.texts)} successful")
81
 
82
  return BatchTopicResponse(
83
  results=results,
@@ -89,6 +90,6 @@ async def batch_extract_topics(request: BatchTopicRequest):
89
  logger.error(f"Validation error: {str(e)}")
90
  raise HTTPException(status_code=400, detail=str(e))
91
  except Exception as e:
92
- logger.error(f"Batch topic extraction error: {str(e)}")
93
- raise HTTPException(status_code=500, detail=f"Batch topic extraction failed: {str(e)}")
94
 
 
4
  from datetime import datetime
5
  import logging
6
 
7
+ from services.topic_similarity_service import topic_similarity_service
8
  from models.topic import (
9
  TopicRequest,
10
  TopicResponse,
 
19
  @router.post("/extract", response_model=TopicResponse, tags=["Topic Extraction"])
20
  async def extract_topic(request: TopicRequest):
21
  """
22
+ Find the most similar topic from predefined topics for a given text/argument
23
 
24
+ - **text**: The input text or argument to find similar topic for (5-5000 chars)
25
 
26
+ Returns the most similar topic from the predefined list
27
  """
28
  try:
29
+ # Find most similar topic
30
+ result = topic_similarity_service.find_most_similar_topic(request.text)
31
+ topic = result["topic"]
32
 
33
  # Build response
34
  response = TopicResponse(
 
37
  timestamp=datetime.now().isoformat()
38
  )
39
 
40
+ logger.info(f"Most similar topic found: {topic[:50]}... (similarity: {result['similarity']:.4f})")
41
  return response
42
 
43
  except ValueError as e:
44
  logger.error(f"Validation error: {str(e)}")
45
  raise HTTPException(status_code=400, detail=str(e))
46
  except Exception as e:
47
+ logger.error(f"Topic similarity error: {str(e)}")
48
+ raise HTTPException(status_code=500, detail=f"Topic similarity search failed: {str(e)}")
49
 
50
 
51
  @router.post("/batch-extract", response_model=BatchTopicResponse, tags=["Topic Extraction"])
52
  async def batch_extract_topics(request: BatchTopicRequest):
53
  """
54
+ Find the most similar topics from predefined topics for multiple texts/arguments
55
 
56
+ - **texts**: List of texts to find similar topics for (max 50)
57
 
58
+ Returns the most similar topics from the predefined list for all texts
59
  """
60
  try:
61
+ # Batch find similar topics
62
+ topics = topic_similarity_service.batch_find_similar_topics(request.texts)
63
 
64
  # Build response
65
  results = []
 
75
  )
76
  )
77
  else:
78
+ # Skip failed searches or handle as needed
79
+ logger.warning(f"Failed to find similar topic for text at index {i}")
80
 
81
+ logger.info(f"Batch topic similarity search completed: {len(results)}/{len(request.texts)} successful")
82
 
83
  return BatchTopicResponse(
84
  results=results,
 
90
  logger.error(f"Validation error: {str(e)}")
91
  raise HTTPException(status_code=400, detail=str(e))
92
  except Exception as e:
93
+ logger.error(f"Batch topic similarity error: {str(e)}")
94
+ raise HTTPException(status_code=500, detail=f"Batch topic similarity search failed: {str(e)}")
95
 
services/topic_service.py CHANGED
@@ -22,7 +22,7 @@ class TopicService:
22
 
23
  def __init__(self):
24
  self.llm = None
25
- self.model_name = "openai/gpt-oss-safeguard-120b" # another model meta-llama/llama-4-scout-17b-16e-instruct
26
  self.initialized = False
27
 
28
  def initialize(self, model_name: Optional[str] = None):
 
22
 
23
  def __init__(self):
24
  self.llm = None
25
+ self.model_name = "openai/gpt-oss-safeguard-20b" # another model meta-llama/llama-4-scout-17b-16e-instruct
26
  self.initialized = False
27
 
28
  def initialize(self, model_name: Optional[str] = None):
topic_similarity_google_example.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import os
3
+ import json
4
+ import hashlib
5
+ from pathlib import Path
6
+ from dotenv import load_dotenv
7
+ from google import genai
8
+ from google.genai import types
9
+ import numpy as np
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
15
+ if not GOOGLE_API_KEY:
16
+ raise ValueError("GOOGLE_API_KEY is not set in environment variables.")
17
+
18
+ # Get the path to topics.json relative to this file
19
+ TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json"
20
+ # Cache file for topic embeddings
21
+ EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json"
22
+
23
+ # Create a Google Generative AI client with the API key
24
+ client = genai.Client(api_key=GOOGLE_API_KEY)
25
+
26
+
27
+ def load_topics():
28
+ """Load topics from topics.json file."""
29
+ with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
30
+ data = json.load(f)
31
+ return data.get("topics", [])
32
+
33
+
34
+ def get_topics_hash(topics):
35
+ """Generate a hash of the topics list to verify cache validity."""
36
+ topics_str = json.dumps(topics, sort_keys=True)
37
+ return hashlib.md5(topics_str.encode('utf-8')).hexdigest()
38
+
39
+
40
+ def load_cached_embeddings():
41
+ """Load cached topic embeddings if they exist and are valid."""
42
+ if not EMBEDDINGS_CACHE_FILE.exists():
43
+ return None
44
+
45
+ try:
46
+ with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
47
+ cache_data = json.load(f)
48
+
49
+ # Verify cache is valid by checking topics hash
50
+ current_topics = load_topics()
51
+ current_hash = get_topics_hash(current_topics)
52
+
53
+ if cache_data.get("topics_hash") == current_hash:
54
+ # Convert list embeddings back to numpy arrays
55
+ embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])]
56
+ return embeddings
57
+ else:
58
+ # Topics have changed, cache is invalid
59
+ return None
60
+ except (json.JSONDecodeError, KeyError, ValueError) as e:
61
+ # Cache file is corrupted or invalid format
62
+ print(f"Warning: Could not load cached embeddings: {e}")
63
+ return None
64
+
65
+
66
+ def save_cached_embeddings(embeddings, topics):
67
+ """Save topic embeddings to cache file."""
68
+ topics_hash = get_topics_hash(topics)
69
+
70
+ # Convert numpy arrays to lists for JSON serialization
71
+ embeddings_list = [emb.tolist() for emb in embeddings]
72
+
73
+ cache_data = {
74
+ "topics_hash": topics_hash,
75
+ "embeddings": embeddings_list,
76
+ "model": "models/text-embedding-004",
77
+ "cached_at": datetime.now().isoformat()
78
+ }
79
+
80
+ try:
81
+ with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
82
+ json.dump(cache_data, f, indent=2)
83
+ print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
84
+ except Exception as e:
85
+ print(f"Warning: Could not save cached embeddings: {e}")
86
+
87
+
88
+ def get_topic_embeddings():
89
+ """
90
+ Get topic embeddings, loading from cache if available, otherwise generating and caching them.
91
+
92
+ Returns:
93
+ numpy.ndarray: Array of topic embeddings
94
+ """
95
+ topics = load_topics()
96
+
97
+ # Try to load from cache first
98
+ cached_embeddings = load_cached_embeddings()
99
+ if cached_embeddings is not None:
100
+ print(f"Loaded {len(cached_embeddings)} topic embeddings from cache")
101
+ return np.array(cached_embeddings)
102
+
103
+ # Cache miss or invalid - generate embeddings
104
+ print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...")
105
+ embedding_response = client.models.embed_content(
106
+ model="models/text-embedding-004",
107
+ contents=topics,
108
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
109
+ )
110
+
111
+ if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
112
+ raise RuntimeError("Embedding API did not return embeddings.")
113
+
114
+ embeddings = [np.array(e.values) for e in embedding_response.embeddings]
115
+
116
+ # Save to cache for future use
117
+ save_cached_embeddings(embeddings, topics)
118
+
119
+ return np.array(embeddings)
120
+
121
+
122
+ def find_most_similar_topic(input_text: str):
123
+ """
124
+ Compare a single input text to all topics and return the highest cosine similarity.
125
+ Uses cached topic embeddings to avoid re-embedding topics on every call.
126
+
127
+ Args:
128
+ input_text: The text to compare against topics
129
+
130
+ Returns:
131
+ dict: Contains 'topic', 'similarity', and 'index' of the most similar topic
132
+ """
133
+ # Load topics from JSON file
134
+ topics = load_topics()
135
+
136
+ if not topics:
137
+ raise ValueError("No topics found in topics.json")
138
+
139
+ # Get topic embeddings (from cache or generate)
140
+ topic_embeddings = get_topic_embeddings()
141
+
142
+ # Only embed the input text (much faster!)
143
+ embedding_response = client.models.embed_content(
144
+ model="models/text-embedding-004",
145
+ contents=[input_text],
146
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
147
+ )
148
+
149
+ if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
150
+ raise RuntimeError("Embedding API did not return embeddings.")
151
+
152
+ # Extract input embedding
153
+ input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1)
154
+
155
+ # Calculate cosine similarity between input and each topic
156
+ similarities = cosine_similarity(input_embedding, topic_embeddings)[0]
157
+
158
+ # Find the highest similarity
159
+ max_index = np.argmax(similarities)
160
+ max_similarity = similarities[max_index]
161
+ most_similar_topic = topics[max_index]
162
+
163
+ return {
164
+ "topic": most_similar_topic,
165
+ "similarity": float(max_similarity),
166
+ "index": int(max_index)
167
+ }
168
+
169
+
170
+ if __name__ == "__main__":
171
+ # Example usage
172
+ #start time
173
+ start_time = datetime.now()
174
+ test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were."
175
+ result = find_most_similar_topic(test_text)
176
+ print(f"Input text: '{test_text}'")
177
+ print(f"Most similar topic: '{result['topic']}'")
178
+ print(f"Cosine similarity: {result['similarity']:.4f}%")
179
+ #end time
180
+ end_time = datetime.now()
181
+ #in seconds
182
+ print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")
topic_similarity_langchain_example.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_core.example_selectors import (
9
+ SemanticSimilarityExampleSelector,
10
+ )
11
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
12
+
13
+ # Load topics from data file
14
+ with open(
15
+ file="data/topics.json",
16
+ encoding="utf-8"
17
+ ) as f:
18
+ data = json.load(f)
19
+
20
+ # Make sure each example is a dict with "topic" key (wrap as dict if plain string)
21
+ def format_examples(examples):
22
+ formatted = []
23
+ for ex in examples:
24
+ if isinstance(ex, str):
25
+ formatted.append({"topic": ex})
26
+ elif isinstance(ex, dict) and "topic" in ex:
27
+ formatted.append({"topic": ex["topic"]})
28
+ else:
29
+ formatted.append({"topic": str(ex)})
30
+ return formatted
31
+
32
+ # topics.json should have a top-level "topics" key
33
+ examples = data.get("topics", [])
34
+ formatted_examples = format_examples(examples)
35
+
36
+ start_time = datetime.now()
37
+ example_selector = SemanticSimilarityExampleSelector.from_examples(
38
+ examples=formatted_examples,
39
+ embeddings=GoogleGenerativeAIEmbeddings(
40
+ model="models/text-embedding-004",
41
+ api_key=os.getenv("GOOGLE_API_KEY")
42
+ ),
43
+ vectorstore_cls=FAISS,
44
+ k=1,
45
+ input_keys=["topic"],
46
+ )
47
+
48
+ # Example call to selector (for demonstration; remove in production)
49
+ result = example_selector.select_examples(
50
+ {"topic": "people who are terminally ill and suffering greatly should have the right to end their own life if they so desire."}
51
+ )
52
+ print(result)
53
+ end_time = datetime.now()
54
+ print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")