| | import torch |
| | import argparse |
| | from safetensors.torch import save_file, safe_open |
| | from tqdm import tqdm |
| | import sys |
| |
|
| |
|
| | def normalize_key(key): |
| | """Strips the 'model.diffusion_model.' prefix from a key if it exists.""" |
| | prefix = 'model.diffusion_model.' |
| | if key.startswith(prefix): |
| | return key[len(prefix):] |
| | return key |
| |
|
| |
|
| | def get_torch_dtype(dtype_str: str): |
| | """Converts a string to a torch.dtype object.""" |
| | if dtype_str == "fp32": |
| | return torch.float32 |
| | if dtype_str == "fp16": |
| | return torch.float16 |
| | if dtype_str == "bf16": |
| | return torch.bfloat16 |
| | raise ValueError(f"Unsupported dtype: {dtype_str}") |
| |
|
| |
|
| | def randomized_svd(matrix, rank, n_oversamples=10): |
| | """Performs Randomized SVD for a faster approximation.""" |
| | max_rank = min(matrix.shape) |
| | if rank >= max_rank: |
| | rank = max_rank |
| | n_oversamples = 0 |
| |
|
| | target_rank = min(rank + n_oversamples, max_rank) |
| |
|
| | P = torch.randn((matrix.shape[1], target_rank), device=matrix.device, dtype=matrix.dtype) |
| | Y = matrix @ P |
| |
|
| | Q, _ = torch.linalg.qr(Y.float()) |
| |
|
| | B = Q.T @ matrix.float() |
| |
|
| | U_b, S, Vh = torch.linalg.svd(B, full_matrices=False) |
| | U = Q @ U_b |
| |
|
| | U = U[:, :rank] |
| | S = S[:rank] |
| | Vh = Vh[:rank, :] |
| |
|
| | return U, S, Vh |
| |
|
| |
|
| | def extract_and_svd_lora(args): |
| | """Main function to extract, decompose, and save the LoRA.""" |
| | print(f"Loading base model A: {args.model_a}") |
| | print(f"Loading finetuned model B: {args.model_b}") |
| | print(f"Using decomposition method: {args.method}") |
| |
|
| | lora_tensors = {} |
| | dtype = get_torch_dtype(args.precision) |
| |
|
| | with safe_open(args.model_a, framework="pt", device="cpu") as f_a, \ |
| | safe_open(args.model_b, framework="pt", device="cpu") as f_b: |
| |
|
| | keys_a_original = set(f_a.keys()) |
| | keys_b_original = set(f_b.keys()) |
| | print(f"\nFound {len(keys_a_original)} keys in model A.") |
| | print(f"Found {len(keys_b_original)} keys in model B.") |
| |
|
| | normalized_keys_a = {normalize_key(k): k for k in keys_a_original} |
| | normalized_keys_b = {normalize_key(k): k for k in keys_b_original} |
| |
|
| | common_normalized_keys = set(normalized_keys_a.keys()).intersection(set(normalized_keys_b.keys())) |
| | print(f"Found {len(common_normalized_keys)} common keys after normalization.\n") |
| |
|
| | processable_keys = {k for k in common_normalized_keys if |
| | (k.endswith('.weight') or k.endswith('.bias')) and 'lora_' not in k} |
| |
|
| | if not processable_keys: |
| | print("No common weight or bias keys found to process. Check if models are compatible.") |
| | sys.exit(1) |
| |
|
| | print(f"Found {len(processable_keys)} common keys to process.") |
| |
|
| | for norm_key in tqdm(sorted(list(processable_keys)), desc="Processing Layers"): |
| | try: |
| | original_key_a = normalized_keys_a[norm_key] |
| | original_key_b = normalized_keys_b[norm_key] |
| |
|
| | tensor_a = f_a.get_tensor(original_key_a).to(device=args.device, dtype=dtype) |
| | tensor_b = f_b.get_tensor(original_key_b).to(device=args.device, dtype=dtype) |
| |
|
| | if tensor_a.shape != tensor_b.shape: |
| | tqdm.write(f"Skipping key {norm_key} due to shape mismatch") |
| | continue |
| |
|
| | delta = tensor_b - tensor_a |
| |
|
| | if norm_key.endswith('.weight'): |
| | delta_w = delta |
| | if delta_w.dim() < 2: |
| | tqdm.write(f"Skipping weight key {norm_key} as it's not a 2D matrix.") |
| | continue |
| | if delta_w.dim() > 2: |
| | delta_w = delta_w.view(delta_w.shape[0], -1) |
| |
|
| | if args.method == 'rsvd': |
| | |
| | U, S, Vh = randomized_svd(delta_w, args.rank, n_oversamples=args.oversamples) |
| | else: |
| | U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False) |
| | current_rank = min(args.rank, S.size(0)) |
| | U = U[:, :current_rank] |
| | S = S[:current_rank] |
| | Vh = Vh[:current_rank, :] |
| |
|
| | lora_down = Vh |
| | lora_up = U @ torch.diag(S) |
| |
|
| | base_name = norm_key.replace('.weight', '') |
| | prefixed_base_name = f"diffusion_model.{base_name}" |
| | lora_down_name = f"{prefixed_base_name}.lora_down.weight" |
| | lora_up_name = f"{prefixed_base_name}.lora_up.weight" |
| |
|
| | lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32) |
| | lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32) |
| |
|
| | except Exception as e: |
| | tqdm.write(f"Failed to process key {norm_key}: {e}") |
| |
|
| | if not lora_tensors: |
| | print("No tensors were processed. Output file will not be created.") |
| | return |
| |
|
| | print(f"\nSaving {len(lora_tensors)} tensors to {args.output}...") |
| | save_file(lora_tensors, args.output) |
| | print("✅ Done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Extract a LoRA/LoRA+ from two SafeTensors checkpoints.") |
| | parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint.") |
| | parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint.") |
| | parser.add_argument("output", type=str, help="Path to save the output file.") |
| |
|
| | parser.add_argument("--rank", type=int, required=True, help="The target rank for the decomposition.") |
| | parser.add_argument("--alpha", type=float, default=1.0, |
| | help="Informational alpha value for scaling. This value is NOT saved in the file.") |
| | parser.add_argument("--method", type=str, default="rsvd", choices=["svd", "rsvd"], help="Decomposition method.") |
| | parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], |
| | help="Device to use for computation.") |
| | parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], |
| | help="Precision for calculations.") |
| | |
| | parser.add_argument("--oversamples", type=int, default=10, |
| | help="Oversampling parameter for Randomized SVD for better accuracy.") |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.device == "cuda" and not torch.cuda.is_available(): |
| | print("CUDA is not available. Falling back to CPU.") |
| | args.device = "cpu" |
| |
|
| | extract_and_svd_lora(args) |