Omini3D / Scripts /OM_train_3modes.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_train_3modes.py — Three-mode training (diffusion + registration + contrastive)
using OMorpher.
Drop-in replacement for OM_train_3modes.py. Uses the OMorpher object-oriented
wrapper instead of procedural DeformDDPM calls, while preserving the same
training logic, DDP support, loss functions, and checkpoint format.
Usage:
# Single-GPU
python Scripts/OM_train_3modes.py -C Config/config_om.yaml
# Multi-GPU (DDP)
CUDA_VISIBLE_DEVICES=0,1 python Scripts/OM_train_3modes.py -C Config/config_om.yaml
# Dummy data for testing (no real dataset needed)
python Scripts/OM_train_3modes.py -C Config/config_om.yaml --dummy-samples 20
"""
import os
import sys
# Add project root to path so imports work from Scripts/
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT_DIR)
import gc
import glob
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
from OMorpher import OMorpher
from Diffusion.networks import DefRec_MutAttnNet
from Diffusion.losses import Grad, LNCC, LMSE, MSLNCC
from Dataloader.dataLoader import OMDataset_indiv, OMDataset_pair
from Dataloader.dataloader_utils import thresh_img
import utils
# ========================== Constants ==========================
EPS = 1e-5
MSK_EPS = 0.01
TEXT_EMBED_PROB = 0.7
AUG_RESAMPLE_PROB = 0.6
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16] # [ang, dist, reg]
LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf]
DIFF_REG_BATCH_RATIO = 2
LOSS_WEIGHT_CONTRASTIVE = 1.0
CONTRASTIVE_STEP_RATIO = 2
# Auto-detect: use DDP only when multiple CUDA GPUs are available
use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1
# use_distributed = True
# use_distributed = False
# ========================== Arguments ==========================
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-C",
help="Path for the config file",
type=str,
default="Config/config_all.yaml",
required=False,
)
parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
args = parser.parse_args()
# ========================== Dummy Datasets ==========================
class _DummyIndiv(torch.utils.data.Dataset):
def __init__(self, n, sz, embd_dim=1024):
self.n, self.sz, self.embd_dim = n, sz, embd_dim
def __len__(self): return self.n
def __getitem__(self, i):
return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
class _DummyPair(torch.utils.data.Dataset):
def __init__(self, n, sz, embd_dim=1024):
self.n, self.sz, self.embd_dim = n, sz, embd_dim
def __len__(self): return self.n
def __getitem__(self, i):
return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
np.random.randn(self.embd_dim).astype(np.float32),
np.random.randn(self.embd_dim).astype(np.float32))
# ========================== DDP Setup ==========================
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# ========================== Helpers ==========================
def reverse_diffuse_train(network, om, img_org, cond_imgs, T, text=None):
"""Registration reverse diffusion with selective gradient control.
Mirrors DeformDDPM.diff_recover() with T=[None, T_regist].
Only the last k=2 timesteps have gradients enabled for efficient training.
Args:
network: DDP-wrapped (or raw) network module.
om: OMorpher instance (provides STN instances and device info).
img_org: Source image [B, 1, S, S, S].
cond_imgs: Processed conditioning image [B, 1, S, S, S].
T: [T_init, T_schedule]. T_init=None means no forward diffusion.
T_schedule is a list of batched timestep lists from the training loop.
text: Optional text embedding [B, 1024].
Returns:
(ddf_comp, img_rec): Composed DDF and recovered image.
"""
B = img_org.shape[0]
S = om.img_size
# T[0] = None → no forward diffusion, start from original image
ddf_comp = torch.zeros(
[B, om.ndims] + [S] * om.ndims,
dtype=torch.float32, device=om.device,
)
img_rec = img_org.clone().detach()
time_steps = T[1]
k = 2
trainable_iterations = time_steps[-1:-k - 1:-1]
net_module = network.module if isinstance(network, DDP) else network
for i in time_steps:
t = torch.tensor(np.array([i])).to(om.device)
if i in trainable_iterations:
# Gradients enabled — call through DDP wrapper for gradient sync
pre_dvf = network(x=img_rec, y=cond_imgs, t=t, rec_num=2, text=text)
else:
# No gradients — call underlying module directly
with torch.no_grad():
pre_dvf = net_module(x=img_rec, y=cond_imgs, t=t, rec_num=2, text=text)
ddf_comp = om.stn_full(ddf_comp, pre_dvf) + pre_dvf
img_rec = om.img_stn(img_org.clone().detach(), ddf_comp)
return ddf_comp, img_rec
def ddp_load_checkpoint(gpu_id, network, optimizer, model_file,
use_dist=True, load_strict=False):
"""Load checkpoint with DDP-aware parameter broadcast."""
if gpu_id == 0:
utils.print_memory_usage("Before Loading Model")
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
checkpoint = torch.load(model_file, map_location='cpu')
state_dict = checkpoint['model_state_dict']
# Strip DDP 'module.' and DeformDDPM 'network.' prefixes
cleaned = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
if k.startswith("network."):
k = k[len("network."):]
cleaned[k] = v
net = network.module if use_dist else network
net_keys = set(net.state_dict().keys())
filtered = {k: v for k, v in cleaned.items() if k in net_keys}
net.load_state_dict(filtered, strict=load_strict)
if load_strict:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
utils.print_memory_usage("After Loading Checkpoint on GPU")
if use_dist:
# Broadcast model weights from rank 0 to all other GPUs
dist.barrier()
for param in network.parameters():
dist.broadcast(param.data, src=0)
dist.barrier()
for param_group in optimizer.param_groups:
for param in param_group['params']:
if param.grad is not None:
dist.broadcast(param.grad, src=0)
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
return initial_epoch
def save_checkpoint(network, optimizer, epoch, save_path, use_dist=True):
"""Save checkpoint with 'network.' key prefix for backward compatibility."""
net = network.module if use_dist and isinstance(network, DDP) else network
state_dict = {f"network.{k}": v for k, v in net.state_dict().items()}
torch.save({
'model_state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
}, save_path)
# ========================== Main Training ==========================
def main_train(rank=0, world_size=1, train_mode_ratio=1, thresh_imgsim=0.01):
if use_distributed:
ddp_setup(rank, world_size)
if torch.distributed.is_initialized():
print(f"World size: {torch.distributed.get_world_size()}")
print(f"Communication backend: {torch.distributed.get_backend()}")
gpu_id = rank
device = f"cuda:{rank}" if use_distributed else None
# ---- OMorpher initialisation (config, network, STN, losses, auto-checkpoint) ----
om = OMorpher(config=args.config, device=device)
config = om.config
if args.batchsize > 0:
config['batchsize'] = args.batchsize
if gpu_id == 0:
print(config)
epoch_per_save = config['epoch_per_save']
suffix_pth = f"_{config['data_name']}_{config['net_name']}.pth"
model_dir = os.path.join(ROOT_DIR, 'Models', f"{config['data_name']}_{config['net_name']}/")
# ---- Additional loss functions for the three training modes ----
# Diffusion losses reused from OMorpher: om._loss_dist (MRSE), om._loss_ang (NCC)
loss_reg = Grad(
penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
outrange_thresh=0.2, outrange_weight=1e3,
)
loss_reg1 = Grad(
penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
outrange_thresh=0.6, outrange_weight=1e3,
)
# loss_imgsim = LNCC()
loss_imgmse = MSLNCC()
loss_imgmse = LMSE()
# ---- DDP wrapping ----
if use_distributed:
om.network.to(rank)
om.stn_full.to(rank)
om.stn_ctl.to(rank)
om.img_stn.to(rank)
om.msk_stn.to(rank)
network = DDP(om.network, device_ids=[rank])
else:
om.network.to(om.device)
network = om.network
# ---- Optimizer ----
optimizer = Adam(network.parameters(), lr=config["lr"])
# ---- Data loaders ----
if args.dummy_samples > 0:
dataset = _DummyIndiv(args.dummy_samples, config['img_size'])
datasetp = _DummyPair(args.dummy_samples, config['img_size'])
else:
dataset = OMDataset_indiv(transform=None)
datasetp = OMDataset_pair(transform=None)
train_loader = DataLoader(
dataset, batch_size=config['batchsize'], shuffle=True, drop_last=True,
)
train_loader_p = DataLoader(
datasetp,
batch_size=max(1, config['batchsize'] // DIFF_REG_BATCH_RATIO),
shuffle=True, drop_last=True,
)
# ---- Auto-resume from checkpoint ----
os.makedirs(model_dir, exist_ok=True)
model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
if model_files:
if gpu_id == 0:
print(model_files)
initial_epoch = ddp_load_checkpoint(
gpu_id, network, optimizer, model_files[-1], use_distributed,
)
else:
initial_epoch = 0
if gpu_id == 0:
print('len_train_data: ', len(dataset))
is_defrec = isinstance(om.network, DefRec_MutAttnNet)
# ---- Training loop ----
for epoch in range(initial_epoch, config["epoch"]):
epoch_loss_tot = 0.0
epoch_loss_gen_d = 0.0
epoch_loss_gen_a = 0.0
epoch_loss_reg_ = 0.0
epoch_loss_regist = 0.0
epoch_loss_imgsim_ = 0.0
epoch_loss_imgmse_ = 0.0
epoch_loss_ddfreg = 0.0
epoch_loss_contrastive = 0.0
network.train()
loss_nan_step = 0
total = min(len(train_loader), len(train_loader_p))
for step, (batch, batch_p) in tqdm(
enumerate(zip(train_loader, train_loader_p)), total=total,
):
# ==========================================================
# Mode 1: Diffusion training on single image
# ==========================================================
[x0, embd] = batch
x0 = x0.to(om.device).type(torch.float32)
embd_dev = embd.to(om.device).type(torch.float32)
if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
embd_in = embd_dev
else:
embd_in = None
n = x0.size()[0]
x0 = x0.to(om.device)
blind_mask = utils.get_random_deformed_mask(
x0.shape[2:], apply_possibility=0.6,
).to(om.device)
# Data augmentation
if om.ndims > 2:
if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
x0 = utils.random_resample(x0, deform_scale=0)
else:
[x0] = utils.random_permute([x0], select_dims=[-1, -2, -3])
if config['noise_scale'] > 0:
if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB:
x0 = thresh_img(x0, [0, 1 * config['noise_scale']])
x0 = x0 * (np.random.normal(1, config['noise_scale'] * 1)) + np.random.normal(0, config['noise_scale'] * 1)
t = torch.randint(0, om.timesteps, (n,)).to(om.device)
proc_type = random.choice(
['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'],
)
cond_img, _, cond_ratio = om._proc_cond_img(x0, proc_type=proc_type)
# Forward diffusion + network prediction
noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t)
if is_defrec:
pre_dvf_I = network(
x=noisy_img * blind_mask, y=cond_img, t=[t], rec_num=2, text=embd_in,
)
else:
pre_dvf_I = network(
x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd_in,
)
# Diffusion losses
loss_tot = 0
loss_ddf = loss_reg(pre_dvf_I, img=x0)
trm_pred = om.stn_full(pre_dvf_I, dvf_gt)
loss_gen_d = om._loss_dist(
pred=trm_pred, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask,
)
loss_gen_a = om._loss_ang(
pred=trm_pred, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask,
)
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
loss_tot = torch.sqrt(1. + MSK_EPS - cond_ratio) * loss_tot
# NaN / divergence checks
if torch.isnan(x0).any():
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
if loss_ddf > 0.001:
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
loss_nan_step += 1
continue
if loss_nan_step > 5:
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
optimizer.zero_grad()
loss_tot.backward()
optimizer.step()
epoch_loss_tot += loss_tot.item() / total
epoch_loss_gen_d += loss_gen_d.item() / total
epoch_loss_gen_a += loss_gen_a.item() / total
epoch_loss_reg_ += loss_ddf.item() / total
# ==========================================================
# Mode 2: Contrastive training (text-image alignment)
# ==========================================================
loss_contra_val = None
if step % CONTRASTIVE_STEP_RATIO == 0:
# Access raw network (not DDP-wrapped) for contrastive forward pass
raw_network = network.module if isinstance(network, DDP) else network
n_contra = x0.size()[0]
t_contra = torch.randint(0, config["timesteps"], (n_contra,)).to(om.device)
_ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=embd_dev.detach())
if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
img_embd = raw_network.img_embd # [B, 1024]
loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean())
optimizer.zero_grad()
loss_contra.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.05)
optimizer.step()
loss_contra_val = loss_contra.item()
epoch_loss_contrastive += loss_contra_val / total
# ==========================================================
# Mode 3: Registration training on paired images
# ==========================================================
if step % train_mode_ratio == 0:
[x1, y1, _, embd_y] = batch_p
if np.random.uniform(0, 1) < TEXT_EMBED_PROB:
embd_y = embd_y.to(om.device).type(torch.float32)
else:
embd_y = None
x1 = x1.to(om.device).type(torch.float32)
y1 = y1.to(om.device).type(torch.float32)
n = x1.size()[0]
# Augmentation
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1, -2, -3])
if config['noise_scale'] > 0:
[x1, y1] = thresh_img([x1, y1], [0, 2 * config['noise_scale']])
random_scale = np.random.normal(1, config['noise_scale'] * 1)
random_shift = np.random.normal(0, config['noise_scale'] * 1)
x1 = x1 * random_scale + random_shift
y1 = y1 * random_scale + random_shift
# Timestep schedule for reverse diffusion
scale_regist = np.random.uniform(0.0, 0.7)
T_regist = sorted(
random.sample(
range(int(om.timesteps * scale_regist), om.timesteps), 16,
),
reverse=True,
)
T_regist = [
[t_val for _ in range(max(1, config["batchsize"] // 2))]
for t_val in T_regist
]
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
y1_proc, msk_tgt, cond_ratio = om._proc_cond_img(
y1, proc_type=proc_type,
)
# Reverse diffusion for registration (via OMorpher's network & STN)
ddf_comp, img_rec = reverse_diffuse_train(
network, om, x1, y1_proc, T=[None, T_regist], text=embd_y,
)
# Registration losses
loss_sim = loss_imgsim(img_rec, y1, label=(y1 > thresh_imgsim))
loss_mse = loss_imgmse(img_rec, y1)
loss_ddf1 = loss_reg1(ddf_comp, img=y1)
loss_regist = 0
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
# NaN / divergence checks
if torch.isnan(x0).any():
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
if loss_ddf1 > 0.002:
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
loss_regist = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist
optimizer.zero_grad()
loss_regist.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.4)
optimizer.step()
epoch_loss_regist += loss_regist.item() / total
epoch_loss_imgsim_ += loss_sim.item() / total
epoch_loss_imgmse_ += loss_mse.item() / total
epoch_loss_ddfreg += loss_ddf1.item() / total
if step % 10 == 0:
print('step:', step, ':', loss_tot.item(), '=', loss_gen_a.item(), '+', loss_gen_d.item(), '+', loss_ddf.item())
if loss_contra_val is not None:
print(f' loss_contrastive: {loss_contra_val:.6f}')
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
if 1:
print('==================')
print(epoch, ':', epoch_loss_tot, '=', epoch_loss_gen_a, '+', epoch_loss_gen_d, '+', epoch_loss_reg_, ' (ang+dist+regul)')
print(f' loss_contrastive: {epoch_loss_contrastive}')
print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim_} (imgsim) + {epoch_loss_imgmse_} (imgmse) + {epoch_loss_ddfreg} (ddf)')
print('==================')
if 0 == epoch % epoch_per_save:
save_path = os.path.join(model_dir, str(epoch).rjust(6, '0') + suffix_pth)
os.makedirs(model_dir, exist_ok=True)
if not use_distributed:
print(f"saved in {save_path}")
save_checkpoint(network, optimizer, epoch, save_path, use_dist=False)
elif gpu_id == 0:
print(f"saved in {save_path}")
save_checkpoint(network, optimizer, epoch, save_path, use_dist=True)
# Resource cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if use_distributed and dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
if use_distributed:
world_size = torch.cuda.device_count()
print(f"Distributed GPU number = {world_size}")
mp.spawn(main_train, args=(world_size,), nprocs=world_size)
else:
main_train(0, 1)