chat-with-your-data / server /llm_system /core /colpali_embeddings.py
sanchitshaleen
Initial deployment of RAG with Gemma-3 to Hugging Face Spaces
4aec76b
"""
ColPali vision-based document retrieval embeddings.
Uses ColPali to generate multi-vector embeddings from document images,
enabling visual understanding of tables, figures, and complex layouts.
"""
import logging
from typing import Optional
from pathlib import Path
import numpy as np
from PIL import Image
import torch
# Try to import ColPali - graceful fallback if not available
try:
from colpali_engine.models import ColPali
from colpali_engine.utils import process_images
COLPALI_AVAILABLE = True
except ImportError:
COLPALI_AVAILABLE = False
log = logging.getLogger(__name__)
# Global model instance for efficiency
_colpali_model: Optional[ColPali] = None
_device: Optional[torch.device] = None
def _get_device() -> torch.device:
"""Get optimal device (GPU if available, else CPU)."""
global _device
if _device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"ColPali using device: {_device}")
return _device
def _get_colpali_model() -> Optional[ColPali]:
"""Load ColPali model (cached for efficiency)."""
global _colpali_model
if not COLPALI_AVAILABLE:
log.warning("ColPali not available. Install colpali-engine for vision embeddings.")
return None
if _colpali_model is None:
try:
log.info("Loading ColPali model...")
device = _get_device()
model = ColPali.from_pretrained(
"vidore/colpali",
torch_dtype=torch.bfloat16,
device_map=device
)
model.eval()
_colpali_model = model
log.info("✅ ColPali model loaded successfully")
except Exception as e:
log.error(f"Failed to load ColPali model: {e}")
return None
return _colpali_model
def embed_image(image_path: str) -> Optional[dict]:
"""
Generate ColPali multi-vector embeddings from an image.
Args:
image_path: Path to the image file
Returns:
Dict with:
- 'embeddings': List of patch embeddings (multi-vector)
- 'num_patches': Number of patches
- 'image_id': ID of the image
Or None if embedding fails
"""
if not COLPALI_AVAILABLE:
return None
try:
model = _get_colpali_model()
if model is None:
return None
# Load and prepare image
image_path = Path(image_path)
if not image_path.exists():
log.warning(f"Image not found: {image_path}")
return None
image = Image.open(image_path).convert("RGB")
# Process with ColPali
with torch.no_grad():
# Process images returns list of processed images
processed_images = process_images([image])
# Get embeddings
embeddings = model.encode_images(
processed_images,
batch_size=1
)
# Convert to CPU and numpy for storage
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
# embeddings shape: (num_patches, embedding_dim)
num_patches = len(embeddings) if isinstance(embeddings, (list, np.ndarray)) else 1
log.debug(f"Generated {num_patches} patch embeddings for {image_path.name}")
return {
"embeddings": embeddings.tolist() if isinstance(embeddings, np.ndarray) else embeddings,
"num_patches": num_patches,
"image_id": image_path.stem,
"model": "colpali"
}
except Exception as e:
log.error(f"ColPali embedding failed for {image_path}: {e}")
return None
def batch_embed_images(image_paths: list) -> dict:
"""
Batch embed multiple images efficiently.
Args:
image_paths: List of image file paths
Returns:
Dict mapping image_id -> embedding result
"""
if not COLPALI_AVAILABLE:
return {}
try:
model = _get_colpali_model()
if model is None:
return {}
# Load all valid images
images = []
valid_paths = []
for path in image_paths:
path = Path(path)
if path.exists():
try:
img = Image.open(path).convert("RGB")
images.append(img)
valid_paths.append(path)
except Exception as e:
log.warning(f"Could not load image {path}: {e}")
if not images:
return {}
log.info(f"Batch embedding {len(images)} images with ColPali...")
# Process all images
with torch.no_grad():
processed_images = process_images(images)
# Embed in batches (adjust batch_size based on GPU memory)
embeddings = model.encode_images(
processed_images,
batch_size=4
)
# Build results dict
results = {}
for path, emb in zip(valid_paths, embeddings):
emb_np = emb.cpu().numpy() if isinstance(emb, torch.Tensor) else emb
results[path.stem] = {
"embeddings": emb_np.tolist(),
"num_patches": len(emb_np),
"image_id": path.stem,
"model": "colpali"
}
log.info(f"✅ Batch embedded {len(results)} images")
return results
except Exception as e:
log.error(f"Batch embedding failed: {e}")
return {}
def is_colpali_available() -> bool:
"""Check if ColPali is available."""
return COLPALI_AVAILABLE and _get_colpali_model() is not None