Spaces:
Sleeping
Sleeping
| import collections | |
| from pathlib import Path | |
| import functools | |
| import os | |
| import safetensors.torch | |
| from huggingface_hub import model_info, hf_hub_download | |
| import tempfile | |
| import torch | |
| import functools | |
| import os | |
| import requests | |
| import struct | |
| from huggingface_hub import hf_hub_url | |
| DTYPE_MAP = {"FP32": torch.float32, "FP16": torch.float16, "BF16": torch.bfloat16} | |
| # https://huggingface.co/docs/safetensors/v0.3.2/metadata_parsing#python | |
| def _parse_single_file(url): | |
| print(f"{url=}") | |
| token = os.getenv("HF_TOKEN") | |
| headers = {"Range": "bytes=0-7", "Authorization": f"Bearer {token}"} | |
| response = requests.get(url, headers=headers) | |
| length_of_header = struct.unpack("<Q", response.content)[0] | |
| headers = {"Range": f"bytes=8-{7 + length_of_header}", "Authorization": f"Bearer {token}"} | |
| response = requests.get(url, headers=headers) | |
| header = response.json() | |
| return header | |
| def _get_dtype_from_safetensor_file(file_path): | |
| """Inspects a safetensors file and returns the dtype of the first tensor. | |
| If it's not a safetensors file and a URL instead, we query it. | |
| """ | |
| if "https" in file_path: | |
| metadata = _parse_single_file(file_path) | |
| except_format_metadata_keys = sorted({k for k in metadata if k != "__metadata__"}) | |
| string_dtype = metadata[except_format_metadata_keys[0]]["dtype"] | |
| return DTYPE_MAP[string_dtype] | |
| try: | |
| # load_file is simple and sufficient for this info-gathering purpose. | |
| state_dict = safetensors.torch.load_file(file_path) | |
| if not state_dict: | |
| return "N/A (empty)" | |
| # Get the dtype from the first tensor in the state dict | |
| first_tensor = next(iter(state_dict.values())) | |
| return first_tensor.dtype | |
| except Exception as e: | |
| print(f"Warning: Could not determine dtype from {file_path}. Error: {e}") | |
| return "N/A (error)" | |
| def _process_components(component_files, file_accessor_fn, disable_bf16=False): | |
| """ | |
| Generic function to process components, calculate size, and determine dtype. | |
| Args: | |
| component_files (dict): A dictionary mapping component names to lists of file objects. | |
| file_accessor_fn (function): A function that takes a file object and returns | |
| a tuple of (local_path_for_inspection, size_in_bytes, relative_filename). | |
| disable_bf16 (bool): To disable using `torch.bfloat16`. Use it at your own risk. | |
| Returns: | |
| dict: A dictionary containing the total memory and detailed component info. | |
| """ | |
| components_info = {} | |
| total_size_bytes = 0 | |
| for name, files in component_files.items(): | |
| # Get dtype by inspecting the first file of the component | |
| first_file = files[0] | |
| # The accessor function handles how to get the path (download vs local) | |
| # and its size and relative name. | |
| inspection_path, _, _ = file_accessor_fn(first_file) | |
| dtype = _get_dtype_from_safetensor_file(inspection_path) | |
| component_size_bytes = 0 | |
| component_file_details = [] | |
| for f in files: | |
| _, size_bytes, rel_filename = file_accessor_fn(f) | |
| component_size_bytes += size_bytes | |
| component_file_details.append({"filename": rel_filename, "size_mb": size_bytes / (1024**2)}) | |
| if dtype == torch.float32 and not disable_bf16: | |
| print( | |
| f"The `dtype` for component ({name}) is torch.float32. Since bf16 computation is not disabled " | |
| "we will slash the total size of this component by 2." | |
| ) | |
| total_size_bytes += component_size_bytes / 2 | |
| else: | |
| total_size_bytes += component_size_bytes | |
| components_info[name] = { | |
| "size_gb": round(component_size_bytes / (1024**3), 3), | |
| "dtype": dtype, | |
| "files": sorted(component_file_details, key=lambda x: x["filename"]), | |
| } | |
| return { | |
| "total_loading_memory_gb": round(total_size_bytes / (1024**3), 3), | |
| "components": components_info, | |
| } | |
| def _determine_memory_from_hub_ckpt(ckpt_id, variant=None, disable_bf16=False): | |
| """ | |
| Determines memory and dtypes for a checkpoint on the Hugging Face Hub. | |
| """ | |
| files_in_repo = model_info(ckpt_id, files_metadata=True, token=os.getenv("HF_TOKEN")).siblings | |
| all_safetensors_siblings = [ | |
| s for s in files_in_repo if s.rfilename.endswith(".safetensors") and "/" in s.rfilename | |
| ] | |
| if variant: | |
| all_safetensors_siblings = [f for f in all_safetensors_siblings if variant in f.rfilename] | |
| component_files = collections.defaultdict(list) | |
| for sibling in all_safetensors_siblings: | |
| component_name = Path(sibling.rfilename).parent.name | |
| component_files[component_name].append(sibling) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| def hub_file_accessor(file_obj): | |
| """Accessor for Hub files: downloads them and returns path/size.""" | |
| print(f"Querying '{file_obj.rfilename}' for inspection...") | |
| url = hf_hub_url(ckpt_id, file_obj.rfilename) | |
| return url, file_obj.size, file_obj.rfilename | |
| # We only need to download one file per component for dtype inspection. | |
| # To make this efficient, we create a specialized accessor for the processing loop | |
| # that only downloads the *first* file encountered for a component. | |
| downloaded_for_inspection = {} | |
| def efficient_hub_accessor(file_obj): | |
| component_name = Path(file_obj.rfilename).parent.name | |
| if component_name not in downloaded_for_inspection: | |
| path, _, _ = hub_file_accessor(file_obj) | |
| downloaded_for_inspection[component_name] = path | |
| inspection_path = downloaded_for_inspection[component_name] | |
| return inspection_path, file_obj.size, file_obj.rfilename | |
| return _process_components(component_files, efficient_hub_accessor, disable_bf16) | |
| def _determine_memory_from_local_ckpt(path: str, variant=None, disable_bf16=False): | |
| """ | |
| Determines memory and dtypes for a local checkpoint. | |
| """ | |
| ckpt_path = Path(path) | |
| if not ckpt_path.is_dir(): | |
| return {"error": f"Checkpoint path '{path}' not found or is not a directory."} | |
| all_safetensors_paths = list(ckpt_path.glob("**/*.safetensors")) | |
| if variant: | |
| all_safetensors_paths = [p for p in all_safetensors_paths if variant in p.name] | |
| component_files = collections.defaultdict(list) | |
| for file_path in all_safetensors_paths: | |
| component_name = file_path.parent.name | |
| component_files[component_name].append(file_path) | |
| def local_file_accessor(file_path): | |
| """Accessor for local files: just returns their path and size.""" | |
| return file_path, file_path.stat().st_size, str(file_path.relative_to(ckpt_path)) | |
| return _process_components(component_files, local_file_accessor, disable_bf16) | |
| def determine_pipe_loading_memory(ckpt_id: str, variant=None, disable_bf16=False): | |
| """ | |
| Determines the memory and dtypes for a pipeline, whether it's local or on the Hub. | |
| """ | |
| if os.path.isdir(ckpt_id): | |
| return _determine_memory_from_local_ckpt(ckpt_id, variant, disable_bf16) | |
| else: | |
| return _determine_memory_from_hub_ckpt(ckpt_id, variant, disable_bf16) | |
| if __name__ == "__main__": | |
| output = _determine_memory_from_hub_ckpt("Wan-AI/Wan2.1-T2V-14B-Diffusers") | |
| total_size_gb = output["total_loading_memory_gb"] | |
| safetensor_files = output["components"] | |
| print(f"{total_size_gb=} GB") | |
| print(f"{safetensor_files=}") | |
| print("\n") | |
| # total_size_gb, safetensor_files = _determine_memory_from_local_ckpt("LOCAL_DIR") # change me. | |
| # print(f"{total_size_gb=} GB") | |
| # print(f"{safetensor_files=}") | |