File size: 5,879 Bytes
4aec76b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
"""
ColPali vision-based document retrieval embeddings.
Uses ColPali to generate multi-vector embeddings from document images,
enabling visual understanding of tables, figures, and complex layouts.
"""
import logging
from typing import Optional
from pathlib import Path
import numpy as np
from PIL import Image
import torch
# Try to import ColPali - graceful fallback if not available
try:
from colpali_engine.models import ColPali
from colpali_engine.utils import process_images
COLPALI_AVAILABLE = True
except ImportError:
COLPALI_AVAILABLE = False
log = logging.getLogger(__name__)
# Global model instance for efficiency
_colpali_model: Optional[ColPali] = None
_device: Optional[torch.device] = None
def _get_device() -> torch.device:
"""Get optimal device (GPU if available, else CPU)."""
global _device
if _device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"ColPali using device: {_device}")
return _device
def _get_colpali_model() -> Optional[ColPali]:
"""Load ColPali model (cached for efficiency)."""
global _colpali_model
if not COLPALI_AVAILABLE:
log.warning("ColPali not available. Install colpali-engine for vision embeddings.")
return None
if _colpali_model is None:
try:
log.info("Loading ColPali model...")
device = _get_device()
model = ColPali.from_pretrained(
"vidore/colpali",
torch_dtype=torch.bfloat16,
device_map=device
)
model.eval()
_colpali_model = model
log.info("✅ ColPali model loaded successfully")
except Exception as e:
log.error(f"Failed to load ColPali model: {e}")
return None
return _colpali_model
def embed_image(image_path: str) -> Optional[dict]:
"""
Generate ColPali multi-vector embeddings from an image.
Args:
image_path: Path to the image file
Returns:
Dict with:
- 'embeddings': List of patch embeddings (multi-vector)
- 'num_patches': Number of patches
- 'image_id': ID of the image
Or None if embedding fails
"""
if not COLPALI_AVAILABLE:
return None
try:
model = _get_colpali_model()
if model is None:
return None
# Load and prepare image
image_path = Path(image_path)
if not image_path.exists():
log.warning(f"Image not found: {image_path}")
return None
image = Image.open(image_path).convert("RGB")
# Process with ColPali
with torch.no_grad():
# Process images returns list of processed images
processed_images = process_images([image])
# Get embeddings
embeddings = model.encode_images(
processed_images,
batch_size=1
)
# Convert to CPU and numpy for storage
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
# embeddings shape: (num_patches, embedding_dim)
num_patches = len(embeddings) if isinstance(embeddings, (list, np.ndarray)) else 1
log.debug(f"Generated {num_patches} patch embeddings for {image_path.name}")
return {
"embeddings": embeddings.tolist() if isinstance(embeddings, np.ndarray) else embeddings,
"num_patches": num_patches,
"image_id": image_path.stem,
"model": "colpali"
}
except Exception as e:
log.error(f"ColPali embedding failed for {image_path}: {e}")
return None
def batch_embed_images(image_paths: list) -> dict:
"""
Batch embed multiple images efficiently.
Args:
image_paths: List of image file paths
Returns:
Dict mapping image_id -> embedding result
"""
if not COLPALI_AVAILABLE:
return {}
try:
model = _get_colpali_model()
if model is None:
return {}
# Load all valid images
images = []
valid_paths = []
for path in image_paths:
path = Path(path)
if path.exists():
try:
img = Image.open(path).convert("RGB")
images.append(img)
valid_paths.append(path)
except Exception as e:
log.warning(f"Could not load image {path}: {e}")
if not images:
return {}
log.info(f"Batch embedding {len(images)} images with ColPali...")
# Process all images
with torch.no_grad():
processed_images = process_images(images)
# Embed in batches (adjust batch_size based on GPU memory)
embeddings = model.encode_images(
processed_images,
batch_size=4
)
# Build results dict
results = {}
for path, emb in zip(valid_paths, embeddings):
emb_np = emb.cpu().numpy() if isinstance(emb, torch.Tensor) else emb
results[path.stem] = {
"embeddings": emb_np.tolist(),
"num_patches": len(emb_np),
"image_id": path.stem,
"model": "colpali"
}
log.info(f"✅ Batch embedded {len(results)} images")
return results
except Exception as e:
log.error(f"Batch embedding failed: {e}")
return {}
def is_colpali_available() -> bool:
"""Check if ColPali is available."""
return COLPALI_AVAILABLE and _get_colpali_model() is not None
|