""" diffuser_opt.py — Optimized DeformDDPM subclass. Inherits from Diffusion.diffuser.DeformDDPM and overrides only the methods that benefit from optimization. Key optimizations: 1. diff_recover(): hoist img_org/msk_org .clone().detach() outside the loop, pre-compute timestep tensors, use torch.no_grad() for frozen steps 2. _random_ddf_generate(): scaling-and-squaring for O(log n) composition instead of O(n), crop-first upsampling (4x faster), on-device tensors. 3. proc_cond_img(): skip clone for 'uncon' path (most common, ~3/8 weight) 4. _DDF_Encoder_init(): use OptSTN (register_buffer, no per-call .to(device)) 5. recover(): fix t tensor bug (was staying on CPU), avoid redundant torch.tensor() 6. _multiscale_dvf_generate(): generate random tensors on device to avoid CPU→GPU transfer of 3D volumes. """ from torch import nn import torch import numpy as np import torch.nn.functional as F import random import math import Diffusion.utils_diff as utils from Diffusion.diffuser import DeformDDPM as _BaseDeformDDPM from Diffusion.networks import * from Diffusion.networks_opt import OptSTN EPS = 1e-8 class DeformDDPM(_BaseDeformDDPM): """Drop-in replacement for DeformDDPM with speed optimizations.""" # ------------------------------------------------------------------ # Optimization 4: use OptSTN (register_buffer, no per-call .to()) # ------------------------------------------------------------------ def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None): if ctl_sz is None: ctl_sz = self.image_chw[1] // ctl_ratio self.ctl_sz = ctl_sz self.img_sz = self.image_chw[1] # OPT: use OptSTN instead of STN — register_buffer for ref_grid/max_sz self.ddf_stn_rec = OptSTN(img_sz=ctl_sz, ndims=self.ndims, device=self.device, padding_mode=self.ddf_pad_mode) self.img_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device, padding_mode=self.img_pad_mode, resample_mode=self.resample_mode) self.msk_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device, padding_mode=self.img_pad_mode, resample_mode='nearest') def __init__(self, network, n_steps=50, beta_schedule_fn=None, device='cpu', image_chw=(1, 28, 28), batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border", padding_mode="border", v_scale=0.008/256, resample_mode=None, inf_mode=False): # Call parent __init__ — it creates STN instances super().__init__( network=network, n_steps=n_steps, beta_schedule_fn=beta_schedule_fn, device=device, image_chw=image_chw, batch_size=batch_size, img_pad_mode=img_pad_mode, ddf_pad_mode=ddf_pad_mode, padding_mode=padding_mode, v_scale=v_scale, resample_mode=resample_mode, inf_mode=inf_mode, ) # OPT: replace ddf_stn_full with OptSTN too self.ddf_stn_full = OptSTN( img_sz=self.image_chw[1], ndims=self.ndims, padding_mode=self.padding_mode, device=self.device, ) # ------------------------------------------------------------------ # Optimization 5: fix recover() t tensor bug + avoid redundant copies # ------------------------------------------------------------------ def recover(self, x, y, t, rec_num=2, text=None): # OPT: don't recreate t if already a tensor on the right device if isinstance(t, list): t = [t0 if isinstance(t0, torch.Tensor) else torch.tensor(t0, device=x.device) for t0 in t] t = [t0.to(x.device) if t0.device != x.device else t0 for t0 in t] elif isinstance(t, torch.Tensor): # OPT: skip torch.tensor() copy — just ensure correct device if t.device != x.device: t = t.to(x.device) else: t = torch.tensor(t, device=x.device) if rec_num is None: rec_num = self.rec_num return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text) # ------------------------------------------------------------------ # Optimization 2: scaling-and-squaring + crop-first upsample # ------------------------------------------------------------------ def _compose_n_times(self, dvf, n): """Compute n-fold self-composition of dvf using scaling-and-squaring. Uses binary decomposition: O(log n) STN calls instead of O(n). E.g. n=87 → ~10 calls, n=200 → ~9 calls (vs 87/200 iterative calls). The result is the same deformation (n-fold composition) but computed via a different sequence of grid_sample interpolations, so there are small numerical differences (~1e-2 to 1e-1) vs iterative composition. This is acceptable because DDF generation is stochastic augmentation. """ if n <= 0: return torch.zeros_like(dvf) result = None current = dvf # current = dvf^(2^i), starts as dvf^1 while n > 0: if n & 1: # bit is set → accumulate this power if result is None: result = current.clone() else: # result = current ∘ result (apply result first, then current) result = result + self.ddf_stn_rec(current, result) n >>= 1 if n > 0: # Square: current = current ∘ current current = current + self.ddf_stn_rec(current, current) return result def _crop_upsample(self, field): """Upsample DDF from ctl_sz to img_sz with 2x oversampling + center crop. Instead of upsampling the full ctl_sz→img_sz*2 (e.g. 32³→256³) then cropping to img_sz (128³), we crop the control-point field first (to ~20³) then upsample to ~160³ and crop to 128³. This is 4x faster and bit-identical because trilinear interpolation is local. """ crop_rate = 2 upscale = self.img_sz * crop_rate // self.ctl_sz # e.g. 8 margin = 2 # voxels of margin for interpolation boundary lo = self.ctl_sz // 4 - margin # e.g. 6 hi = self.ctl_sz * 3 // 4 + margin # e.g. 26 crop_sz = hi - lo # e.g. 20 up_sz = crop_sz * upscale # e.g. 160 pad = (up_sz - self.img_sz) // 2 # e.g. 16 mode = 'bilinear' if self.ndims == 2 else 'trilinear' if self.ndims == 2: field_crop = field[..., lo:hi, lo:hi] * self.img_sz / self.ctl_sz field_up = F.interpolate(field_crop, up_sz, mode=mode) return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz] else: field_crop = field[..., lo:hi, lo:hi, lo:hi] * self.img_sz / self.ctl_sz field_up = F.interpolate(field_crop, up_sz, mode=mode) return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz, pad:pad + self.img_sz] def _random_ddf_generate(self, rec_num=3, mul_num=[torch.tensor([5]), torch.tensor([5])], ddf0=None, keep_inverse=False, noise_ratio=0.08, select_num=3, flip_ratio=0.5): for _ in range(self.ndims + 1): mul_num = [torch.unsqueeze(n, -1) for n in mul_num] ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims if ddf0 is not None: ddf = ddf0 else: ddf = torch.zeros(ctl_ddf_sz, device=self.device) dddf = torch.zeros(ctl_ddf_sz, device=self.device) scale_num = min(8, int(math.log2(self.ctl_sz))) ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)] for i in range(rec_num): if len(ctl_szs_all) > select_num: ctl_szs = random.sample(ctl_szs_all, select_num) else: ctl_szs = ctl_szs_all dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs) if noise_ratio == 0: dvf0 = dvf else: dvf0 = dvf + self.ddf_stn_rec( self._multiscale_dvf_generate(self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False), dvf) mul_num_ddf_val = int(torch.max(mul_num[0]).item()) mul_num_dvf_val = int(torch.max(mul_num[1]).item()) # OPT: scaling-and-squaring — O(log n) STN calls instead of O(n) # For t=40: 10 calls instead of 80. For t=79: 9 calls instead of 195. ddf = self._compose_n_times(dvf0, mul_num_ddf_val) dddf = self._compose_n_times(dvf, mul_num_dvf_val) # OPT: crop-first upsample — 4x fewer voxels to interpolate (bit-identical) ddf = self._crop_upsample(ddf) dddf = self._crop_upsample(dddf) return ddf, dddf # ------------------------------------------------------------------ # Optimization 6: generate DVF on device to avoid CPU→GPU transfer # ------------------------------------------------------------------ def _multiscale_dvf_generate(self, v_scale, ctl_szs=[4, 8, 16, 32, 64], rand_v_scale=True): dvf = 0 if self.img_sz is None: self.img_sz = max(ctl_szs) if 1 in ctl_szs: dvf_rot = utils.random_ddf( batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz] * self.ndims, range_gauss=0, rot_range=np.pi / 90) dvf = dvf + dvf_rot for ctl_sz in ctl_szs: _v_scale = self._sample_random_uniform_multi_order( high=v_scale, low=0., order_num=random.choice([1, 1, 2])) if rand_v_scale else v_scale if ctl_sz <= 2: _v_scale = _v_scale / 2 # OPT: generate random tensor directly on device dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz] * self.ndims, device=self.device) * _v_scale dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear') dvf = dvf + dvf_comp return dvf # ------------------------------------------------------------------ # Optimization 3: skip clone for 'uncon' (most common conditioning type) # ------------------------------------------------------------------ def proc_cond_img(self, img, proc_type=None, noise_scale=0.1): if proc_type is None: proc_type = random.choices( ['adding', 'independ', 'downsample', 'slice', 'slice1', 'none', 'uncon'], weights=[1, 1, 1, 1, 1, 3], k=1 )[0] mask = torch.tensor(1, device=img.device) cond_ratio = torch.tensor(1., device=img.device) self.msk_noise_scale = torch.tensor(0, device=img.device) noise_type = random.choice(['gaussian', 'uniform', 'none']) if proc_type not in ['none', None, '']: # OPT: handle 'uncon' before cloning — no need to clone img if proc_type == 'uncon': noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale) proc_img = noise_map mask = torch.tensor(0, device=img.device) cond_ratio = torch.tensor(0, device=img.device) return proc_img, mask, cond_ratio # Only clone when we actually need the image data proc_img = img.clone().detach() noise_map = None if proc_type in ['adding', 'independ', 'slice', 'slice1']: noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale) if proc_type == 'adding': proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.]) cond_ratio = torch.tensor(1 - noise_ratio, device=img.device) elif proc_type == 'independ': mask = self.create_noise_map(img, noise_type='binary') if self.msk_noise_scale == 0: proc_img = img * mask else: proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask) with torch.no_grad(): cond_ratio = mask.float().mean() elif proc_type == 'downsample': proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1. / 64, 1]) cond_ratio = torch.tensor(down_ratio, device=img.device) elif proc_type == 'slice' or proc_type == 'slice1': if proc_type == 'slice1': slice_num_max = 1 else: slice_num_max = random.randint(1, 64) slice_num_max = random.randint(1, slice_num_max) mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max]) if self.msk_noise_scale == 0: proc_img = img * mask else: proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask) cond_ratio = torch.tensor(sample_ratio, device=img.device) elif proc_type == 'project': proc_img, proj_num = self.project(proc_img) cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device) return proc_img, mask, cond_ratio else: # 'none' type — still need clone proc_img = img.clone().detach() return proc_img, mask, cond_ratio # ------------------------------------------------------------------ # Optimization 1: hoist clone, pre-compute timestep tensors, # use inference_mode for frozen iterations # ------------------------------------------------------------------ def diff_recover(self, img_org, msk_org=None, T=[None, None], ddf_rand=None, v_scale=None, t_save=None, cond_imgs=None, proc_type=None, text=None): if cond_imgs is None: cond_imgs = img_org.clone().detach() cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(cond_imgs, proc_type=proc_type) if ddf_rand is None: if v_scale is not None: self.v_scale = v_scale self._DDF_Encoder_init() if T[0] is None or T[0] == 0: img_diff = img_org.clone().detach() ddf_rand = torch.zeros_like(img_diff) else: img_diff, _, ddf_rand = self._get_random_ddf( img=img_org, t=torch.tensor(np.array([T[0]])).to(self.device)) else: img_diff = self.img_stn(img_org.clone().detach(), ddf_rand) ddf_comp = ddf_rand.clone().detach() img_rec = img_diff.clone().detach() if msk_org is not None: msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand) else: msk_diff = None msk_rec = msk_diff.clone().detach() if msk_org is not None else None img_save = [] msk_save = [] # OPT: hoist clone().detach() outside the loop — grid_sample is read-only img_org_ref = img_org.clone().detach() msk_org_ref = msk_org.clone().detach() if msk_org is not None else None if isinstance(self.network, DefRec_MutAttnNet): t_list = list(range(T[1] - 1, -1, -1)) pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list, rec_num=None, text=text) ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I img_rec = self.img_stn(img_org_ref, ddf_comp) if msk_org is not None: msk_rec = self.msk_stn(msk_org_ref, ddf_comp) else: if isinstance(T[-1], int): time_steps = range(T[-1] - 1, -1, -1) trainable_iterations = [] else: time_steps = T[-1] k = 2 trainable_iterations = time_steps[-1:-k - 1:-1] # OPT: pre-compute trainable index threshold — avoid unhashable list issue t_save_set = set(t_save) if t_save is not None else None num_time_steps = len(time_steps) if not isinstance(time_steps, range) else len(time_steps) trainable_start_idx = num_time_steps - len(trainable_iterations) for step_idx, i in enumerate(time_steps): # OPT: create tensor directly on device, no numpy intermediate t = torch.tensor([i], device=self.device) if step_idx >= trainable_start_idx: pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text) else: # OPT: no_grad for frozen iterations (inference_mode not safe here # because ddf_comp is composed across frozen+trainable iterations) with torch.no_grad(): pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text) ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I # OPT: use pre-cloned reference instead of cloning each iteration img_rec = self.img_stn(img_org_ref, ddf_comp) if msk_org is not None: msk_rec = self.msk_stn(msk_org_ref, ddf_comp) if t_save_set is not None: if i in t_save_set: img_save.append(img_rec) if msk_org is not None: msk_save.append(msk_rec) return [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save]