| | import torch |
| | import argparse |
| | from safetensors.torch import load_file, save_file |
| | from tqdm import tqdm |
| | import os |
| |
|
| | def slerp(t1, t2, alpha): |
| | """ |
| | Performs Spherical Linear Interpolation (SLERP) between two tensors. |
| | """ |
| | |
| | t1_float = t1.float() |
| | t2_float = t2.float() |
| |
|
| | |
| | t1_flat = t1_float.flatten() |
| | t2_flat = t2_float.flatten() |
| |
|
| | |
| | dot = torch.sum(t1_flat * t2_flat) / (torch.linalg.norm(t1_flat) * torch.linalg.norm(t2_flat)) |
| |
|
| | |
| | dot = torch.clamp(dot, -1.0, 1.0) |
| |
|
| | |
| | theta = torch.acos(dot) |
| |
|
| | |
| | |
| | if torch.abs(theta) < 1e-4: |
| | return torch.lerp(t1, t2, alpha) |
| |
|
| | sin_theta = torch.sin(theta) |
| |
|
| | |
| | factor1 = torch.sin((1.0 - alpha) * theta) / sin_theta |
| | factor2 = torch.sin(alpha * theta) / sin_theta |
| |
|
| | |
| | interp_flat = factor1 * t1_flat + factor2 * t2_flat |
| |
|
| | |
| | return interp_flat.reshape(t1.shape).to(t1.dtype) |
| |
|
| | def lerp(t1, t2, alpha): |
| | """ |
| | Performs Linear Interpolation (LERP) between two tensors. |
| | """ |
| | return torch.lerp(t1, t2, alpha) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Merge two Safetensor models using either Linear (LERP) or Spherical (SLERP) interpolation.") |
| | parser.add_argument("model_a", type=str, help="Path to the first model (A).") |
| | parser.add_argument("model_b", type=str, help="Path to the second model (B).") |
| | parser.add_argument("output", type=str, help="Path to save the merged model.") |
| | parser.add_argument("--alpha", type=float, default=0.5, help="Interpolation factor (alpha). 0.0 is 100%% model A, 1.0 is 100%% model B. Default is 0.5.") |
| | parser.add_argument("--method", type=str, default="lerp", choices=["lerp", "slerp"], help="Merge method to use: 'lerp' (linear) or 'slerp' (spherical). Default is 'lerp'.") |
| | |
| | args = parser.parse_args() |
| |
|
| | if not os.path.exists(args.model_a): |
| | print(f"Error: Model file not found at {args.model_a}") |
| | return |
| | if not os.path.exists(args.model_b): |
| | print(f"Error: Model file not found at {args.model_b}") |
| | return |
| |
|
| | print(f"Loading model A from: {args.model_a}") |
| | tensors_a = load_file(args.model_a) |
| | |
| | print(f"Loading model B from: {args.model_b}") |
| | tensors_b = load_file(args.model_b) |
| | |
| | merged_tensors = {} |
| |
|
| | |
| | keys_a = set(tensors_a.keys()) |
| | keys_b = set(tensors_b.keys()) |
| | common_keys = keys_a.intersection(keys_b) |
| | keys_only_in_a = keys_a - keys_b |
| | keys_only_in_b = keys_b - keys_a |
| |
|
| | print(f"\nFound {len(keys_a)} keys in {args.model_a}.") |
| | print(f"Found {len(keys_b)} keys in {args.model_b}.") |
| | print(f"-> Found {len(common_keys)} common keys.") |
| | print(f"-> Found {len(keys_only_in_a)} keys unique to model A.") |
| | print(f"-> Found {len(keys_only_in_b)} keys unique to model B.\n") |
| |
|
| | if not common_keys and not keys_only_in_a and not keys_only_in_b: |
| | print("Warning: No tensors found to merge or copy. The output file will be empty.") |
| | save_file({}, args.output) |
| | print("Operation complete.") |
| | return |
| |
|
| | print(f"Merging {len(common_keys)} common layers with alpha={args.alpha} using {args.method.upper()}...") |
| | for key in tqdm(common_keys, desc="Merging common layers"): |
| | if tensors_a[key].shape != tensors_b[key].shape: |
| | print(f"Warning: Skipping layer '{key}' due to shape mismatch: {tensors_a[key].shape} vs {tensors_b[key].shape}") |
| | merged_tensors[key] = tensors_a[key] |
| | continue |
| |
|
| | tensor_a = tensors_a[key] |
| | tensor_b = tensors_b[key] |
| |
|
| | if not tensor_a.is_floating_point(): |
| | print(f"Warning: Skipping merge for non-floating point tensor '{key}' (dtype: {tensor_a.dtype}). Copying from model A.") |
| | merged_tensors[key] = tensor_a |
| | continue |
| | |
| | if args.method == "slerp": |
| | merged_tensors[key] = slerp(tensor_a, tensor_b, args.alpha) |
| | else: |
| | merged_tensors[key] = lerp(tensor_a, tensor_b, args.alpha) |
| |
|
| |
|
| | |
| | if keys_only_in_a: |
| | print(f"Copying {len(keys_only_in_a)} layers unique to model A...") |
| | for key in tqdm(keys_only_in_a, desc="Copying layers from A"): |
| | merged_tensors[key] = tensors_a[key] |
| |
|
| | if keys_only_in_b: |
| | print(f"Copying {len(keys_only_in_b)} layers unique to model B...") |
| | for key in tqdm(keys_only_in_b, desc="Copying layers from B"): |
| | merged_tensors[key] = tensors_b[key] |
| |
|
| | print(f"\nSaving merged model to: {args.output}") |
| | save_file(merged_tensors, args.output) |
| | print("Merge complete!") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|