|
|
""" |
|
|
ColPali vision-based document retrieval embeddings. |
|
|
|
|
|
Uses ColPali to generate multi-vector embeddings from document images, |
|
|
enabling visual understanding of tables, figures, and complex layouts. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Optional |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
|
|
|
try: |
|
|
from colpali_engine.models import ColPali |
|
|
from colpali_engine.utils import process_images |
|
|
COLPALI_AVAILABLE = True |
|
|
except ImportError: |
|
|
COLPALI_AVAILABLE = False |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
_colpali_model: Optional[ColPali] = None |
|
|
_device: Optional[torch.device] = None |
|
|
|
|
|
|
|
|
def _get_device() -> torch.device: |
|
|
"""Get optimal device (GPU if available, else CPU).""" |
|
|
global _device |
|
|
if _device is None: |
|
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
log.info(f"ColPali using device: {_device}") |
|
|
return _device |
|
|
|
|
|
|
|
|
def _get_colpali_model() -> Optional[ColPali]: |
|
|
"""Load ColPali model (cached for efficiency).""" |
|
|
global _colpali_model |
|
|
|
|
|
if not COLPALI_AVAILABLE: |
|
|
log.warning("ColPali not available. Install colpali-engine for vision embeddings.") |
|
|
return None |
|
|
|
|
|
if _colpali_model is None: |
|
|
try: |
|
|
log.info("Loading ColPali model...") |
|
|
device = _get_device() |
|
|
model = ColPali.from_pretrained( |
|
|
"vidore/colpali", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=device |
|
|
) |
|
|
model.eval() |
|
|
_colpali_model = model |
|
|
log.info("✅ ColPali model loaded successfully") |
|
|
except Exception as e: |
|
|
log.error(f"Failed to load ColPali model: {e}") |
|
|
return None |
|
|
|
|
|
return _colpali_model |
|
|
|
|
|
|
|
|
def embed_image(image_path: str) -> Optional[dict]: |
|
|
""" |
|
|
Generate ColPali multi-vector embeddings from an image. |
|
|
|
|
|
Args: |
|
|
image_path: Path to the image file |
|
|
|
|
|
Returns: |
|
|
Dict with: |
|
|
- 'embeddings': List of patch embeddings (multi-vector) |
|
|
- 'num_patches': Number of patches |
|
|
- 'image_id': ID of the image |
|
|
Or None if embedding fails |
|
|
""" |
|
|
if not COLPALI_AVAILABLE: |
|
|
return None |
|
|
|
|
|
try: |
|
|
model = _get_colpali_model() |
|
|
if model is None: |
|
|
return None |
|
|
|
|
|
|
|
|
image_path = Path(image_path) |
|
|
if not image_path.exists(): |
|
|
log.warning(f"Image not found: {image_path}") |
|
|
return None |
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
processed_images = process_images([image]) |
|
|
|
|
|
|
|
|
embeddings = model.encode_images( |
|
|
processed_images, |
|
|
batch_size=1 |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(embeddings, torch.Tensor): |
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
|
|
|
num_patches = len(embeddings) if isinstance(embeddings, (list, np.ndarray)) else 1 |
|
|
|
|
|
log.debug(f"Generated {num_patches} patch embeddings for {image_path.name}") |
|
|
|
|
|
return { |
|
|
"embeddings": embeddings.tolist() if isinstance(embeddings, np.ndarray) else embeddings, |
|
|
"num_patches": num_patches, |
|
|
"image_id": image_path.stem, |
|
|
"model": "colpali" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"ColPali embedding failed for {image_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def batch_embed_images(image_paths: list) -> dict: |
|
|
""" |
|
|
Batch embed multiple images efficiently. |
|
|
|
|
|
Args: |
|
|
image_paths: List of image file paths |
|
|
|
|
|
Returns: |
|
|
Dict mapping image_id -> embedding result |
|
|
""" |
|
|
if not COLPALI_AVAILABLE: |
|
|
return {} |
|
|
|
|
|
try: |
|
|
model = _get_colpali_model() |
|
|
if model is None: |
|
|
return {} |
|
|
|
|
|
|
|
|
images = [] |
|
|
valid_paths = [] |
|
|
|
|
|
for path in image_paths: |
|
|
path = Path(path) |
|
|
if path.exists(): |
|
|
try: |
|
|
img = Image.open(path).convert("RGB") |
|
|
images.append(img) |
|
|
valid_paths.append(path) |
|
|
except Exception as e: |
|
|
log.warning(f"Could not load image {path}: {e}") |
|
|
|
|
|
if not images: |
|
|
return {} |
|
|
|
|
|
log.info(f"Batch embedding {len(images)} images with ColPali...") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
processed_images = process_images(images) |
|
|
|
|
|
|
|
|
embeddings = model.encode_images( |
|
|
processed_images, |
|
|
batch_size=4 |
|
|
) |
|
|
|
|
|
|
|
|
results = {} |
|
|
for path, emb in zip(valid_paths, embeddings): |
|
|
emb_np = emb.cpu().numpy() if isinstance(emb, torch.Tensor) else emb |
|
|
results[path.stem] = { |
|
|
"embeddings": emb_np.tolist(), |
|
|
"num_patches": len(emb_np), |
|
|
"image_id": path.stem, |
|
|
"model": "colpali" |
|
|
} |
|
|
|
|
|
log.info(f"✅ Batch embedded {len(results)} images") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"Batch embedding failed: {e}") |
|
|
return {} |
|
|
|
|
|
|
|
|
def is_colpali_available() -> bool: |
|
|
"""Check if ColPali is available.""" |
|
|
return COLPALI_AVAILABLE and _get_colpali_model() is not None |
|
|
|