| | """SHARP inference utilities (PLY export only). |
| | |
| | This module intentionally does *not* implement MP4/video rendering. |
| | It provides a small, Spaces/ZeroGPU-friendly wrapper that: |
| | - Caches model weights and predictor construction across requests. |
| | - Runs SHARP inference and exports a canonical `.ply`. |
| | |
| | Public API (used by the Gradio app): |
| | - predict_to_ply_gpu(...) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | import threading |
| | import time |
| | import uuid |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Final |
| |
|
| | import torch |
| |
|
| | try: |
| | import spaces |
| | except Exception: |
| | spaces = None |
| |
|
| | try: |
| | |
| | from huggingface_hub import hf_hub_download, try_to_load_from_cache |
| | except Exception: |
| | hf_hub_download = None |
| | try_to_load_from_cache = None |
| |
|
| | from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image |
| | from sharp.models import PredictorParams, create_predictor |
| | from sharp.utils import io |
| | from sharp.utils.gaussians import save_ply |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _now_ms() -> int: |
| | return int(time.time() * 1000) |
| |
|
| |
|
| | def _ensure_dir(path: Path) -> Path: |
| | path.mkdir(parents=True, exist_ok=True) |
| | return path |
| |
|
| |
|
| | def _make_even(x: int) -> int: |
| | return x if x % 2 == 0 else x + 1 |
| |
|
| |
|
| | def _select_device(preference: str = "auto") -> torch.device: |
| | """Select the best available device for inference (CPU/CUDA/MPS).""" |
| | if preference not in {"auto", "cpu", "cuda", "mps"}: |
| | raise ValueError("device preference must be one of: auto|cpu|cuda|mps") |
| |
|
| | if preference == "cpu": |
| | return torch.device("cpu") |
| | if preference == "cuda": |
| | return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | if preference == "mps": |
| | return torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | if torch.backends.mps.is_available(): |
| | return torch.device("mps") |
| | return torch.device("cpu") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass(frozen=True, slots=True) |
| | class PredictionOutputs: |
| | """Outputs of SHARP inference.""" |
| |
|
| | ply_path: Path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class ModelWrapper: |
| | """Cached SHARP model wrapper for Gradio/Spaces.""" |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | outputs_dir: str | Path = "outputs", |
| | checkpoint_url: str = DEFAULT_MODEL_URL, |
| | checkpoint_path: str | Path | None = None, |
| | device_preference: str = "auto", |
| | keep_model_on_device: bool | None = None, |
| | hf_repo_id: str | None = None, |
| | hf_filename: str | None = None, |
| | hf_revision: str | None = None, |
| | ) -> None: |
| | self.outputs_dir = _ensure_dir(Path(outputs_dir)) |
| | self.checkpoint_url = checkpoint_url |
| |
|
| | env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT") |
| | if checkpoint_path: |
| | self.checkpoint_path = Path(checkpoint_path) |
| | elif env_ckpt: |
| | self.checkpoint_path = Path(env_ckpt) |
| | else: |
| | self.checkpoint_path = None |
| |
|
| | |
| | self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp") |
| | self.hf_filename = hf_filename or os.getenv( |
| | "SHARP_HF_FILENAME", "sharp_2572gikvuh.pt" |
| | ) |
| | self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None |
| |
|
| | self.device_preference = device_preference |
| |
|
| | |
| | if keep_model_on_device is None: |
| | keep_env = ( |
| | os.getenv("SHARP_KEEP_MODEL_ON_DEVICE") |
| | ) |
| | self.keep_model_on_device = keep_env == "1" |
| | else: |
| | self.keep_model_on_device = keep_model_on_device |
| |
|
| | self._lock = threading.RLock() |
| | self._predictor: torch.nn.Module | None = None |
| | self._predictor_device: torch.device | None = None |
| | self._state_dict: dict | None = None |
| |
|
| | def has_cuda(self) -> bool: |
| | return torch.cuda.is_available() |
| |
|
| | def _load_state_dict(self) -> dict: |
| | with self._lock: |
| | if self._state_dict is not None: |
| | return self._state_dict |
| |
|
| | |
| | if self.checkpoint_path is not None: |
| | try: |
| | self._state_dict = torch.load( |
| | self.checkpoint_path, |
| | weights_only=True, |
| | map_location="cpu", |
| | ) |
| | return self._state_dict |
| | except Exception as e: |
| | raise RuntimeError( |
| | "Failed to load SHARP checkpoint from local path.\n\n" |
| | f"Path:\n {self.checkpoint_path}\n\n" |
| | f"Original error:\n {type(e).__name__}: {e}" |
| | ) from e |
| |
|
| | |
| | hf_cache_error: Exception | None = None |
| | if try_to_load_from_cache is not None: |
| | try: |
| | cached = try_to_load_from_cache( |
| | repo_id=self.hf_repo_id, |
| | filename=self.hf_filename, |
| | revision=self.hf_revision, |
| | repo_type="model", |
| | ) |
| | except TypeError: |
| | cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) |
| |
|
| | try: |
| | if isinstance(cached, str) and Path(cached).exists(): |
| | self._state_dict = torch.load( |
| | cached, weights_only=True, map_location="cpu" |
| | ) |
| | return self._state_dict |
| | except Exception as e: |
| | hf_cache_error = e |
| |
|
| | |
| | hf_error: Exception | None = None |
| | if hf_hub_download is not None: |
| | |
| | try: |
| | import inspect |
| |
|
| | if "local_files_only" in inspect.signature(hf_hub_download).parameters: |
| | ckpt_path = hf_hub_download( |
| | repo_id=self.hf_repo_id, |
| | filename=self.hf_filename, |
| | revision=self.hf_revision, |
| | local_files_only=True, |
| | ) |
| | if Path(ckpt_path).exists(): |
| | self._state_dict = torch.load( |
| | ckpt_path, weights_only=True, map_location="cpu" |
| | ) |
| | return self._state_dict |
| | except Exception: |
| | pass |
| |
|
| | try: |
| | ckpt_path = hf_hub_download( |
| | repo_id=self.hf_repo_id, |
| | filename=self.hf_filename, |
| | revision=self.hf_revision, |
| | ) |
| | self._state_dict = torch.load( |
| | ckpt_path, |
| | weights_only=True, |
| | map_location="cpu", |
| | ) |
| | return self._state_dict |
| | except Exception as e: |
| | hf_error = e |
| |
|
| | |
| | url_error: Exception | None = None |
| | try: |
| | self._state_dict = torch.hub.load_state_dict_from_url( |
| | self.checkpoint_url, |
| | progress=True, |
| | map_location="cpu", |
| | ) |
| | return self._state_dict |
| | except Exception as e: |
| | url_error = e |
| |
|
| | |
| | hint_lines = [ |
| | "Failed to load SHARP checkpoint.", |
| | "", |
| | "Tried (in order):", |
| | f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", |
| | f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}", |
| | f" 3) URL (torch hub): {self.checkpoint_url}", |
| | "", |
| | "If network access is restricted, set a local checkpoint path:", |
| | " - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt", |
| | "", |
| | "Original errors:", |
| | ] |
| | if try_to_load_from_cache is None: |
| | hint_lines.append(" HF cache: huggingface_hub not installed") |
| | elif hf_cache_error is not None: |
| | hint_lines.append( |
| | f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}" |
| | ) |
| | else: |
| | hint_lines.append(" HF cache: (not found in cache)") |
| |
|
| | if hf_hub_download is None: |
| | hint_lines.append(" HF download: huggingface_hub not installed") |
| | else: |
| | hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}") |
| |
|
| | hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}") |
| |
|
| | raise RuntimeError("\n".join(hint_lines)) |
| |
|
| | def _get_predictor(self, device: torch.device) -> torch.nn.Module: |
| | with self._lock: |
| | if self._predictor is None: |
| | state_dict = self._load_state_dict() |
| | predictor = create_predictor(PredictorParams()) |
| | predictor.load_state_dict(state_dict) |
| | predictor.eval() |
| | self._predictor = predictor |
| | self._predictor_device = torch.device("cpu") |
| |
|
| | assert self._predictor is not None |
| | assert self._predictor_device is not None |
| |
|
| | if self._predictor_device != device: |
| | self._predictor.to(device) |
| | self._predictor_device = device |
| |
|
| | return self._predictor |
| |
|
| | def _maybe_move_model_back_to_cpu(self) -> None: |
| | if self.keep_model_on_device: |
| | return |
| | with self._lock: |
| | if self._predictor is not None and self._predictor_device is not None: |
| | if self._predictor_device.type != "cpu": |
| | self._predictor.to("cpu") |
| | self._predictor_device = torch.device("cpu") |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | def _make_output_stem(self, input_path: Path) -> str: |
| | return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}" |
| |
|
| | def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs: |
| | """Run SHARP inference and export a .ply file.""" |
| | image_path = Path(image_path) |
| | if not image_path.exists(): |
| | raise FileNotFoundError(f"Image does not exist: {image_path}") |
| |
|
| | device = _select_device(self.device_preference) |
| | predictor = self._get_predictor(device) |
| |
|
| | image_np, _, f_px = io.load_rgb(image_path) |
| | height, width = image_np.shape[:2] |
| |
|
| | with torch.no_grad(): |
| | gaussians = predict_image(predictor, image_np, f_px, device) |
| |
|
| | stem = self._make_output_stem(image_path) |
| | ply_path = self.outputs_dir / f"{stem}.ply" |
| |
|
| | |
| | save_ply(gaussians, f_px, (height, width), ply_path) |
| |
|
| | self._maybe_move_model_back_to_cpu() |
| |
|
| | return PredictionOutputs(ply_path=ply_path) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs") |
| |
|
| | _GLOBAL_MODEL: ModelWrapper | None = None |
| | _GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock() |
| |
|
| |
|
| | def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper: |
| | global _GLOBAL_MODEL |
| | with _GLOBAL_MODEL_INIT_LOCK: |
| | if _GLOBAL_MODEL is None: |
| | _GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir) |
| | return _GLOBAL_MODEL |
| |
|
| | def predict_to_ply( |
| | image_path: str | Path, |
| | ) -> Path: |
| | model = get_global_model() |
| | return model.predict_to_ply(image_path).ply_path |
| |
|
| |
|
| | |
| | if spaces is not None: |
| | predict_to_ply_gpu = spaces.GPU(duration=180)(predict_to_ply) |
| | else: |
| | predict_to_ply_gpu = predict_to_ply |
| |
|