sanchitshaleen
Initial deployment of RAG with Gemma-3 to Hugging Face Spaces
4aec76b
""" Database Module for LLM System - Qdrant Vector Store
- Contains the `VectorDB` class to manage a vector database using Qdrant and embeddings.
- Supports both traditional text embeddings (Ollama) and multimodal embeddings (Jina v4).
- Provides methods to initialize the database, retrieve embeddings, and perform similarity searches.
- Replaces FAISS with Qdrant for production-ready vector storage
"""
import os
from typing import Tuple, Optional, List
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore
from langchain_core.runnables import ConfigurableField
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
# For type hinting
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStoreRetriever
# For hybrid search
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
# config:
from llm_system.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, EMB_MODEL_NAME, OLLAMA_BASE_URL, USE_JINA_EMBEDDINGS
# Import embeddings factory
from llm_system.core.llm import get_embeddings
from logger import get_logger
log = get_logger(name="core_database")
class VectorDB:
"""A class to manage the vector database using Qdrant with flexible embeddings.
Args:
embed_model (str): The name of the embeddings model to use (for Ollama fallback).
retriever_num_docs (int): Number of documents to retrieve for similarity search.
verify_connection (bool): Whether to verify the connection to the embeddings model.
use_jina (bool, optional): Whether to use Jina multimodal embeddings. If None, uses config.
qdrant_url (str, optional): Qdrant server URL. Defaults to config value.
collection_name (str, optional): Qdrant collection name. Defaults to config value.
api_key (str, optional): API key for Qdrant Cloud. Defaults to config value.
## Functions:
+ `get_embeddings()`: Returns the embeddings model (Jina or Ollama).
+ `get_vector_store()`: Returns the Qdrant vector store.
+ `get_retriever()`: Returns the retriever configured for similarity search.
+ `save_db_to_disk()`: No-op for Qdrant (data persisted automatically on server).
"""
def __init__(
self, embed_model: str,
retriever_num_docs: int = 5,
verify_connection: bool = False,
use_jina: Optional[bool] = None,
qdrant_url: Optional[str] = None,
collection_name: Optional[str] = None,
api_key: Optional[str] = None
):
self.qdrant_url: str = qdrant_url or QDRANT_URL
self.collection_name: str = collection_name or QDRANT_COLLECTION_NAME
self.api_key: Optional[str] = api_key or QDRANT_API_KEY
self.retriever_k: int = retriever_num_docs
self.embed_model = embed_model
log.info(
f"Initializing VectorDB with embeddings='{embed_model}', Qdrant URL='{self.qdrant_url}', "
f"collection='{self.collection_name}', k={retriever_num_docs} docs."
)
# Initialize embeddings (Jina multimodal or Ollama fallback)
if use_jina is None:
use_jina = USE_JINA_EMBEDDINGS
if use_jina:
log.info("Using Jina multimodal embeddings")
self.embeddings = get_embeddings(use_jina=True)
else:
log.info(f"Using Ollama embeddings: {embed_model}")
from langchain_ollama import OllamaEmbeddings
self.embeddings = OllamaEmbeddings(model=embed_model, base_url=OLLAMA_BASE_URL, num_gpu=0, keep_alive=-1)
if verify_connection:
try:
self.embeddings.embed_documents(['a'])
log.info(f"Embeddings model initialized and verified.")
except Exception as e:
log.error(f"Failed to initialize Embeddings: {e}")
raise RuntimeError(f"Couldn't initialize Embeddings model") from e
else:
log.warning(f"Embeddings initialized without connection verification.")
# Initialize Qdrant client
try:
self.qdrant_client = QdrantClient(
url=self.qdrant_url,
api_key=self.api_key,
timeout=30
)
log.info(f"Connected to Qdrant at '{self.qdrant_url}'.")
except Exception as e:
log.error(f"Failed to connect to Qdrant: {e}")
raise RuntimeError(f"Couldn't connect to Qdrant at '{self.qdrant_url}'") from e
# Get embedding dimension
sample_embedding = self.embeddings.embed_query("test")
embedding_dim = len(sample_embedding)
log.info(f"Embedding dimension: {embedding_dim}")
# Check if collection exists, create if not
collections = self.qdrant_client.get_collections()
collection_names = [c.name for c in collections.collections]
if self.collection_name not in collection_names:
log.info(f"Collection '{self.collection_name}' not found. Creating new collection.")
self.qdrant_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
log.info(f"Created collection '{self.collection_name}' with dimension {embedding_dim}.")
else:
log.info(f"Using existing collection '{self.collection_name}'.")
# Initialize Qdrant vector store
self.db = QdrantVectorStore(
client=self.qdrant_client,
collection_name=self.collection_name,
embedding=self.embeddings,
)
log.info("Qdrant vector store initialized.")
# Create semantic search retriever (similarity)
semantic_retriever = self.db.as_retriever(
search_kwargs={
"k": self.retriever_k,
}
)
# Create BM25 keyword-based retriever for hybrid search
try:
bm25_retriever = BM25Retriever.from_documents([], k=self.retriever_k)
self.bm25_retriever = bm25_retriever
log.info("Initialized BM25 retriever for hybrid search")
except Exception as e:
log.warning(f"Could not initialize BM25 retriever: {e}. Using semantic search only.")
self.bm25_retriever = None
self.retriever = semantic_retriever
self.semantic_retriever = semantic_retriever
def get_embeddings(self) -> Embeddings:
"""Returns the Ollama embeddings model."""
log.info("Returning the Embeddings model instance.")
return self.embeddings
def get_vector_store(self) -> VectorStore:
"""Returns the Qdrant vector store instance."""
log.info("Returning the Qdrant vector store instance.")
return self.db
def get_retriever(self) -> VectorStoreRetriever:
"""Returns the configurable retriever for similarity search."""
log.info("Returning the retriever for similarity search.")
return self.retriever # type: ignore[return-value]
def get_hybrid_retriever(self, use_hybrid: bool = True):
"""Returns a hybrid retriever combining semantic and keyword search.
Args:
use_hybrid: If True, returns ensemble of semantic + BM25 retrieval.
If False, returns semantic search only.
Returns:
Retriever that combines semantic and keyword-based search
"""
if not use_hybrid or self.bm25_retriever is None:
log.info("Returning semantic search only")
return self.semantic_retriever
try:
ensemble = EnsembleRetriever(
retrievers=[self.semantic_retriever, self.bm25_retriever],
weights=[0.6, 0.4]
)
log.info("Returning hybrid retriever (semantic + BM25 ensemble)")
return ensemble
except Exception as e:
log.warning(f"Could not create ensemble retriever: {e}. Using semantic search only.")
return self.semantic_retriever
def save_db_to_disk(self) -> bool:
"""No-op for Qdrant (data is automatically persisted on the server).
Colpali embeddings are stored in document metadata and automatically
persisted with the documents in Qdrant.
Returns:
bool: Always returns True as Qdrant handles persistence automatically.
"""
log.info("Qdrant automatically persists data on server. No manual save needed.")
return True
def get_collection_info(self) -> dict:
"""Get information about the current collection.
Returns:
dict: Collection information including size, vectors count, etc.
"""
try:
info = self.qdrant_client.get_collection(self.collection_name)
log.info(f"Retrieved collection info for '{self.collection_name}'.")
return {
"name": self.collection_name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status,
}
except Exception as e:
log.error(f"Failed to get collection info: {e}")
return {}
def delete_by_filter(self, filter_dict: dict) -> bool:
"""Delete documents matching the filter criteria.
Args:
filter_dict: Qdrant filter dictionary (e.g., {"user_id": "test_user"})
Returns:
bool: True if deletion was successful.
"""
try:
from qdrant_client.models import Filter, FieldCondition, MatchValue
# Convert simple dict to Qdrant filter
conditions = []
for key, value in filter_dict.items():
conditions.append(
FieldCondition(
key=key,
match=MatchValue(value=value)
)
)
filter_obj = Filter(must=conditions)
self.qdrant_client.delete(
collection_name=self.collection_name,
points_selector=filter_obj
)
log.info(f"Deleted documents matching filter: {filter_dict}")
return True
except Exception as e:
log.error(f"Failed to delete documents: {e}")
return False
def create_hybrid_retriever_wrapper(self, base_retriever):
"""Create a hybrid retriever that combines semantic and keyword search.
This wraps the base retriever to add keyword-based (BM25) results to
semantic search results, prioritizing nutrition facts and specific terms.
Args:
base_retriever: The base semantic search retriever from Qdrant
Returns:
A wrapper retriever that performs hybrid search
"""
def hybrid_invoke(query: str, **kwargs):
# Get semantic results
semantic_docs = base_retriever.invoke(query, **kwargs)
# Try to enhance with keyword search if BM25 is available
if self.bm25_retriever is None:
log.debug("BM25 not available, returning semantic results only")
return semantic_docs
try:
# Get keyword results
keyword_docs = self.bm25_retriever.invoke(query)
# Combine and deduplicate (prioritize semantic, add unique keyword results)
seen_ids = {doc.metadata.get('id') for doc in semantic_docs}
combined = list(semantic_docs)
for doc in keyword_docs:
if doc.metadata.get('id') not in seen_ids:
combined.append(doc)
seen_ids.add(doc.metadata.get('id'))
if len(combined) >= self.retriever_k:
break
log.debug(f"Hybrid search returned {len(semantic_docs)} semantic + {len(combined) - len(semantic_docs)} keyword results")
return combined[:self.retriever_k]
except Exception as e:
log.warning(f"Error in hybrid search, falling back to semantic: {e}")
return semantic_docs
# Create a simple wrapper class
class HybridRetriever:
def __init__(self, invoke_fn):
self.invoke_fn = invoke_fn
def invoke(self, query, **kwargs):
return self.invoke_fn(query, **kwargs)
async def ainvoke(self, query, **kwargs):
return self.invoke_fn(query, **kwargs)
return HybridRetriever(hybrid_invoke)