File size: 18,475 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""

diffuser_opt.py — Optimized DeformDDPM subclass.



Inherits from Diffusion.diffuser.DeformDDPM and overrides only the methods

that benefit from optimization.



Key optimizations:

  1. diff_recover(): hoist img_org/msk_org .clone().detach() outside the loop,

     pre-compute timestep tensors, use torch.no_grad() for frozen steps

  2. _random_ddf_generate(): scaling-and-squaring for O(log n) composition

     instead of O(n), crop-first upsampling (4x faster), on-device tensors.

  3. proc_cond_img(): skip clone for 'uncon' path (most common, ~3/8 weight)

  4. _DDF_Encoder_init(): use OptSTN (register_buffer, no per-call .to(device))

  5. recover(): fix t tensor bug (was staying on CPU), avoid redundant torch.tensor()

  6. _multiscale_dvf_generate(): generate random tensors on device to avoid

     CPU→GPU transfer of 3D volumes.

"""

from torch import nn
import torch
import numpy as np
import torch.nn.functional as F
import random
import math

import Diffusion.utils_diff as utils
from Diffusion.diffuser import DeformDDPM as _BaseDeformDDPM
from Diffusion.networks import *
from Diffusion.networks_opt import OptSTN

EPS = 1e-8


class DeformDDPM(_BaseDeformDDPM):
    """Drop-in replacement for DeformDDPM with speed optimizations."""

    # ------------------------------------------------------------------
    # Optimization 4: use OptSTN (register_buffer, no per-call .to())
    # ------------------------------------------------------------------
    def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
        if ctl_sz is None:
            ctl_sz = self.image_chw[1] // ctl_ratio
        self.ctl_sz = ctl_sz
        self.img_sz = self.image_chw[1]
        # OPT: use OptSTN instead of STN — register_buffer for ref_grid/max_sz
        self.ddf_stn_rec = OptSTN(img_sz=ctl_sz, ndims=self.ndims, device=self.device,
                                  padding_mode=self.ddf_pad_mode)
        self.img_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
                              padding_mode=self.img_pad_mode, resample_mode=self.resample_mode)
        self.msk_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
                              padding_mode=self.img_pad_mode, resample_mode='nearest')

    def __init__(self, network, n_steps=50, beta_schedule_fn=None, device='cpu',

                 image_chw=(1, 28, 28), batch_size=1, img_pad_mode="zeros",

                 ddf_pad_mode="border", padding_mode="border",

                 v_scale=0.008/256, resample_mode=None, inf_mode=False):
        # Call parent __init__ — it creates STN instances
        super().__init__(
            network=network, n_steps=n_steps, beta_schedule_fn=beta_schedule_fn,
            device=device, image_chw=image_chw, batch_size=batch_size,
            img_pad_mode=img_pad_mode, ddf_pad_mode=ddf_pad_mode,
            padding_mode=padding_mode, v_scale=v_scale, resample_mode=resample_mode,
            inf_mode=inf_mode,
        )
        # OPT: replace ddf_stn_full with OptSTN too
        self.ddf_stn_full = OptSTN(
            img_sz=self.image_chw[1], ndims=self.ndims,
            padding_mode=self.padding_mode, device=self.device,
        )

    # ------------------------------------------------------------------
    # Optimization 5: fix recover() t tensor bug + avoid redundant copies
    # ------------------------------------------------------------------
    def recover(self, x, y, t, rec_num=2, text=None):
        # OPT: don't recreate t if already a tensor on the right device
        if isinstance(t, list):
            t = [t0 if isinstance(t0, torch.Tensor) else torch.tensor(t0, device=x.device)
                 for t0 in t]
            t = [t0.to(x.device) if t0.device != x.device else t0 for t0 in t]
        elif isinstance(t, torch.Tensor):
            # OPT: skip torch.tensor() copy — just ensure correct device
            if t.device != x.device:
                t = t.to(x.device)
        else:
            t = torch.tensor(t, device=x.device)
        if rec_num is None:
            rec_num = self.rec_num
        return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)

    # ------------------------------------------------------------------
    # Optimization 2: scaling-and-squaring + crop-first upsample
    # ------------------------------------------------------------------
    def _compose_n_times(self, dvf, n):
        """Compute n-fold self-composition of dvf using scaling-and-squaring.



        Uses binary decomposition: O(log n) STN calls instead of O(n).

        E.g. n=87 → ~10 calls, n=200 → ~9 calls (vs 87/200 iterative calls).



        The result is the same deformation (n-fold composition) but computed

        via a different sequence of grid_sample interpolations, so there are

        small numerical differences (~1e-2 to 1e-1) vs iterative composition.

        This is acceptable because DDF generation is stochastic augmentation.

        """
        if n <= 0:
            return torch.zeros_like(dvf)
        result = None
        current = dvf  # current = dvf^(2^i), starts as dvf^1
        while n > 0:
            if n & 1:  # bit is set → accumulate this power
                if result is None:
                    result = current.clone()
                else:
                    # result = current ∘ result (apply result first, then current)
                    result = result + self.ddf_stn_rec(current, result)
            n >>= 1
            if n > 0:
                # Square: current = current ∘ current
                current = current + self.ddf_stn_rec(current, current)
        return result

    def _crop_upsample(self, field):
        """Upsample DDF from ctl_sz to img_sz with 2x oversampling + center crop.



        Instead of upsampling the full ctl_sz→img_sz*2 (e.g. 32³→256³) then

        cropping to img_sz (128³), we crop the control-point field first

        (to ~20³) then upsample to ~160³ and crop to 128³. This is 4x faster

        and bit-identical because trilinear interpolation is local.

        """
        crop_rate = 2
        upscale = self.img_sz * crop_rate // self.ctl_sz  # e.g. 8
        margin = 2  # voxels of margin for interpolation boundary
        lo = self.ctl_sz // 4 - margin    # e.g. 6
        hi = self.ctl_sz * 3 // 4 + margin  # e.g. 26
        crop_sz = hi - lo                   # e.g. 20
        up_sz = crop_sz * upscale           # e.g. 160
        pad = (up_sz - self.img_sz) // 2    # e.g. 16

        mode = 'bilinear' if self.ndims == 2 else 'trilinear'
        if self.ndims == 2:
            field_crop = field[..., lo:hi, lo:hi] * self.img_sz / self.ctl_sz
            field_up = F.interpolate(field_crop, up_sz, mode=mode)
            return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz]
        else:
            field_crop = field[..., lo:hi, lo:hi, lo:hi] * self.img_sz / self.ctl_sz
            field_up = F.interpolate(field_crop, up_sz, mode=mode)
            return field_up[..., pad:pad + self.img_sz,
                                 pad:pad + self.img_sz,
                                 pad:pad + self.img_sz]

    def _random_ddf_generate(self, rec_num=3, mul_num=[torch.tensor([5]), torch.tensor([5])],

                             ddf0=None, keep_inverse=False, noise_ratio=0.08, select_num=3, flip_ratio=0.5):
        for _ in range(self.ndims + 1):
            mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
        ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
        if ddf0 is not None:
            ddf = ddf0
        else:
            ddf = torch.zeros(ctl_ddf_sz, device=self.device)
        dddf = torch.zeros(ctl_ddf_sz, device=self.device)
        scale_num = min(8, int(math.log2(self.ctl_sz)))
        ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]

        for i in range(rec_num):
            if len(ctl_szs_all) > select_num:
                ctl_szs = random.sample(ctl_szs_all, select_num)
            else:
                ctl_szs = ctl_szs_all
            dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs)
            if noise_ratio == 0:
                dvf0 = dvf
            else:
                dvf0 = dvf + self.ddf_stn_rec(
                    self._multiscale_dvf_generate(self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False),
                    dvf)

            mul_num_ddf_val = int(torch.max(mul_num[0]).item())
            mul_num_dvf_val = int(torch.max(mul_num[1]).item())

            # OPT: scaling-and-squaring — O(log n) STN calls instead of O(n)
            # For t=40: 10 calls instead of 80. For t=79: 9 calls instead of 195.
            ddf = self._compose_n_times(dvf0, mul_num_ddf_val)
            dddf = self._compose_n_times(dvf, mul_num_dvf_val)

        # OPT: crop-first upsample — 4x fewer voxels to interpolate (bit-identical)
        ddf = self._crop_upsample(ddf)
        dddf = self._crop_upsample(dddf)
        return ddf, dddf

    # ------------------------------------------------------------------
    # Optimization 6: generate DVF on device to avoid CPU→GPU transfer
    # ------------------------------------------------------------------
    def _multiscale_dvf_generate(self, v_scale, ctl_szs=[4, 8, 16, 32, 64], rand_v_scale=True):
        dvf = 0
        if self.img_sz is None:
            self.img_sz = max(ctl_szs)
        if 1 in ctl_szs:
            dvf_rot = utils.random_ddf(
                batch_size=self.batch_size, ndims=self.ndims,
                img_sz=[self.ctl_sz] * self.ndims, range_gauss=0, rot_range=np.pi / 90)
            dvf = dvf + dvf_rot
        for ctl_sz in ctl_szs:
            _v_scale = self._sample_random_uniform_multi_order(
                high=v_scale, low=0., order_num=random.choice([1, 1, 2])) if rand_v_scale else v_scale
            if ctl_sz <= 2:
                _v_scale = _v_scale / 2
            # OPT: generate random tensor directly on device
            dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz] * self.ndims,
                                   device=self.device) * _v_scale
            dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims,
                                     align_corners=False,
                                     mode='bilinear' if self.ndims == 2 else 'trilinear')
            dvf = dvf + dvf_comp
        return dvf

    # ------------------------------------------------------------------
    # Optimization 3: skip clone for 'uncon' (most common conditioning type)
    # ------------------------------------------------------------------
    def proc_cond_img(self, img, proc_type=None, noise_scale=0.1):
        if proc_type is None:
            proc_type = random.choices(
                ['adding', 'independ', 'downsample', 'slice', 'slice1', 'none', 'uncon'],
                weights=[1, 1, 1, 1, 1, 3], k=1
            )[0]
        mask = torch.tensor(1, device=img.device)
        cond_ratio = torch.tensor(1., device=img.device)
        self.msk_noise_scale = torch.tensor(0, device=img.device)
        noise_type = random.choice(['gaussian', 'uniform', 'none'])

        if proc_type not in ['none', None, '']:
            # OPT: handle 'uncon' before cloning — no need to clone img
            if proc_type == 'uncon':
                noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
                proc_img = noise_map
                mask = torch.tensor(0, device=img.device)
                cond_ratio = torch.tensor(0, device=img.device)
                return proc_img, mask, cond_ratio

            # Only clone when we actually need the image data
            proc_img = img.clone().detach()
            noise_map = None
            if proc_type in ['adding', 'independ', 'slice', 'slice1']:
                noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
            if proc_type == 'adding':
                proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
                cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
            elif proc_type == 'independ':
                mask = self.create_noise_map(img, noise_type='binary')
                if self.msk_noise_scale == 0:
                    proc_img = img * mask
                else:
                    proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
                with torch.no_grad():
                    cond_ratio = mask.float().mean()
            elif proc_type == 'downsample':
                proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1. / 64, 1])
                cond_ratio = torch.tensor(down_ratio, device=img.device)
            elif proc_type == 'slice' or proc_type == 'slice1':
                if proc_type == 'slice1':
                    slice_num_max = 1
                else:
                    slice_num_max = random.randint(1, 64)
                    slice_num_max = random.randint(1, slice_num_max)
                mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
                if self.msk_noise_scale == 0:
                    proc_img = img * mask
                else:
                    proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
                cond_ratio = torch.tensor(sample_ratio, device=img.device)
            elif proc_type == 'project':
                proc_img, proj_num = self.project(proc_img)
                cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
            return proc_img, mask, cond_ratio
        else:
            # 'none' type — still need clone
            proc_img = img.clone().detach()
            return proc_img, mask, cond_ratio

    # ------------------------------------------------------------------
    # Optimization 1: hoist clone, pre-compute timestep tensors,
    #                  use inference_mode for frozen iterations
    # ------------------------------------------------------------------
    def diff_recover(self, img_org, msk_org=None, T=[None, None], ddf_rand=None,

                     v_scale=None, t_save=None, cond_imgs=None, proc_type=None, text=None):
        if cond_imgs is None:
            cond_imgs = img_org.clone().detach()
        cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(cond_imgs, proc_type=proc_type)
        if ddf_rand is None:
            if v_scale is not None:
                self.v_scale = v_scale
                self._DDF_Encoder_init()
            if T[0] is None or T[0] == 0:
                img_diff = img_org.clone().detach()
                ddf_rand = torch.zeros_like(img_diff)
            else:
                img_diff, _, ddf_rand = self._get_random_ddf(
                    img=img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
        else:
            img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
        ddf_comp = ddf_rand.clone().detach()
        img_rec = img_diff.clone().detach()
        if msk_org is not None:
            msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
        else:
            msk_diff = None
        msk_rec = msk_diff.clone().detach() if msk_org is not None else None
        img_save = []
        msk_save = []

        # OPT: hoist clone().detach() outside the loop — grid_sample is read-only
        img_org_ref = img_org.clone().detach()
        msk_org_ref = msk_org.clone().detach() if msk_org is not None else None

        if isinstance(self.network, DefRec_MutAttnNet):
            t_list = list(range(T[1] - 1, -1, -1))
            pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list, rec_num=None, text=text)
            ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
            img_rec = self.img_stn(img_org_ref, ddf_comp)
            if msk_org is not None:
                msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
        else:
            if isinstance(T[-1], int):
                time_steps = range(T[-1] - 1, -1, -1)
                trainable_iterations = []
            else:
                time_steps = T[-1]
                k = 2
                trainable_iterations = time_steps[-1:-k - 1:-1]

            # OPT: pre-compute trainable index threshold — avoid unhashable list issue
            t_save_set = set(t_save) if t_save is not None else None
            num_time_steps = len(time_steps) if not isinstance(time_steps, range) else len(time_steps)
            trainable_start_idx = num_time_steps - len(trainable_iterations)

            for step_idx, i in enumerate(time_steps):
                # OPT: create tensor directly on device, no numpy intermediate
                t = torch.tensor([i], device=self.device)

                if step_idx >= trainable_start_idx:
                    pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
                else:
                    # OPT: no_grad for frozen iterations (inference_mode not safe here
                    # because ddf_comp is composed across frozen+trainable iterations)
                    with torch.no_grad():
                        pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)

                ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
                # OPT: use pre-cloned reference instead of cloning each iteration
                img_rec = self.img_stn(img_org_ref, ddf_comp)
                if msk_org is not None:
                    msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
                if t_save_set is not None:
                    if i in t_save_set:
                        img_save.append(img_rec)
                        if msk_org is not None:
                            msk_save.append(msk_rec)

        return [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save]