Yassine Mhirsi commited on
Commit
e97ac87
·
1 Parent(s): d25effa

refactor: Update file path handling in TopicSimilarityService for improved error reporting and ensure absolute paths for topic and embeddings cache files

Browse files
Files changed (2) hide show
  1. config.py +1 -1
  2. services/topic_similarity_service.py +26 -11
config.py CHANGED
@@ -12,7 +12,7 @@ load_dotenv()
12
 
13
  # ============ DIRECTORIES ============
14
  API_DIR = Path(__file__).parent
15
- PROJECT_ROOT = API_DIR.parent
16
 
17
  # ============ HUGGING FACE MODELS ============
18
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
 
12
 
13
  # ============ DIRECTORIES ============
14
  API_DIR = Path(__file__).parent
15
+ PROJECT_ROOT = API_DIR # config.py is in the project root, so API_DIR is the project root
16
 
17
  # ============ HUGGING FACE MODELS ============
18
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
services/topic_similarity_service.py CHANGED
@@ -62,15 +62,26 @@ class TopicSimilarityService:
62
 
63
  def _load_topics(self) -> List[str]:
64
  """Load topics from topics.json file"""
65
- if not TOPICS_FILE.exists():
66
- raise FileNotFoundError(f"Topics file not found: {TOPICS_FILE}")
 
 
 
 
 
 
 
 
67
 
68
  try:
69
- with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
70
  data = json.load(f)
71
- return data.get("topics", [])
 
 
 
72
  except (json.JSONDecodeError, KeyError) as e:
73
- raise ValueError(f"Error loading topics from {TOPICS_FILE}: {str(e)}")
74
 
75
  def _get_topics_hash(self, topics: List[str]) -> str:
76
  """Generate a hash of the topics list to verify cache validity"""
@@ -79,11 +90,14 @@ class TopicSimilarityService:
79
 
80
  def _load_cached_embeddings(self) -> Optional[np.ndarray]:
81
  """Load cached topic embeddings if they exist and are valid"""
82
- if not EMBEDDINGS_CACHE_FILE.exists():
 
 
 
83
  return None
84
 
85
  try:
86
- with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
87
  cache_data = json.load(f)
88
 
89
  # Verify cache is valid by checking topics hash
@@ -117,12 +131,13 @@ class TopicSimilarityService:
117
  }
118
 
119
  try:
120
- # Ensure directory exists
121
- EMBEDDINGS_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
 
122
 
123
- with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
124
  json.dump(cache_data, f, indent=2)
125
- logger.info(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
126
  except Exception as e:
127
  logger.warning(f"Could not save cached embeddings: {e}")
128
 
 
62
 
63
  def _load_topics(self) -> List[str]:
64
  """Load topics from topics.json file"""
65
+ # Ensure path is absolute
66
+ topics_file = Path(TOPICS_FILE).resolve()
67
+
68
+ if not topics_file.exists():
69
+ raise FileNotFoundError(
70
+ f"Topics file not found: {topics_file}\n"
71
+ f"Current working directory: {Path.cwd()}\n"
72
+ f"PROJECT_ROOT: {PROJECT_ROOT}\n"
73
+ f"TOPICS_FILE path: {TOPICS_FILE}"
74
+ )
75
 
76
  try:
77
+ with open(topics_file, 'r', encoding='utf-8') as f:
78
  data = json.load(f)
79
+ topics = data.get("topics", [])
80
+ if not topics:
81
+ raise ValueError(f"No topics found in {topics_file}")
82
+ return topics
83
  except (json.JSONDecodeError, KeyError) as e:
84
+ raise ValueError(f"Error loading topics from {topics_file}: {str(e)}")
85
 
86
  def _get_topics_hash(self, topics: List[str]) -> str:
87
  """Generate a hash of the topics list to verify cache validity"""
 
90
 
91
  def _load_cached_embeddings(self) -> Optional[np.ndarray]:
92
  """Load cached topic embeddings if they exist and are valid"""
93
+ # Ensure path is absolute
94
+ cache_file = Path(EMBEDDINGS_CACHE_FILE).resolve()
95
+
96
+ if not cache_file.exists():
97
  return None
98
 
99
  try:
100
+ with open(cache_file, 'r', encoding='utf-8') as f:
101
  cache_data = json.load(f)
102
 
103
  # Verify cache is valid by checking topics hash
 
131
  }
132
 
133
  try:
134
+ # Ensure path is absolute and directory exists
135
+ cache_file = Path(EMBEDDINGS_CACHE_FILE).resolve()
136
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
137
 
138
+ with open(cache_file, 'w', encoding='utf-8') as f:
139
  json.dump(cache_data, f, indent=2)
140
+ logger.info(f"Cached {len(embeddings)} topic embeddings to {cache_file}")
141
  except Exception as e:
142
  logger.warning(f"Could not save cached embeddings: {e}")
143