Omini3D / Diffusion /diffuser_opt.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
diffuser_opt.py — Optimized DeformDDPM subclass.
Inherits from Diffusion.diffuser.DeformDDPM and overrides only the methods
that benefit from optimization.
Key optimizations:
1. diff_recover(): hoist img_org/msk_org .clone().detach() outside the loop,
pre-compute timestep tensors, use torch.no_grad() for frozen steps
2. _random_ddf_generate(): scaling-and-squaring for O(log n) composition
instead of O(n), crop-first upsampling (4x faster), on-device tensors.
3. proc_cond_img(): skip clone for 'uncon' path (most common, ~3/8 weight)
4. _DDF_Encoder_init(): use OptSTN (register_buffer, no per-call .to(device))
5. recover(): fix t tensor bug (was staying on CPU), avoid redundant torch.tensor()
6. _multiscale_dvf_generate(): generate random tensors on device to avoid
CPU→GPU transfer of 3D volumes.
"""
from torch import nn
import torch
import numpy as np
import torch.nn.functional as F
import random
import math
import Diffusion.utils_diff as utils
from Diffusion.diffuser import DeformDDPM as _BaseDeformDDPM
from Diffusion.networks import *
from Diffusion.networks_opt import OptSTN
EPS = 1e-8
class DeformDDPM(_BaseDeformDDPM):
"""Drop-in replacement for DeformDDPM with speed optimizations."""
# ------------------------------------------------------------------
# Optimization 4: use OptSTN (register_buffer, no per-call .to())
# ------------------------------------------------------------------
def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
if ctl_sz is None:
ctl_sz = self.image_chw[1] // ctl_ratio
self.ctl_sz = ctl_sz
self.img_sz = self.image_chw[1]
# OPT: use OptSTN instead of STN — register_buffer for ref_grid/max_sz
self.ddf_stn_rec = OptSTN(img_sz=ctl_sz, ndims=self.ndims, device=self.device,
padding_mode=self.ddf_pad_mode)
self.img_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
padding_mode=self.img_pad_mode, resample_mode=self.resample_mode)
self.msk_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
padding_mode=self.img_pad_mode, resample_mode='nearest')
def __init__(self, network, n_steps=50, beta_schedule_fn=None, device='cpu',
image_chw=(1, 28, 28), batch_size=1, img_pad_mode="zeros",
ddf_pad_mode="border", padding_mode="border",
v_scale=0.008/256, resample_mode=None, inf_mode=False):
# Call parent __init__ — it creates STN instances
super().__init__(
network=network, n_steps=n_steps, beta_schedule_fn=beta_schedule_fn,
device=device, image_chw=image_chw, batch_size=batch_size,
img_pad_mode=img_pad_mode, ddf_pad_mode=ddf_pad_mode,
padding_mode=padding_mode, v_scale=v_scale, resample_mode=resample_mode,
inf_mode=inf_mode,
)
# OPT: replace ddf_stn_full with OptSTN too
self.ddf_stn_full = OptSTN(
img_sz=self.image_chw[1], ndims=self.ndims,
padding_mode=self.padding_mode, device=self.device,
)
# ------------------------------------------------------------------
# Optimization 5: fix recover() t tensor bug + avoid redundant copies
# ------------------------------------------------------------------
def recover(self, x, y, t, rec_num=2, text=None):
# OPT: don't recreate t if already a tensor on the right device
if isinstance(t, list):
t = [t0 if isinstance(t0, torch.Tensor) else torch.tensor(t0, device=x.device)
for t0 in t]
t = [t0.to(x.device) if t0.device != x.device else t0 for t0 in t]
elif isinstance(t, torch.Tensor):
# OPT: skip torch.tensor() copy — just ensure correct device
if t.device != x.device:
t = t.to(x.device)
else:
t = torch.tensor(t, device=x.device)
if rec_num is None:
rec_num = self.rec_num
return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
# ------------------------------------------------------------------
# Optimization 2: scaling-and-squaring + crop-first upsample
# ------------------------------------------------------------------
def _compose_n_times(self, dvf, n):
"""Compute n-fold self-composition of dvf using scaling-and-squaring.
Uses binary decomposition: O(log n) STN calls instead of O(n).
E.g. n=87 → ~10 calls, n=200 → ~9 calls (vs 87/200 iterative calls).
The result is the same deformation (n-fold composition) but computed
via a different sequence of grid_sample interpolations, so there are
small numerical differences (~1e-2 to 1e-1) vs iterative composition.
This is acceptable because DDF generation is stochastic augmentation.
"""
if n <= 0:
return torch.zeros_like(dvf)
result = None
current = dvf # current = dvf^(2^i), starts as dvf^1
while n > 0:
if n & 1: # bit is set → accumulate this power
if result is None:
result = current.clone()
else:
# result = current ∘ result (apply result first, then current)
result = result + self.ddf_stn_rec(current, result)
n >>= 1
if n > 0:
# Square: current = current ∘ current
current = current + self.ddf_stn_rec(current, current)
return result
def _crop_upsample(self, field):
"""Upsample DDF from ctl_sz to img_sz with 2x oversampling + center crop.
Instead of upsampling the full ctl_sz→img_sz*2 (e.g. 32³→256³) then
cropping to img_sz (128³), we crop the control-point field first
(to ~20³) then upsample to ~160³ and crop to 128³. This is 4x faster
and bit-identical because trilinear interpolation is local.
"""
crop_rate = 2
upscale = self.img_sz * crop_rate // self.ctl_sz # e.g. 8
margin = 2 # voxels of margin for interpolation boundary
lo = self.ctl_sz // 4 - margin # e.g. 6
hi = self.ctl_sz * 3 // 4 + margin # e.g. 26
crop_sz = hi - lo # e.g. 20
up_sz = crop_sz * upscale # e.g. 160
pad = (up_sz - self.img_sz) // 2 # e.g. 16
mode = 'bilinear' if self.ndims == 2 else 'trilinear'
if self.ndims == 2:
field_crop = field[..., lo:hi, lo:hi] * self.img_sz / self.ctl_sz
field_up = F.interpolate(field_crop, up_sz, mode=mode)
return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz]
else:
field_crop = field[..., lo:hi, lo:hi, lo:hi] * self.img_sz / self.ctl_sz
field_up = F.interpolate(field_crop, up_sz, mode=mode)
return field_up[..., pad:pad + self.img_sz,
pad:pad + self.img_sz,
pad:pad + self.img_sz]
def _random_ddf_generate(self, rec_num=3, mul_num=[torch.tensor([5]), torch.tensor([5])],
ddf0=None, keep_inverse=False, noise_ratio=0.08, select_num=3, flip_ratio=0.5):
for _ in range(self.ndims + 1):
mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
if ddf0 is not None:
ddf = ddf0
else:
ddf = torch.zeros(ctl_ddf_sz, device=self.device)
dddf = torch.zeros(ctl_ddf_sz, device=self.device)
scale_num = min(8, int(math.log2(self.ctl_sz)))
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
for i in range(rec_num):
if len(ctl_szs_all) > select_num:
ctl_szs = random.sample(ctl_szs_all, select_num)
else:
ctl_szs = ctl_szs_all
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs)
if noise_ratio == 0:
dvf0 = dvf
else:
dvf0 = dvf + self.ddf_stn_rec(
self._multiscale_dvf_generate(self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False),
dvf)
mul_num_ddf_val = int(torch.max(mul_num[0]).item())
mul_num_dvf_val = int(torch.max(mul_num[1]).item())
# OPT: scaling-and-squaring — O(log n) STN calls instead of O(n)
# For t=40: 10 calls instead of 80. For t=79: 9 calls instead of 195.
ddf = self._compose_n_times(dvf0, mul_num_ddf_val)
dddf = self._compose_n_times(dvf, mul_num_dvf_val)
# OPT: crop-first upsample — 4x fewer voxels to interpolate (bit-identical)
ddf = self._crop_upsample(ddf)
dddf = self._crop_upsample(dddf)
return ddf, dddf
# ------------------------------------------------------------------
# Optimization 6: generate DVF on device to avoid CPU→GPU transfer
# ------------------------------------------------------------------
def _multiscale_dvf_generate(self, v_scale, ctl_szs=[4, 8, 16, 32, 64], rand_v_scale=True):
dvf = 0
if self.img_sz is None:
self.img_sz = max(ctl_szs)
if 1 in ctl_szs:
dvf_rot = utils.random_ddf(
batch_size=self.batch_size, ndims=self.ndims,
img_sz=[self.ctl_sz] * self.ndims, range_gauss=0, rot_range=np.pi / 90)
dvf = dvf + dvf_rot
for ctl_sz in ctl_szs:
_v_scale = self._sample_random_uniform_multi_order(
high=v_scale, low=0., order_num=random.choice([1, 1, 2])) if rand_v_scale else v_scale
if ctl_sz <= 2:
_v_scale = _v_scale / 2
# OPT: generate random tensor directly on device
dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz] * self.ndims,
device=self.device) * _v_scale
dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims,
align_corners=False,
mode='bilinear' if self.ndims == 2 else 'trilinear')
dvf = dvf + dvf_comp
return dvf
# ------------------------------------------------------------------
# Optimization 3: skip clone for 'uncon' (most common conditioning type)
# ------------------------------------------------------------------
def proc_cond_img(self, img, proc_type=None, noise_scale=0.1):
if proc_type is None:
proc_type = random.choices(
['adding', 'independ', 'downsample', 'slice', 'slice1', 'none', 'uncon'],
weights=[1, 1, 1, 1, 1, 3], k=1
)[0]
mask = torch.tensor(1, device=img.device)
cond_ratio = torch.tensor(1., device=img.device)
self.msk_noise_scale = torch.tensor(0, device=img.device)
noise_type = random.choice(['gaussian', 'uniform', 'none'])
if proc_type not in ['none', None, '']:
# OPT: handle 'uncon' before cloning — no need to clone img
if proc_type == 'uncon':
noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
proc_img = noise_map
mask = torch.tensor(0, device=img.device)
cond_ratio = torch.tensor(0, device=img.device)
return proc_img, mask, cond_ratio
# Only clone when we actually need the image data
proc_img = img.clone().detach()
noise_map = None
if proc_type in ['adding', 'independ', 'slice', 'slice1']:
noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
if proc_type == 'adding':
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
elif proc_type == 'independ':
mask = self.create_noise_map(img, noise_type='binary')
if self.msk_noise_scale == 0:
proc_img = img * mask
else:
proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
with torch.no_grad():
cond_ratio = mask.float().mean()
elif proc_type == 'downsample':
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1. / 64, 1])
cond_ratio = torch.tensor(down_ratio, device=img.device)
elif proc_type == 'slice' or proc_type == 'slice1':
if proc_type == 'slice1':
slice_num_max = 1
else:
slice_num_max = random.randint(1, 64)
slice_num_max = random.randint(1, slice_num_max)
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
if self.msk_noise_scale == 0:
proc_img = img * mask
else:
proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
cond_ratio = torch.tensor(sample_ratio, device=img.device)
elif proc_type == 'project':
proc_img, proj_num = self.project(proc_img)
cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
return proc_img, mask, cond_ratio
else:
# 'none' type — still need clone
proc_img = img.clone().detach()
return proc_img, mask, cond_ratio
# ------------------------------------------------------------------
# Optimization 1: hoist clone, pre-compute timestep tensors,
# use inference_mode for frozen iterations
# ------------------------------------------------------------------
def diff_recover(self, img_org, msk_org=None, T=[None, None], ddf_rand=None,
v_scale=None, t_save=None, cond_imgs=None, proc_type=None, text=None):
if cond_imgs is None:
cond_imgs = img_org.clone().detach()
cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(cond_imgs, proc_type=proc_type)
if ddf_rand is None:
if v_scale is not None:
self.v_scale = v_scale
self._DDF_Encoder_init()
if T[0] is None or T[0] == 0:
img_diff = img_org.clone().detach()
ddf_rand = torch.zeros_like(img_diff)
else:
img_diff, _, ddf_rand = self._get_random_ddf(
img=img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
else:
img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
ddf_comp = ddf_rand.clone().detach()
img_rec = img_diff.clone().detach()
if msk_org is not None:
msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
else:
msk_diff = None
msk_rec = msk_diff.clone().detach() if msk_org is not None else None
img_save = []
msk_save = []
# OPT: hoist clone().detach() outside the loop — grid_sample is read-only
img_org_ref = img_org.clone().detach()
msk_org_ref = msk_org.clone().detach() if msk_org is not None else None
if isinstance(self.network, DefRec_MutAttnNet):
t_list = list(range(T[1] - 1, -1, -1))
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list, rec_num=None, text=text)
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
img_rec = self.img_stn(img_org_ref, ddf_comp)
if msk_org is not None:
msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
else:
if isinstance(T[-1], int):
time_steps = range(T[-1] - 1, -1, -1)
trainable_iterations = []
else:
time_steps = T[-1]
k = 2
trainable_iterations = time_steps[-1:-k - 1:-1]
# OPT: pre-compute trainable index threshold — avoid unhashable list issue
t_save_set = set(t_save) if t_save is not None else None
num_time_steps = len(time_steps) if not isinstance(time_steps, range) else len(time_steps)
trainable_start_idx = num_time_steps - len(trainable_iterations)
for step_idx, i in enumerate(time_steps):
# OPT: create tensor directly on device, no numpy intermediate
t = torch.tensor([i], device=self.device)
if step_idx >= trainable_start_idx:
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
else:
# OPT: no_grad for frozen iterations (inference_mode not safe here
# because ddf_comp is composed across frozen+trainable iterations)
with torch.no_grad():
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
# OPT: use pre-cloned reference instead of cloning each iteration
img_rec = self.img_stn(img_org_ref, ddf_comp)
if msk_org is not None:
msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
if t_save_set is not None:
if i in t_save_set:
img_save.append(img_rec)
if msk_org is not None:
msk_save.append(msk_rec)
return [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save]