| | import os |
| | import argparse |
| | import torch |
| | from safetensors.torch import load_file, save_file |
| | from safetensors import safe_open |
| | from tqdm import tqdm |
| |
|
| |
|
| | def resize_lora_model(model_path, output_path, new_dim, device, method): |
| | """ |
| | Resizes the LoRA dimension of a model using SVD or Randomized SVD. |
| | Also scales the alpha value(s) proportionally. |
| | |
| | Args: |
| | model_path (str): Path to the LoRA model to resize. |
| | output_path (str): Path to save the new resized model. |
| | new_dim (int): The target new dimension for the LoRA weights. |
| | device (str): The device to run calculations on ('cuda' or 'cpu'). |
| | method (str): The resizing method to use ('svd' or 'randomized_svd'). |
| | """ |
| | print(f"Loading model from: {model_path}") |
| | |
| | model = load_file(model_path, device="cpu") |
| | new_model = {} |
| |
|
| | |
| | original_dim = None |
| | alpha = None |
| | try: |
| | with safe_open(model_path, framework="pt", device="cpu") as f: |
| | metadata = f.metadata() |
| | if metadata: |
| | if 'ss_network_dim' in metadata: |
| | original_dim = int(metadata['ss_network_dim']) |
| | print(f"Original dimension (from metadata): {original_dim}") |
| | if 'ss_network_alpha' in metadata: |
| | alpha = float(metadata['ss_network_alpha']) |
| | print(f"Original alpha (from metadata): {alpha}") |
| | except Exception as e: |
| | print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.") |
| |
|
| | |
| | if original_dim is None: |
| | for key in model.keys(): |
| | if key.endswith((".lora_down.weight", ".lora_A.weight")): |
| | original_dim = model[key].shape[0] |
| | print(f"Inferred original dimension from weights: {original_dim}") |
| | break |
| |
|
| | if original_dim is None: |
| | print("Error: Could not determine original LoRA dimension.") |
| | return |
| |
|
| | if original_dim == new_dim: |
| | print("Error: New dimension is the same as the original dimension. No changes to make.") |
| | return |
| |
|
| | |
| | if alpha is None: |
| | for key in model.keys(): |
| | if key.endswith(".alpha"): |
| | alpha = model[key].item() |
| | print(f"Inferred alpha from weights: {alpha}") |
| | break |
| |
|
| | |
| | if alpha is None: |
| | alpha = float(original_dim) |
| | print(f"Alpha not found, falling back to using dimension value: {alpha}") |
| |
|
| | |
| | |
| | ratio = new_dim / original_dim |
| | print(f"Dimension resize ratio: {ratio:.4f}") |
| |
|
| | lora_keys_to_process = set() |
| | for key in model.keys(): |
| | if 'lora_' in key and key.endswith('.weight'): |
| | base_key = key.split('.lora_')[0] |
| | lora_keys_to_process.add(base_key) |
| |
|
| | if not lora_keys_to_process: |
| | print("Error: No LoRA weights found in the model.") |
| | return |
| |
|
| | print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...") |
| | print(f"Using '{method}' method for resizing.") |
| |
|
| | for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"): |
| | try: |
| | down_key, up_key = None, None |
| |
|
| | |
| | if base_key + ".lora_down.weight" in model: |
| | down_key = base_key + ".lora_down.weight" |
| | up_key = base_key + ".lora_up.weight" |
| | elif base_key + ".lora_A.weight" in model: |
| | down_key = base_key + ".lora_A.weight" |
| | up_key = base_key + ".lora_B.weight" |
| | else: |
| | continue |
| |
|
| | down_weight = model[down_key] |
| | up_weight = model[up_key] |
| | original_dtype = up_weight.dtype |
| |
|
| | |
| | down_weight = down_weight.to(device, dtype=torch.float32) |
| | up_weight = up_weight.to(device, dtype=torch.float32) |
| |
|
| | |
| | conv2d = down_weight.ndim == 4 |
| | if conv2d: |
| | conv_shape = down_weight.shape |
| | down_weight = down_weight.flatten(1) |
| | up_weight = up_weight.flatten(1) |
| |
|
| | |
| | full_weight = up_weight @ down_weight |
| |
|
| | if method == 'svd': |
| | |
| | U, S, Vh = torch.linalg.svd(full_weight) |
| |
|
| | |
| | U = U[:, :new_dim] |
| | S = S[:new_dim] |
| | Vh = Vh[:new_dim, :] |
| |
|
| | |
| | |
| | S_sqrt = torch.sqrt(S) |
| | new_up = U @ torch.diag(S_sqrt) |
| | new_down = torch.diag(S_sqrt) @ Vh |
| |
|
| | elif method == 'randomized_svd': |
| | |
| | U, S, V = torch.svd_lowrank(full_weight, q=new_dim) |
| | Vh = V.T |
| |
|
| | |
| | S_sqrt = torch.sqrt(S) |
| | new_up = U @ torch.diag(S_sqrt) |
| | new_down = torch.diag(S_sqrt) @ Vh |
| |
|
| | if conv2d: |
| | new_down = new_down.reshape(new_dim, conv_shape[1], conv_shape[2], conv_shape[3]) |
| |
|
| | |
| | new_model[down_key] = new_down.contiguous().to(original_dtype) |
| | new_model[up_key] = new_up.contiguous().to(original_dtype) |
| |
|
| | |
| | |
| | alpha_key = base_key + ".alpha" |
| | if alpha_key in model: |
| | original_alpha_tensor = model[alpha_key] |
| | |
| | new_alpha_value = original_alpha_tensor.item() * ratio |
| | new_model[alpha_key] = torch.tensor(new_alpha_value, dtype=original_alpha_tensor.dtype) |
| | |
| |
|
| | except Exception as e: |
| | print(f"Warning: Failed to process {base_key}. Error: {e}") |
| | continue |
| |
|
| | |
| | for key, value in model.items(): |
| | if ".lora_" not in key: |
| | |
| | if ".alpha" not in key or key not in new_model: |
| | new_model[key] = value |
| |
|
| | |
| | new_metadata = {'ss_network_dim': str(new_dim)} |
| | new_alpha = alpha * ratio |
| | new_metadata['ss_network_alpha'] = str(new_alpha) |
| | print(f"\nNew global alpha scaled to: {new_alpha:.2f}") |
| |
|
| | |
| | if device != 'cpu': |
| | print("\nMoving processed tensors to CPU for saving...") |
| | for key in tqdm(new_model.keys(), desc="Finalizing"): |
| | if isinstance(new_model[key], torch.Tensor): |
| | new_model[key] = new_model[key].cpu() |
| |
|
| | print(f"\nSaving resized model to: {output_path}") |
| | save_file(new_model, output_path, metadata=new_metadata) |
| | print("Done! 🎉") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description="Resize a LoRA model to a new dimension and scales alpha proportionally.", |
| | formatter_class=argparse.RawTextHelpFormatter |
| | ) |
| | parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).") |
| | parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.") |
| | parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).") |
| | parser.add_argument("--device", type=str, default=None, |
| | help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.") |
| | parser.add_argument( |
| | "--method", |
| | type=str, |
| | default="svd", |
| | choices=["svd", "randomized_svd"], |
| | help="""Resizing method: |
| | 'svd' (default): Accurate but slower. Uses full SVD for optimal weight preservation. |
| | 'randomized_svd': Faster approximation of SVD. Excellent for speed on large models.""" |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.device: |
| | device = args.device |
| | else: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | print(f"Using device: {device}") |
| |
|
| | resize_lora_model(args.model_path, args.output_path, args.new_dim, device, args.method) |