File size: 18,475 Bytes
2af0e94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 | """
diffuser_opt.py — Optimized DeformDDPM subclass.
Inherits from Diffusion.diffuser.DeformDDPM and overrides only the methods
that benefit from optimization.
Key optimizations:
1. diff_recover(): hoist img_org/msk_org .clone().detach() outside the loop,
pre-compute timestep tensors, use torch.no_grad() for frozen steps
2. _random_ddf_generate(): scaling-and-squaring for O(log n) composition
instead of O(n), crop-first upsampling (4x faster), on-device tensors.
3. proc_cond_img(): skip clone for 'uncon' path (most common, ~3/8 weight)
4. _DDF_Encoder_init(): use OptSTN (register_buffer, no per-call .to(device))
5. recover(): fix t tensor bug (was staying on CPU), avoid redundant torch.tensor()
6. _multiscale_dvf_generate(): generate random tensors on device to avoid
CPU→GPU transfer of 3D volumes.
"""
from torch import nn
import torch
import numpy as np
import torch.nn.functional as F
import random
import math
import Diffusion.utils_diff as utils
from Diffusion.diffuser import DeformDDPM as _BaseDeformDDPM
from Diffusion.networks import *
from Diffusion.networks_opt import OptSTN
EPS = 1e-8
class DeformDDPM(_BaseDeformDDPM):
"""Drop-in replacement for DeformDDPM with speed optimizations."""
# ------------------------------------------------------------------
# Optimization 4: use OptSTN (register_buffer, no per-call .to())
# ------------------------------------------------------------------
def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
if ctl_sz is None:
ctl_sz = self.image_chw[1] // ctl_ratio
self.ctl_sz = ctl_sz
self.img_sz = self.image_chw[1]
# OPT: use OptSTN instead of STN — register_buffer for ref_grid/max_sz
self.ddf_stn_rec = OptSTN(img_sz=ctl_sz, ndims=self.ndims, device=self.device,
padding_mode=self.ddf_pad_mode)
self.img_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
padding_mode=self.img_pad_mode, resample_mode=self.resample_mode)
self.msk_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
padding_mode=self.img_pad_mode, resample_mode='nearest')
def __init__(self, network, n_steps=50, beta_schedule_fn=None, device='cpu',
image_chw=(1, 28, 28), batch_size=1, img_pad_mode="zeros",
ddf_pad_mode="border", padding_mode="border",
v_scale=0.008/256, resample_mode=None, inf_mode=False):
# Call parent __init__ — it creates STN instances
super().__init__(
network=network, n_steps=n_steps, beta_schedule_fn=beta_schedule_fn,
device=device, image_chw=image_chw, batch_size=batch_size,
img_pad_mode=img_pad_mode, ddf_pad_mode=ddf_pad_mode,
padding_mode=padding_mode, v_scale=v_scale, resample_mode=resample_mode,
inf_mode=inf_mode,
)
# OPT: replace ddf_stn_full with OptSTN too
self.ddf_stn_full = OptSTN(
img_sz=self.image_chw[1], ndims=self.ndims,
padding_mode=self.padding_mode, device=self.device,
)
# ------------------------------------------------------------------
# Optimization 5: fix recover() t tensor bug + avoid redundant copies
# ------------------------------------------------------------------
def recover(self, x, y, t, rec_num=2, text=None):
# OPT: don't recreate t if already a tensor on the right device
if isinstance(t, list):
t = [t0 if isinstance(t0, torch.Tensor) else torch.tensor(t0, device=x.device)
for t0 in t]
t = [t0.to(x.device) if t0.device != x.device else t0 for t0 in t]
elif isinstance(t, torch.Tensor):
# OPT: skip torch.tensor() copy — just ensure correct device
if t.device != x.device:
t = t.to(x.device)
else:
t = torch.tensor(t, device=x.device)
if rec_num is None:
rec_num = self.rec_num
return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
# ------------------------------------------------------------------
# Optimization 2: scaling-and-squaring + crop-first upsample
# ------------------------------------------------------------------
def _compose_n_times(self, dvf, n):
"""Compute n-fold self-composition of dvf using scaling-and-squaring.
Uses binary decomposition: O(log n) STN calls instead of O(n).
E.g. n=87 → ~10 calls, n=200 → ~9 calls (vs 87/200 iterative calls).
The result is the same deformation (n-fold composition) but computed
via a different sequence of grid_sample interpolations, so there are
small numerical differences (~1e-2 to 1e-1) vs iterative composition.
This is acceptable because DDF generation is stochastic augmentation.
"""
if n <= 0:
return torch.zeros_like(dvf)
result = None
current = dvf # current = dvf^(2^i), starts as dvf^1
while n > 0:
if n & 1: # bit is set → accumulate this power
if result is None:
result = current.clone()
else:
# result = current ∘ result (apply result first, then current)
result = result + self.ddf_stn_rec(current, result)
n >>= 1
if n > 0:
# Square: current = current ∘ current
current = current + self.ddf_stn_rec(current, current)
return result
def _crop_upsample(self, field):
"""Upsample DDF from ctl_sz to img_sz with 2x oversampling + center crop.
Instead of upsampling the full ctl_sz→img_sz*2 (e.g. 32³→256³) then
cropping to img_sz (128³), we crop the control-point field first
(to ~20³) then upsample to ~160³ and crop to 128³. This is 4x faster
and bit-identical because trilinear interpolation is local.
"""
crop_rate = 2
upscale = self.img_sz * crop_rate // self.ctl_sz # e.g. 8
margin = 2 # voxels of margin for interpolation boundary
lo = self.ctl_sz // 4 - margin # e.g. 6
hi = self.ctl_sz * 3 // 4 + margin # e.g. 26
crop_sz = hi - lo # e.g. 20
up_sz = crop_sz * upscale # e.g. 160
pad = (up_sz - self.img_sz) // 2 # e.g. 16
mode = 'bilinear' if self.ndims == 2 else 'trilinear'
if self.ndims == 2:
field_crop = field[..., lo:hi, lo:hi] * self.img_sz / self.ctl_sz
field_up = F.interpolate(field_crop, up_sz, mode=mode)
return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz]
else:
field_crop = field[..., lo:hi, lo:hi, lo:hi] * self.img_sz / self.ctl_sz
field_up = F.interpolate(field_crop, up_sz, mode=mode)
return field_up[..., pad:pad + self.img_sz,
pad:pad + self.img_sz,
pad:pad + self.img_sz]
def _random_ddf_generate(self, rec_num=3, mul_num=[torch.tensor([5]), torch.tensor([5])],
ddf0=None, keep_inverse=False, noise_ratio=0.08, select_num=3, flip_ratio=0.5):
for _ in range(self.ndims + 1):
mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
if ddf0 is not None:
ddf = ddf0
else:
ddf = torch.zeros(ctl_ddf_sz, device=self.device)
dddf = torch.zeros(ctl_ddf_sz, device=self.device)
scale_num = min(8, int(math.log2(self.ctl_sz)))
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
for i in range(rec_num):
if len(ctl_szs_all) > select_num:
ctl_szs = random.sample(ctl_szs_all, select_num)
else:
ctl_szs = ctl_szs_all
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs)
if noise_ratio == 0:
dvf0 = dvf
else:
dvf0 = dvf + self.ddf_stn_rec(
self._multiscale_dvf_generate(self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False),
dvf)
mul_num_ddf_val = int(torch.max(mul_num[0]).item())
mul_num_dvf_val = int(torch.max(mul_num[1]).item())
# OPT: scaling-and-squaring — O(log n) STN calls instead of O(n)
# For t=40: 10 calls instead of 80. For t=79: 9 calls instead of 195.
ddf = self._compose_n_times(dvf0, mul_num_ddf_val)
dddf = self._compose_n_times(dvf, mul_num_dvf_val)
# OPT: crop-first upsample — 4x fewer voxels to interpolate (bit-identical)
ddf = self._crop_upsample(ddf)
dddf = self._crop_upsample(dddf)
return ddf, dddf
# ------------------------------------------------------------------
# Optimization 6: generate DVF on device to avoid CPU→GPU transfer
# ------------------------------------------------------------------
def _multiscale_dvf_generate(self, v_scale, ctl_szs=[4, 8, 16, 32, 64], rand_v_scale=True):
dvf = 0
if self.img_sz is None:
self.img_sz = max(ctl_szs)
if 1 in ctl_szs:
dvf_rot = utils.random_ddf(
batch_size=self.batch_size, ndims=self.ndims,
img_sz=[self.ctl_sz] * self.ndims, range_gauss=0, rot_range=np.pi / 90)
dvf = dvf + dvf_rot
for ctl_sz in ctl_szs:
_v_scale = self._sample_random_uniform_multi_order(
high=v_scale, low=0., order_num=random.choice([1, 1, 2])) if rand_v_scale else v_scale
if ctl_sz <= 2:
_v_scale = _v_scale / 2
# OPT: generate random tensor directly on device
dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz] * self.ndims,
device=self.device) * _v_scale
dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims,
align_corners=False,
mode='bilinear' if self.ndims == 2 else 'trilinear')
dvf = dvf + dvf_comp
return dvf
# ------------------------------------------------------------------
# Optimization 3: skip clone for 'uncon' (most common conditioning type)
# ------------------------------------------------------------------
def proc_cond_img(self, img, proc_type=None, noise_scale=0.1):
if proc_type is None:
proc_type = random.choices(
['adding', 'independ', 'downsample', 'slice', 'slice1', 'none', 'uncon'],
weights=[1, 1, 1, 1, 1, 3], k=1
)[0]
mask = torch.tensor(1, device=img.device)
cond_ratio = torch.tensor(1., device=img.device)
self.msk_noise_scale = torch.tensor(0, device=img.device)
noise_type = random.choice(['gaussian', 'uniform', 'none'])
if proc_type not in ['none', None, '']:
# OPT: handle 'uncon' before cloning — no need to clone img
if proc_type == 'uncon':
noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
proc_img = noise_map
mask = torch.tensor(0, device=img.device)
cond_ratio = torch.tensor(0, device=img.device)
return proc_img, mask, cond_ratio
# Only clone when we actually need the image data
proc_img = img.clone().detach()
noise_map = None
if proc_type in ['adding', 'independ', 'slice', 'slice1']:
noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
if proc_type == 'adding':
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
elif proc_type == 'independ':
mask = self.create_noise_map(img, noise_type='binary')
if self.msk_noise_scale == 0:
proc_img = img * mask
else:
proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
with torch.no_grad():
cond_ratio = mask.float().mean()
elif proc_type == 'downsample':
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1. / 64, 1])
cond_ratio = torch.tensor(down_ratio, device=img.device)
elif proc_type == 'slice' or proc_type == 'slice1':
if proc_type == 'slice1':
slice_num_max = 1
else:
slice_num_max = random.randint(1, 64)
slice_num_max = random.randint(1, slice_num_max)
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
if self.msk_noise_scale == 0:
proc_img = img * mask
else:
proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
cond_ratio = torch.tensor(sample_ratio, device=img.device)
elif proc_type == 'project':
proc_img, proj_num = self.project(proc_img)
cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
return proc_img, mask, cond_ratio
else:
# 'none' type — still need clone
proc_img = img.clone().detach()
return proc_img, mask, cond_ratio
# ------------------------------------------------------------------
# Optimization 1: hoist clone, pre-compute timestep tensors,
# use inference_mode for frozen iterations
# ------------------------------------------------------------------
def diff_recover(self, img_org, msk_org=None, T=[None, None], ddf_rand=None,
v_scale=None, t_save=None, cond_imgs=None, proc_type=None, text=None):
if cond_imgs is None:
cond_imgs = img_org.clone().detach()
cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(cond_imgs, proc_type=proc_type)
if ddf_rand is None:
if v_scale is not None:
self.v_scale = v_scale
self._DDF_Encoder_init()
if T[0] is None or T[0] == 0:
img_diff = img_org.clone().detach()
ddf_rand = torch.zeros_like(img_diff)
else:
img_diff, _, ddf_rand = self._get_random_ddf(
img=img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
else:
img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
ddf_comp = ddf_rand.clone().detach()
img_rec = img_diff.clone().detach()
if msk_org is not None:
msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
else:
msk_diff = None
msk_rec = msk_diff.clone().detach() if msk_org is not None else None
img_save = []
msk_save = []
# OPT: hoist clone().detach() outside the loop — grid_sample is read-only
img_org_ref = img_org.clone().detach()
msk_org_ref = msk_org.clone().detach() if msk_org is not None else None
if isinstance(self.network, DefRec_MutAttnNet):
t_list = list(range(T[1] - 1, -1, -1))
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list, rec_num=None, text=text)
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
img_rec = self.img_stn(img_org_ref, ddf_comp)
if msk_org is not None:
msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
else:
if isinstance(T[-1], int):
time_steps = range(T[-1] - 1, -1, -1)
trainable_iterations = []
else:
time_steps = T[-1]
k = 2
trainable_iterations = time_steps[-1:-k - 1:-1]
# OPT: pre-compute trainable index threshold — avoid unhashable list issue
t_save_set = set(t_save) if t_save is not None else None
num_time_steps = len(time_steps) if not isinstance(time_steps, range) else len(time_steps)
trainable_start_idx = num_time_steps - len(trainable_iterations)
for step_idx, i in enumerate(time_steps):
# OPT: create tensor directly on device, no numpy intermediate
t = torch.tensor([i], device=self.device)
if step_idx >= trainable_start_idx:
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
else:
# OPT: no_grad for frozen iterations (inference_mode not safe here
# because ddf_comp is composed across frozen+trainable iterations)
with torch.no_grad():
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
# OPT: use pre-cloned reference instead of cloning each iteration
img_rec = self.img_stn(img_org_ref, ddf_comp)
if msk_org is not None:
msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
if t_save_set is not None:
if i in t_save_set:
img_save.append(img_rec)
if msk_org is not None:
msk_save.append(msk_rec)
return [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save]
|