| """
|
| OM_train_2modes.py — Dual-mode training (diffusion + registration) using OMorpher.
|
|
|
| Drop-in replacement for OM_train_2modes.py. Uses the OMorpher object-oriented
|
| wrapper instead of procedural DeformDDPM calls, while preserving the same
|
| training logic, DDP support, loss functions, and checkpoint format.
|
|
|
| Usage:
|
| # Single-GPU
|
| python Scripts/OM_train_2modes.py -C Config/config_om.yaml
|
|
|
| # Multi-GPU (DDP)
|
| CUDA_VISIBLE_DEVICES=0,1 python Scripts/OM_train_2modes.py -C Config/config_om.yaml
|
| """
|
|
|
| import os
|
| import sys
|
|
|
|
|
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| sys.path.insert(0, ROOT_DIR)
|
|
|
| import gc
|
| import glob
|
| import random
|
|
|
| import numpy as np
|
| import torch
|
| import torch.distributed as dist
|
| import torch.multiprocessing as mp
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
| from torch.optim import Adam
|
| from torch.utils.data import DataLoader
|
| from tqdm import tqdm
|
| import argparse
|
|
|
| from OMorpher import OMorpher
|
| from Diffusion.networks import DefRec_MutAttnNet
|
| from Diffusion.losses import Grad, LNCC, LMSE, MSLNCC
|
| from Dataloader.dataLoader import OMDataset_indiv, OMDataset_pair
|
| from Dataloader.dataloader_utils import thresh_img
|
| import utils
|
|
|
|
|
|
|
| EPS = 1e-5
|
| MSK_EPS = 0.01
|
| TEXT_EMBED_PROB = 0.7
|
| AUG_RESAMPLE_PROB = 0.6
|
| LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16]
|
| LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128]
|
| DIFF_REG_BATCH_RATIO = 2
|
|
|
| use_distributed = True
|
|
|
|
|
|
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "--config", "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_all.yaml",
|
| required=False,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
|
|
| def ddp_setup(rank, world_size):
|
| """
|
| Args:
|
| rank: Unique identifier of each process
|
| world_size: Total number of processes
|
| """
|
| os.environ["MASTER_ADDR"] = "localhost"
|
| os.environ["MASTER_PORT"] = "12355"
|
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| torch.cuda.set_device(rank)
|
|
|
|
|
|
|
|
|
|
|
| def reverse_diffuse_train(network, om, img_org, cond_imgs, T, text=None):
|
| """Registration reverse diffusion with selective gradient control.
|
|
|
| Mirrors DeformDDPM.diff_recover() with T=[None, T_regist].
|
| Only the last k=2 timesteps have gradients enabled for efficient training.
|
|
|
| Args:
|
| network: DDP-wrapped (or raw) network module.
|
| om: OMorpher instance (provides STN instances and device info).
|
| img_org: Source image [B, 1, S, S, S].
|
| cond_imgs: Processed conditioning image [B, 1, S, S, S].
|
| T: [T_init, T_schedule]. T_init=None means no forward diffusion.
|
| T_schedule is a list of batched timestep lists from the training loop.
|
| text: Optional text embedding [B, 1024].
|
|
|
| Returns:
|
| (ddf_comp, img_rec): Composed DDF and recovered image.
|
| """
|
| B = img_org.shape[0]
|
| S = om.img_size
|
|
|
|
|
| ddf_comp = torch.zeros(
|
| [B, om.ndims] + [S] * om.ndims,
|
| dtype=torch.float32, device=om.device,
|
| )
|
| img_rec = img_org.clone().detach()
|
|
|
| time_steps = T[1]
|
| k = 2
|
| trainable_iterations = time_steps[-1:-k - 1:-1]
|
|
|
| net_module = network.module if isinstance(network, DDP) else network
|
|
|
| for i in time_steps:
|
| t = torch.tensor(np.array([i])).to(om.device)
|
|
|
| if i in trainable_iterations:
|
|
|
| pre_dvf = network(x=img_rec, y=cond_imgs, t=t, rec_num=2, text=text)
|
| else:
|
|
|
| with torch.no_grad():
|
| pre_dvf = net_module(x=img_rec, y=cond_imgs, t=t, rec_num=2, text=text)
|
|
|
| ddf_comp = om.stn_full(ddf_comp, pre_dvf) + pre_dvf
|
| img_rec = om.img_stn(img_org.clone().detach(), ddf_comp)
|
|
|
| return ddf_comp, img_rec
|
|
|
|
|
| def ddp_load_checkpoint(gpu_id, network, optimizer, model_file,
|
| use_dist=True, load_strict=False):
|
| """Load checkpoint with DDP-aware parameter broadcast."""
|
| if gpu_id == 0:
|
| utils.print_memory_usage("Before Loading Model")
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
|
|
| checkpoint = torch.load(model_file)
|
| state_dict = checkpoint['model_state_dict']
|
|
|
|
|
| cleaned = {}
|
| for k, v in state_dict.items():
|
| k = k.replace("module.", "")
|
| if k.startswith("network."):
|
| k = k[len("network."):]
|
| cleaned[k] = v
|
|
|
| net = network.module if use_dist else network
|
| net_keys = set(net.state_dict().keys())
|
| filtered = {k: v for k, v in cleaned.items() if k in net_keys}
|
| net.load_state_dict(filtered, strict=load_strict)
|
|
|
| if load_strict:
|
| optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| utils.print_memory_usage("After Loading Checkpoint on GPU")
|
|
|
| if use_dist:
|
|
|
| dist.barrier()
|
| for param in network.parameters():
|
| dist.broadcast(param.data, src=0)
|
| dist.barrier()
|
| for param_group in optimizer.param_groups:
|
| for param in param_group['params']:
|
| if param.grad is not None:
|
| dist.broadcast(param.grad, src=0)
|
|
|
| initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| return initial_epoch
|
|
|
|
|
| def save_checkpoint(network, optimizer, epoch, save_path, use_dist=True):
|
| """Save checkpoint with 'network.' key prefix for backward compatibility."""
|
| net = network.module if use_dist and isinstance(network, DDP) else network
|
| state_dict = {f"network.{k}": v for k, v in net.state_dict().items()}
|
| torch.save({
|
| 'model_state_dict': state_dict,
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'epoch': epoch,
|
| }, save_path)
|
|
|
|
|
|
|
|
|
|
|
| def main_train(rank=0, world_size=1, train_mode_ratio=1, thresh_imgsim=0.01):
|
| if use_distributed:
|
| ddp_setup(rank, world_size)
|
| if torch.distributed.is_initialized():
|
| print(f"World size: {torch.distributed.get_world_size()}")
|
| print(f"Communication backend: {torch.distributed.get_backend()}")
|
| gpu_id = rank
|
| device = f"cuda:{rank}" if use_distributed else None
|
|
|
|
|
| om = OMorpher(config=args.config, device=device)
|
| config = om.config
|
| if gpu_id == 0:
|
| print(config)
|
|
|
| epoch_per_save = config['epoch_per_save']
|
| suffix_pth = f"_{config['data_name']}_{config['net_name']}.pth"
|
| model_dir = os.path.join(ROOT_DIR, 'Models', f"{config['data_name']}_{config['net_name']}/")
|
|
|
|
|
|
|
| loss_reg = Grad(
|
| penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
|
| outrange_thresh=0.2, outrange_weight=1e3,
|
| )
|
| loss_reg1 = Grad(
|
| penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
|
| outrange_thresh=0.6, outrange_weight=1e3,
|
| )
|
|
|
| loss_imgsim = MSLNCC()
|
| loss_imgmse = LMSE()
|
|
|
|
|
| if use_distributed:
|
| om.network.to(rank)
|
| om.stn_full.to(rank)
|
| om.stn_ctl.to(rank)
|
| om.img_stn.to(rank)
|
| om.msk_stn.to(rank)
|
| network = DDP(om.network, device_ids=[rank])
|
| else:
|
| om.network.to(om.device)
|
| network = om.network
|
|
|
|
|
| optimizer = Adam(network.parameters(), lr=config["lr"])
|
|
|
|
|
| dataset = OMDataset_indiv(transform=None)
|
| train_loader = DataLoader(
|
| dataset, batch_size=config['batchsize'], shuffle=True, drop_last=True,
|
| )
|
|
|
| datasetp = OMDataset_pair(transform=None)
|
| train_loader_p = DataLoader(
|
| datasetp,
|
| batch_size=config['batchsize'] // DIFF_REG_BATCH_RATIO,
|
| shuffle=True, drop_last=True,
|
| )
|
|
|
|
|
| os.makedirs(model_dir, exist_ok=True)
|
| model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
|
| if model_files:
|
| if gpu_id == 0:
|
| print(model_files)
|
| initial_epoch = ddp_load_checkpoint(
|
| gpu_id, network, optimizer, model_files[-1], use_distributed,
|
| )
|
| else:
|
| initial_epoch = 0
|
|
|
| if gpu_id == 0:
|
| print('len_train_data: ', len(dataset))
|
|
|
| is_defrec = isinstance(om.network, DefRec_MutAttnNet)
|
|
|
|
|
| for epoch in range(initial_epoch, config["epoch"]):
|
| epoch_loss_tot = 0.0
|
| epoch_loss_gen_d = 0.0
|
| epoch_loss_gen_a = 0.0
|
| epoch_loss_reg_ = 0.0
|
| epoch_loss_regist = 0.0
|
| epoch_loss_imgsim_ = 0.0
|
| epoch_loss_imgmse_ = 0.0
|
| epoch_loss_ddfreg = 0.0
|
|
|
| network.train()
|
| loss_nan_step = 0
|
| total = min(len(train_loader), len(train_loader_p))
|
|
|
| for step, (batch, batch_p) in tqdm(
|
| enumerate(zip(train_loader, train_loader_p)), total=total,
|
| ):
|
|
|
|
|
|
|
| [x0, embd] = batch
|
| x0 = x0.to(om.device).type(torch.float32)
|
| if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
|
| embd = embd.to(om.device).type(torch.float32)
|
| else:
|
| embd = None
|
|
|
| n = x0.size()[0]
|
| x0 = x0.to(om.device)
|
|
|
| blind_mask = utils.get_random_deformed_mask(
|
| x0.shape[2:], apply_possibility=0.6,
|
| ).to(om.device)
|
|
|
|
|
| if om.ndims > 2:
|
| if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
|
| x0 = utils.random_resample(x0, deform_scale=0)
|
| else:
|
| [x0] = utils.random_permute([x0], select_dims=[-1, -2, -3])
|
|
|
| if config['noise_scale'] > 0:
|
| if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
|
| x0 = thresh_img(x0, [0, 2 * config['noise_scale']])
|
| x0 = x0 * (np.random.normal(1, config['noise_scale'] * 1)) + np.random.normal(0, config['noise_scale'] * 1)
|
|
|
| t = torch.randint(0, om.timesteps, (n,)).to(om.device)
|
|
|
| proc_type = random.choice(
|
| ['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'],
|
| )
|
| cond_img, _, cond_ratio = om._proc_cond_img(x0, proc_type=proc_type)
|
|
|
|
|
| noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t)
|
|
|
| if is_defrec:
|
| pre_dvf_I = network(
|
| x=noisy_img * blind_mask, y=cond_img, t=[t], rec_num=2, text=embd,
|
| )
|
| else:
|
| pre_dvf_I = network(
|
| x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd,
|
| )
|
|
|
|
|
| loss_tot = 0
|
| loss_ddf = loss_reg(pre_dvf_I, img=x0)
|
| trm_pred = om.stn_full(pre_dvf_I, dvf_gt)
|
| loss_gen_d = om._loss_dist(
|
| pred=trm_pred, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask,
|
| )
|
| loss_gen_a = om._loss_ang(
|
| pred=trm_pred, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask,
|
| )
|
|
|
| loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| loss_tot = torch.sqrt(1. + MSK_EPS - cond_ratio) * loss_tot
|
|
|
|
|
| if torch.isnan(x0).any():
|
| print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| if loss_ddf > 0.001:
|
| print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| loss_nan_step += 1
|
| continue
|
| if loss_nan_step > 5:
|
| print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
|
|
| optimizer.zero_grad()
|
| loss_tot.backward()
|
| optimizer.step()
|
|
|
| epoch_loss_tot += loss_tot.item() / total
|
| epoch_loss_gen_d += loss_gen_d.item() / total
|
| epoch_loss_gen_a += loss_gen_a.item() / total
|
| epoch_loss_reg_ += loss_ddf.item() / total
|
|
|
|
|
|
|
|
|
| if step % train_mode_ratio == 0:
|
| [x1, y1, _, embd_y] = batch_p
|
| if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
|
| embd_y = embd_y.to(om.device).type(torch.float32)
|
| else:
|
| embd_y = None
|
|
|
| x1 = x1.to(om.device).type(torch.float32)
|
| y1 = y1.to(om.device).type(torch.float32)
|
| n = x1.size()[0]
|
|
|
|
|
| [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1, -2, -3])
|
| if config['noise_scale'] > 0:
|
| [x1, y1] = thresh_img([x1, y1], [0, 2 * config['noise_scale']])
|
| random_scale = np.random.normal(1, config['noise_scale'] * 1)
|
| random_shift = np.random.normal(0, config['noise_scale'] * 1)
|
| x1 = x1 * random_scale + random_shift
|
| y1 = y1 * random_scale + random_shift
|
|
|
|
|
| scale_regist = np.random.uniform(0.0, 0.7)
|
| T_regist = sorted(
|
| random.sample(
|
| range(int(om.timesteps * scale_regist), om.timesteps), 16,
|
| ),
|
| reverse=True,
|
| )
|
| T_regist = [
|
| [t_val for _ in range(config["batchsize"] // 2)]
|
| for t_val in T_regist
|
| ]
|
|
|
| proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| y1_proc, msk_tgt, cond_ratio = om._proc_cond_img(
|
| y1, proc_type=proc_type,
|
| )
|
|
|
|
|
| ddf_comp, img_rec = reverse_diffuse_train(
|
| network, om, x1, y1_proc, T=[None, T_regist], text=embd_y,
|
| )
|
|
|
|
|
| loss_sim = loss_imgsim(img_rec, y1, label=(y1 > thresh_imgsim))
|
| loss_mse = loss_imgmse(img_rec, y1, label=(y1 >= 0.0))
|
| loss_ddf1 = loss_reg1(ddf_comp, img=y1)
|
|
|
| loss_regist = 0
|
| loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
|
|
|
|
| if torch.isnan(x0).any():
|
| print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| if loss_ddf1 > 0.002:
|
| print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
|
|
| loss_regist = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist
|
| optimizer.zero_grad()
|
| loss_regist.backward()
|
|
|
| torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.4)
|
| optimizer.step()
|
|
|
| epoch_loss_regist += loss_regist.item() / total
|
| epoch_loss_imgsim_ += loss_sim.item() / total
|
| epoch_loss_imgmse_ += loss_mse.item() / total
|
| epoch_loss_ddfreg += loss_ddf1.item() / total
|
|
|
| if step % 10 == 0:
|
| print('step:', step, ':', loss_tot.item(), '=', loss_gen_a.item(), '+', loss_gen_d.item(), '+', loss_ddf.item())
|
| print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
|
|
| if 1:
|
| print('==================')
|
| print(epoch, ':', epoch_loss_tot, '=', epoch_loss_gen_a, '+', epoch_loss_gen_d, '+', epoch_loss_reg_, ' (ang+dist+regul)')
|
| print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim_} (imgsim) + {epoch_loss_imgmse_} (imgmse) + {epoch_loss_ddfreg} (ddf)')
|
| print('==================')
|
|
|
| if 0 == epoch % epoch_per_save:
|
| save_path = os.path.join(model_dir, str(epoch).rjust(6, '0') + suffix_pth)
|
| os.makedirs(model_dir, exist_ok=True)
|
| if not use_distributed:
|
| print(f"saved in {save_path}")
|
| save_checkpoint(network, optimizer, epoch, save_path, use_dist=False)
|
| elif gpu_id == 0:
|
| print(f"saved in {save_path}")
|
| save_checkpoint(network, optimizer, epoch, save_path, use_dist=True)
|
|
|
|
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| if use_distributed and dist.is_initialized():
|
| dist.destroy_process_group()
|
|
|
|
|
| if __name__ == "__main__":
|
| if use_distributed:
|
| world_size = torch.cuda.device_count()
|
| print(f"Distributed GPU number = {world_size}")
|
| mp.spawn(main_train, args=(world_size,), nprocs=world_size)
|
| else:
|
| main_train(0, 1)
|
|
|