|
|
""" Database Module for LLM System - Qdrant Vector Store |
|
|
- Contains the `VectorDB` class to manage a vector database using Qdrant and embeddings. |
|
|
- Supports both traditional text embeddings (Ollama) and multimodal embeddings (Jina v4). |
|
|
- Provides methods to initialize the database, retrieve embeddings, and perform similarity searches. |
|
|
- Replaces FAISS with Qdrant for production-ready vector storage |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Tuple, Optional, List |
|
|
from langchain_core.documents import Document |
|
|
from langchain_qdrant import QdrantVectorStore |
|
|
from langchain_core.runnables import ConfigurableField |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.models import Distance, VectorParams |
|
|
|
|
|
|
|
|
from langchain_core.embeddings import Embeddings |
|
|
from langchain_core.vectorstores import VectorStore |
|
|
from langchain_core.vectorstores import VectorStoreRetriever |
|
|
|
|
|
|
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain.retrievers.ensemble import EnsembleRetriever |
|
|
|
|
|
|
|
|
from llm_system.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, EMB_MODEL_NAME, OLLAMA_BASE_URL, USE_JINA_EMBEDDINGS |
|
|
|
|
|
|
|
|
from llm_system.core.llm import get_embeddings |
|
|
|
|
|
from logger import get_logger |
|
|
log = get_logger(name="core_database") |
|
|
|
|
|
|
|
|
class VectorDB: |
|
|
"""A class to manage the vector database using Qdrant with flexible embeddings. |
|
|
|
|
|
Args: |
|
|
embed_model (str): The name of the embeddings model to use (for Ollama fallback). |
|
|
retriever_num_docs (int): Number of documents to retrieve for similarity search. |
|
|
verify_connection (bool): Whether to verify the connection to the embeddings model. |
|
|
use_jina (bool, optional): Whether to use Jina multimodal embeddings. If None, uses config. |
|
|
qdrant_url (str, optional): Qdrant server URL. Defaults to config value. |
|
|
collection_name (str, optional): Qdrant collection name. Defaults to config value. |
|
|
api_key (str, optional): API key for Qdrant Cloud. Defaults to config value. |
|
|
|
|
|
## Functions: |
|
|
+ `get_embeddings()`: Returns the embeddings model (Jina or Ollama). |
|
|
+ `get_vector_store()`: Returns the Qdrant vector store. |
|
|
+ `get_retriever()`: Returns the retriever configured for similarity search. |
|
|
+ `save_db_to_disk()`: No-op for Qdrant (data persisted automatically on server). |
|
|
""" |
|
|
def __init__( |
|
|
self, embed_model: str, |
|
|
retriever_num_docs: int = 5, |
|
|
verify_connection: bool = False, |
|
|
use_jina: Optional[bool] = None, |
|
|
qdrant_url: Optional[str] = None, |
|
|
collection_name: Optional[str] = None, |
|
|
api_key: Optional[str] = None |
|
|
): |
|
|
self.qdrant_url: str = qdrant_url or QDRANT_URL |
|
|
self.collection_name: str = collection_name or QDRANT_COLLECTION_NAME |
|
|
self.api_key: Optional[str] = api_key or QDRANT_API_KEY |
|
|
self.retriever_k: int = retriever_num_docs |
|
|
self.embed_model = embed_model |
|
|
|
|
|
log.info( |
|
|
f"Initializing VectorDB with embeddings='{embed_model}', Qdrant URL='{self.qdrant_url}', " |
|
|
f"collection='{self.collection_name}', k={retriever_num_docs} docs." |
|
|
) |
|
|
|
|
|
|
|
|
if use_jina is None: |
|
|
use_jina = USE_JINA_EMBEDDINGS |
|
|
|
|
|
if use_jina: |
|
|
log.info("Using Jina multimodal embeddings") |
|
|
self.embeddings = get_embeddings(use_jina=True) |
|
|
else: |
|
|
log.info(f"Using Ollama embeddings: {embed_model}") |
|
|
from langchain_ollama import OllamaEmbeddings |
|
|
self.embeddings = OllamaEmbeddings(model=embed_model, base_url=OLLAMA_BASE_URL, num_gpu=0, keep_alive=-1) |
|
|
|
|
|
if verify_connection: |
|
|
try: |
|
|
self.embeddings.embed_documents(['a']) |
|
|
log.info(f"Embeddings model initialized and verified.") |
|
|
except Exception as e: |
|
|
log.error(f"Failed to initialize Embeddings: {e}") |
|
|
raise RuntimeError(f"Couldn't initialize Embeddings model") from e |
|
|
else: |
|
|
log.warning(f"Embeddings initialized without connection verification.") |
|
|
|
|
|
|
|
|
try: |
|
|
self.qdrant_client = QdrantClient( |
|
|
url=self.qdrant_url, |
|
|
api_key=self.api_key, |
|
|
timeout=30 |
|
|
) |
|
|
log.info(f"Connected to Qdrant at '{self.qdrant_url}'.") |
|
|
except Exception as e: |
|
|
log.error(f"Failed to connect to Qdrant: {e}") |
|
|
raise RuntimeError(f"Couldn't connect to Qdrant at '{self.qdrant_url}'") from e |
|
|
|
|
|
|
|
|
sample_embedding = self.embeddings.embed_query("test") |
|
|
embedding_dim = len(sample_embedding) |
|
|
log.info(f"Embedding dimension: {embedding_dim}") |
|
|
|
|
|
|
|
|
collections = self.qdrant_client.get_collections() |
|
|
collection_names = [c.name for c in collections.collections] |
|
|
|
|
|
if self.collection_name not in collection_names: |
|
|
log.info(f"Collection '{self.collection_name}' not found. Creating new collection.") |
|
|
self.qdrant_client.create_collection( |
|
|
collection_name=self.collection_name, |
|
|
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), |
|
|
) |
|
|
log.info(f"Created collection '{self.collection_name}' with dimension {embedding_dim}.") |
|
|
else: |
|
|
log.info(f"Using existing collection '{self.collection_name}'.") |
|
|
|
|
|
|
|
|
self.db = QdrantVectorStore( |
|
|
client=self.qdrant_client, |
|
|
collection_name=self.collection_name, |
|
|
embedding=self.embeddings, |
|
|
) |
|
|
log.info("Qdrant vector store initialized.") |
|
|
|
|
|
|
|
|
semantic_retriever = self.db.as_retriever( |
|
|
search_kwargs={ |
|
|
"k": self.retriever_k, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
bm25_retriever = BM25Retriever.from_documents([], k=self.retriever_k) |
|
|
self.bm25_retriever = bm25_retriever |
|
|
log.info("Initialized BM25 retriever for hybrid search") |
|
|
except Exception as e: |
|
|
log.warning(f"Could not initialize BM25 retriever: {e}. Using semantic search only.") |
|
|
self.bm25_retriever = None |
|
|
|
|
|
self.retriever = semantic_retriever |
|
|
self.semantic_retriever = semantic_retriever |
|
|
|
|
|
def get_embeddings(self) -> Embeddings: |
|
|
"""Returns the Ollama embeddings model.""" |
|
|
log.info("Returning the Embeddings model instance.") |
|
|
return self.embeddings |
|
|
|
|
|
def get_vector_store(self) -> VectorStore: |
|
|
"""Returns the Qdrant vector store instance.""" |
|
|
log.info("Returning the Qdrant vector store instance.") |
|
|
return self.db |
|
|
|
|
|
def get_retriever(self) -> VectorStoreRetriever: |
|
|
"""Returns the configurable retriever for similarity search.""" |
|
|
log.info("Returning the retriever for similarity search.") |
|
|
return self.retriever |
|
|
|
|
|
def get_hybrid_retriever(self, use_hybrid: bool = True): |
|
|
"""Returns a hybrid retriever combining semantic and keyword search. |
|
|
|
|
|
Args: |
|
|
use_hybrid: If True, returns ensemble of semantic + BM25 retrieval. |
|
|
If False, returns semantic search only. |
|
|
|
|
|
Returns: |
|
|
Retriever that combines semantic and keyword-based search |
|
|
""" |
|
|
if not use_hybrid or self.bm25_retriever is None: |
|
|
log.info("Returning semantic search only") |
|
|
return self.semantic_retriever |
|
|
|
|
|
try: |
|
|
ensemble = EnsembleRetriever( |
|
|
retrievers=[self.semantic_retriever, self.bm25_retriever], |
|
|
weights=[0.6, 0.4] |
|
|
) |
|
|
log.info("Returning hybrid retriever (semantic + BM25 ensemble)") |
|
|
return ensemble |
|
|
except Exception as e: |
|
|
log.warning(f"Could not create ensemble retriever: {e}. Using semantic search only.") |
|
|
return self.semantic_retriever |
|
|
|
|
|
def save_db_to_disk(self) -> bool: |
|
|
"""No-op for Qdrant (data is automatically persisted on the server). |
|
|
|
|
|
Colpali embeddings are stored in document metadata and automatically |
|
|
persisted with the documents in Qdrant. |
|
|
|
|
|
Returns: |
|
|
bool: Always returns True as Qdrant handles persistence automatically. |
|
|
""" |
|
|
log.info("Qdrant automatically persists data on server. No manual save needed.") |
|
|
return True |
|
|
|
|
|
def get_collection_info(self) -> dict: |
|
|
"""Get information about the current collection. |
|
|
|
|
|
Returns: |
|
|
dict: Collection information including size, vectors count, etc. |
|
|
""" |
|
|
try: |
|
|
info = self.qdrant_client.get_collection(self.collection_name) |
|
|
log.info(f"Retrieved collection info for '{self.collection_name}'.") |
|
|
return { |
|
|
"name": self.collection_name, |
|
|
"vectors_count": info.vectors_count, |
|
|
"points_count": info.points_count, |
|
|
"status": info.status, |
|
|
} |
|
|
except Exception as e: |
|
|
log.error(f"Failed to get collection info: {e}") |
|
|
return {} |
|
|
|
|
|
def delete_by_filter(self, filter_dict: dict) -> bool: |
|
|
"""Delete documents matching the filter criteria. |
|
|
|
|
|
Args: |
|
|
filter_dict: Qdrant filter dictionary (e.g., {"user_id": "test_user"}) |
|
|
|
|
|
Returns: |
|
|
bool: True if deletion was successful. |
|
|
""" |
|
|
try: |
|
|
from qdrant_client.models import Filter, FieldCondition, MatchValue |
|
|
|
|
|
|
|
|
conditions = [] |
|
|
for key, value in filter_dict.items(): |
|
|
conditions.append( |
|
|
FieldCondition( |
|
|
key=key, |
|
|
match=MatchValue(value=value) |
|
|
) |
|
|
) |
|
|
|
|
|
filter_obj = Filter(must=conditions) |
|
|
|
|
|
self.qdrant_client.delete( |
|
|
collection_name=self.collection_name, |
|
|
points_selector=filter_obj |
|
|
) |
|
|
log.info(f"Deleted documents matching filter: {filter_dict}") |
|
|
return True |
|
|
except Exception as e: |
|
|
log.error(f"Failed to delete documents: {e}") |
|
|
return False |
|
|
|
|
|
def create_hybrid_retriever_wrapper(self, base_retriever): |
|
|
"""Create a hybrid retriever that combines semantic and keyword search. |
|
|
|
|
|
This wraps the base retriever to add keyword-based (BM25) results to |
|
|
semantic search results, prioritizing nutrition facts and specific terms. |
|
|
|
|
|
Args: |
|
|
base_retriever: The base semantic search retriever from Qdrant |
|
|
|
|
|
Returns: |
|
|
A wrapper retriever that performs hybrid search |
|
|
""" |
|
|
def hybrid_invoke(query: str, **kwargs): |
|
|
|
|
|
semantic_docs = base_retriever.invoke(query, **kwargs) |
|
|
|
|
|
|
|
|
if self.bm25_retriever is None: |
|
|
log.debug("BM25 not available, returning semantic results only") |
|
|
return semantic_docs |
|
|
|
|
|
try: |
|
|
|
|
|
keyword_docs = self.bm25_retriever.invoke(query) |
|
|
|
|
|
|
|
|
seen_ids = {doc.metadata.get('id') for doc in semantic_docs} |
|
|
combined = list(semantic_docs) |
|
|
|
|
|
for doc in keyword_docs: |
|
|
if doc.metadata.get('id') not in seen_ids: |
|
|
combined.append(doc) |
|
|
seen_ids.add(doc.metadata.get('id')) |
|
|
if len(combined) >= self.retriever_k: |
|
|
break |
|
|
|
|
|
log.debug(f"Hybrid search returned {len(semantic_docs)} semantic + {len(combined) - len(semantic_docs)} keyword results") |
|
|
return combined[:self.retriever_k] |
|
|
except Exception as e: |
|
|
log.warning(f"Error in hybrid search, falling back to semantic: {e}") |
|
|
return semantic_docs |
|
|
|
|
|
|
|
|
class HybridRetriever: |
|
|
def __init__(self, invoke_fn): |
|
|
self.invoke_fn = invoke_fn |
|
|
|
|
|
def invoke(self, query, **kwargs): |
|
|
return self.invoke_fn(query, **kwargs) |
|
|
|
|
|
async def ainvoke(self, query, **kwargs): |
|
|
return self.invoke_fn(query, **kwargs) |
|
|
|
|
|
return HybridRetriever(hybrid_invoke) |
|
|
|