| | |
| |
|
| | import argparse |
| | import torch |
| | from safetensors import safe_open |
| |
|
| |
|
| | def compare_safetensors(filepath1: str, filepath2: str): |
| | """ |
| | Compares two .safetensors files, ignoring a specific prefix on layer names, |
| | and prints a summary of the differences. |
| | |
| | Args: |
| | filepath1 (str): Path to the first .safetensors file. |
| | filepath2 (str): Path to the second .safetensors file. |
| | """ |
| | |
| | prefix_to_ignore = "model.diffusion_model." |
| |
|
| | |
| | results = { |
| | "only_in_file1": [], |
| | "only_in_file2": [], |
| | "different_content": [], |
| | } |
| |
|
| | print("\nLoading files and preparing for comparison...") |
| | print(f"Ignoring prefix: '{prefix_to_ignore}'") |
| |
|
| | try: |
| | |
| | with safe_open(filepath1, framework="pt", device="cpu") as f1, \ |
| | safe_open(filepath2, framework="pt", device="cpu") as f2: |
| |
|
| | |
| | map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()} |
| | map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()} |
| |
|
| | |
| | normalized_keys1 = set(map1.keys()) |
| | normalized_keys2 = set(map2.keys()) |
| |
|
| | |
| | results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2)) |
| | results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1)) |
| |
|
| | |
| | common_normalized_keys = normalized_keys1.intersection(normalized_keys2) |
| | print(f"Comparing {len(common_normalized_keys)} common tensors...") |
| |
|
| | for norm_key in sorted(list(common_normalized_keys)): |
| | |
| | original_key1 = map1[norm_key] |
| | original_key2 = map2[norm_key] |
| |
|
| | |
| | tensor1 = f1.get_tensor(original_key1) |
| | tensor2 = f2.get_tensor(original_key2) |
| |
|
| | |
| | if not torch.equal(tensor1, tensor2): |
| | |
| | results["different_content"].append(norm_key) |
| |
|
| | |
| | print("\n" + "=" * 60) |
| | print("🔍 Safetensor Comparison Results") |
| | print("=" * 60) |
| | print(f"File 1: {filepath1}") |
| | print(f"File 2: {filepath2}") |
| | print("-" * 60) |
| |
|
| | |
| | total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"]) |
| | if total_diffs == 0: |
| | print("\n✅ The files are identical after normalization. No differences found.") |
| | print("=" * 60 + "\n") |
| | return |
| |
|
| | |
| | if results["different_content"]: |
| | print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):") |
| | for norm_key in results["different_content"]: |
| | print(f" - Normalized Key: {norm_key}") |
| | print(f" (File 1 Original: {map1[norm_key]})") |
| | print(f" (File 2 Original: {map2[norm_key]})") |
| |
|
| | |
| | if results["only_in_file1"]: |
| | print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):") |
| | for norm_key in results["only_in_file1"]: |
| | print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})") |
| |
|
| | |
| | if results["only_in_file2"]: |
| | print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):") |
| | for norm_key in results["only_in_file2"]: |
| | print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})") |
| |
|
| | print("\n" + "=" * 60 + "\n") |
| |
|
| | except FileNotFoundError as e: |
| | print(f"❌ Error: Could not find a file. Details: {e}") |
| | except Exception as e: |
| | print(f"❌ An error occurred: {e}") |
| | print("Please ensure both files are valid .safetensors files.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | parser = argparse.ArgumentParser( |
| | description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.", |
| | formatter_class=argparse.RawTextHelpFormatter |
| | ) |
| |
|
| | parser.add_argument( |
| | "file1", |
| | type=str, |
| | help="Path to the first .safetensors file." |
| | ) |
| | parser.add_argument( |
| | "file2", |
| | type=str, |
| | help="Path to the second .safetensors file." |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | compare_safetensors(args.file1, args.file2) |