File size: 13,128 Bytes
4aec76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
""" Database Module for LLM System - Qdrant Vector Store
- Contains the `VectorDB` class to manage a vector database using Qdrant and embeddings.
- Supports both traditional text embeddings (Ollama) and multimodal embeddings (Jina v4).
- Provides methods to initialize the database, retrieve embeddings, and perform similarity searches.
- Replaces FAISS with Qdrant for production-ready vector storage
"""

import os
from typing import Tuple, Optional, List
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore
from langchain_core.runnables import ConfigurableField
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

# For type hinting
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStoreRetriever

# For hybrid search
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever

# config:
from llm_system.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, EMB_MODEL_NAME, OLLAMA_BASE_URL, USE_JINA_EMBEDDINGS

# Import embeddings factory
from llm_system.core.llm import get_embeddings

from logger import get_logger
log = get_logger(name="core_database")


class VectorDB:
    """A class to manage the vector database using Qdrant with flexible embeddings.

    Args:
        embed_model (str): The name of the embeddings model to use (for Ollama fallback).
        retriever_num_docs (int): Number of documents to retrieve for similarity search.
        verify_connection (bool): Whether to verify the connection to the embeddings model.
        use_jina (bool, optional): Whether to use Jina multimodal embeddings. If None, uses config.
        qdrant_url (str, optional): Qdrant server URL. Defaults to config value.
        collection_name (str, optional): Qdrant collection name. Defaults to config value.
        api_key (str, optional): API key for Qdrant Cloud. Defaults to config value.

    ## Functions:
        + `get_embeddings()`: Returns the embeddings model (Jina or Ollama).
        + `get_vector_store()`: Returns the Qdrant vector store.
        + `get_retriever()`: Returns the retriever configured for similarity search.
        + `save_db_to_disk()`: No-op for Qdrant (data persisted automatically on server).
    """
    def __init__(
        self, embed_model: str,
        retriever_num_docs: int = 5,
        verify_connection: bool = False,
        use_jina: Optional[bool] = None,
        qdrant_url: Optional[str] = None,
        collection_name: Optional[str] = None,
        api_key: Optional[str] = None
    ):
        self.qdrant_url: str = qdrant_url or QDRANT_URL
        self.collection_name: str = collection_name or QDRANT_COLLECTION_NAME
        self.api_key: Optional[str] = api_key or QDRANT_API_KEY
        self.retriever_k: int = retriever_num_docs
        self.embed_model = embed_model

        log.info(
            f"Initializing VectorDB with embeddings='{embed_model}', Qdrant URL='{self.qdrant_url}', "
            f"collection='{self.collection_name}', k={retriever_num_docs} docs."
        )

        # Initialize embeddings (Jina multimodal or Ollama fallback)
        if use_jina is None:
            use_jina = USE_JINA_EMBEDDINGS
        
        if use_jina:
            log.info("Using Jina multimodal embeddings")
            self.embeddings = get_embeddings(use_jina=True)
        else:
            log.info(f"Using Ollama embeddings: {embed_model}")
            from langchain_ollama import OllamaEmbeddings
            self.embeddings = OllamaEmbeddings(model=embed_model, base_url=OLLAMA_BASE_URL, num_gpu=0, keep_alive=-1)

        if verify_connection:
            try:
                self.embeddings.embed_documents(['a'])
                log.info(f"Embeddings model initialized and verified.")
            except Exception as e:
                log.error(f"Failed to initialize Embeddings: {e}")
                raise RuntimeError(f"Couldn't initialize Embeddings model") from e
        else:
            log.warning(f"Embeddings initialized without connection verification.")

        # Initialize Qdrant client
        try:
            self.qdrant_client = QdrantClient(
                url=self.qdrant_url,
                api_key=self.api_key,
                timeout=30
            )
            log.info(f"Connected to Qdrant at '{self.qdrant_url}'.")
        except Exception as e:
            log.error(f"Failed to connect to Qdrant: {e}")
            raise RuntimeError(f"Couldn't connect to Qdrant at '{self.qdrant_url}'") from e

        # Get embedding dimension
        sample_embedding = self.embeddings.embed_query("test")
        embedding_dim = len(sample_embedding)
        log.info(f"Embedding dimension: {embedding_dim}")

        # Check if collection exists, create if not
        collections = self.qdrant_client.get_collections()
        collection_names = [c.name for c in collections.collections]
        
        if self.collection_name not in collection_names:
            log.info(f"Collection '{self.collection_name}' not found. Creating new collection.")
            self.qdrant_client.create_collection(
                collection_name=self.collection_name,
                vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
            )
            log.info(f"Created collection '{self.collection_name}' with dimension {embedding_dim}.")
        else:
            log.info(f"Using existing collection '{self.collection_name}'.")

        # Initialize Qdrant vector store
        self.db = QdrantVectorStore(
            client=self.qdrant_client,
            collection_name=self.collection_name,
            embedding=self.embeddings,
        )
        log.info("Qdrant vector store initialized.")

        # Create semantic search retriever (similarity)
        semantic_retriever = self.db.as_retriever(
            search_kwargs={
                "k": self.retriever_k,
            }
        )
        
        # Create BM25 keyword-based retriever for hybrid search
        try:
            bm25_retriever = BM25Retriever.from_documents([], k=self.retriever_k)
            self.bm25_retriever = bm25_retriever
            log.info("Initialized BM25 retriever for hybrid search")
        except Exception as e:
            log.warning(f"Could not initialize BM25 retriever: {e}. Using semantic search only.")
            self.bm25_retriever = None
        
        self.retriever = semantic_retriever
        self.semantic_retriever = semantic_retriever

    def get_embeddings(self) -> Embeddings:
        """Returns the Ollama embeddings model."""
        log.info("Returning the Embeddings model instance.")
        return self.embeddings

    def get_vector_store(self) -> VectorStore:
        """Returns the Qdrant vector store instance."""
        log.info("Returning the Qdrant vector store instance.")
        return self.db

    def get_retriever(self) -> VectorStoreRetriever:
        """Returns the configurable retriever for similarity search."""
        log.info("Returning the retriever for similarity search.")
        return self.retriever # type: ignore[return-value]
    
    def get_hybrid_retriever(self, use_hybrid: bool = True):
        """Returns a hybrid retriever combining semantic and keyword search.
        
        Args:
            use_hybrid: If True, returns ensemble of semantic + BM25 retrieval.
                       If False, returns semantic search only.
        
        Returns:
            Retriever that combines semantic and keyword-based search
        """
        if not use_hybrid or self.bm25_retriever is None:
            log.info("Returning semantic search only")
            return self.semantic_retriever
        
        try:
            ensemble = EnsembleRetriever(
                retrievers=[self.semantic_retriever, self.bm25_retriever],
                weights=[0.6, 0.4]
            )
            log.info("Returning hybrid retriever (semantic + BM25 ensemble)")
            return ensemble
        except Exception as e:
            log.warning(f"Could not create ensemble retriever: {e}. Using semantic search only.")
            return self.semantic_retriever

    def save_db_to_disk(self) -> bool:
        """No-op for Qdrant (data is automatically persisted on the server).
        
        Colpali embeddings are stored in document metadata and automatically 
        persisted with the documents in Qdrant.
        
        Returns:
            bool: Always returns True as Qdrant handles persistence automatically.
        """
        log.info("Qdrant automatically persists data on server. No manual save needed.")
        return True

    def get_collection_info(self) -> dict:
        """Get information about the current collection.
        
        Returns:
            dict: Collection information including size, vectors count, etc.
        """
        try:
            info = self.qdrant_client.get_collection(self.collection_name)
            log.info(f"Retrieved collection info for '{self.collection_name}'.")
            return {
                "name": self.collection_name,
                "vectors_count": info.vectors_count,
                "points_count": info.points_count,
                "status": info.status,
            }
        except Exception as e:
            log.error(f"Failed to get collection info: {e}")
            return {}

    def delete_by_filter(self, filter_dict: dict) -> bool:
        """Delete documents matching the filter criteria.
        
        Args:
            filter_dict: Qdrant filter dictionary (e.g., {"user_id": "test_user"})
            
        Returns:
            bool: True if deletion was successful.
        """
        try:
            from qdrant_client.models import Filter, FieldCondition, MatchValue
            
            # Convert simple dict to Qdrant filter
            conditions = []
            for key, value in filter_dict.items():
                conditions.append(
                    FieldCondition(
                        key=key,
                        match=MatchValue(value=value)
                    )
                )
            
            filter_obj = Filter(must=conditions)
            
            self.qdrant_client.delete(
                collection_name=self.collection_name,
                points_selector=filter_obj
            )
            log.info(f"Deleted documents matching filter: {filter_dict}")
            return True
        except Exception as e:
            log.error(f"Failed to delete documents: {e}")
            return False
    
    def create_hybrid_retriever_wrapper(self, base_retriever):
        """Create a hybrid retriever that combines semantic and keyword search.
        
        This wraps the base retriever to add keyword-based (BM25) results to
        semantic search results, prioritizing nutrition facts and specific terms.
        
        Args:
            base_retriever: The base semantic search retriever from Qdrant
            
        Returns:
            A wrapper retriever that performs hybrid search
        """
        def hybrid_invoke(query: str, **kwargs):
            # Get semantic results
            semantic_docs = base_retriever.invoke(query, **kwargs)
            
            # Try to enhance with keyword search if BM25 is available
            if self.bm25_retriever is None:
                log.debug("BM25 not available, returning semantic results only")
                return semantic_docs
            
            try:
                # Get keyword results
                keyword_docs = self.bm25_retriever.invoke(query)
                
                # Combine and deduplicate (prioritize semantic, add unique keyword results)
                seen_ids = {doc.metadata.get('id') for doc in semantic_docs}
                combined = list(semantic_docs)
                
                for doc in keyword_docs:
                    if doc.metadata.get('id') not in seen_ids:
                        combined.append(doc)
                        seen_ids.add(doc.metadata.get('id'))
                        if len(combined) >= self.retriever_k:
                            break
                
                log.debug(f"Hybrid search returned {len(semantic_docs)} semantic + {len(combined) - len(semantic_docs)} keyword results")
                return combined[:self.retriever_k]
            except Exception as e:
                log.warning(f"Error in hybrid search, falling back to semantic: {e}")
                return semantic_docs
        
        # Create a simple wrapper class
        class HybridRetriever:
            def __init__(self, invoke_fn):
                self.invoke_fn = invoke_fn
            
            def invoke(self, query, **kwargs):
                return self.invoke_fn(query, **kwargs)
            
            async def ainvoke(self, query, **kwargs):
                return self.invoke_fn(query, **kwargs)
        
        return HybridRetriever(hybrid_invoke)