Omini3D / Diffusion /diffuser-reg.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
from torch import nn
import torch
import numpy as np
from torch.nn.utils.stateless import functional_call
import Diffusion.utils_diff as utils
from Diffusion.networks import *
# from networks import *
import random
EPS = 1e-8
class DeformDDPM(nn.Module):
def __init__(
self,
network,
n_steps=50,
beta_schedule_fn = None,
device='cpu',
image_chw=(1, 28, 28),
batch_size = 1,
img_pad_mode = "zeros",
ddf_pad_mode="border",
padding_mode="border",
v_scale = 0.008/256,
resample_mode=None,
inf_mode = False,
):
super(DeformDDPM, self).__init__()
self.rec_num=2
self.ndims=len(image_chw)-1
self.n_steps = n_steps
self.v_scale = v_scale
self.device = device
self.msk_noise_scale = torch.tensor(0)
# self.msk_noise_scale = torch.tensor(1)
# print('================')
# print("device:",device)
# if device == 'cpu':
# print("num_device: 1")
# else:
# print("num_device:", torch.cuda.device_count())
# print('================')
self.num_device = torch.cuda.device_count()
self.batch_size = batch_size #//self.num_device
self.img_pad_mode = img_pad_mode
self.ddf_pad_mode = ddf_pad_mode
self.padding_mode = padding_mode
self.resample_mode = resample_mode
self.image_chw = image_chw
self.network = network#.to(self.device)
self.ddf_stn_full = STN(
img_sz = self.image_chw[1],
ndims = self.ndims,
padding_mode = self.padding_mode,
device = self.device,
)
self._DDF_Encoder_init()
self.copy_opt = nn.Identity()
self.inf_mode = inf_mode
return
def get_stn(self):
return self.img_stn, self.ddf_stn_full
def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
if ctl_sz is None:
ctl_sz = self.image_chw[1] // ctl_ratio
self.ctl_sz=ctl_sz
self.img_sz=self.image_chw[1]
self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode)
self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode)
self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest')
def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
rec_num = 1
mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
# mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
# mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
return rec_num,mul_num_ddf,mul_num_dvf
# def _sample_random_uniform_multi_order(self, high=None, low=0, order_num=3):
# # high: tensor of shape (...), low: int or tensor broadcastable to high
# sample_num = torch.full_like(high, low) if not isinstance(low, torch.Tensor) else low.clone()
# for _ in range(order_num):
# # For each element, sample in [sample_num, high]
# # torch.randint requires scalar low/high, so we use elementwise sampling
# rand_shape = high.shape
# # Clamp sample_num to be <= high
# sample_num = torch.minimum(sample_num, high)
# # Generate random numbers for each element
# rand = torch.empty(rand_shape, dtype=high.dtype, device=high.device)
# for idx in np.ndindex(rand_shape):
# l = sample_num[idx].item()
# h = high[idx].item()
# if l >= h:
# rand[idx] = l
# else:
# rand[idx] = torch.randint(l, h + 1, (1,), device=high.device)
# sample_num = rand.to(high.dtype)
# return sample_num
def _get_random_ddf(self,img,t):
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
warped_img = self.img_stn(img,ddf_forward)
return warped_img, dvf_forward,ddf_forward
def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True):
dvf=0
if self.img_sz is None:
self.img_sz=max(ctl_szs)
if 1 in ctl_szs:
dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
dvf = dvf + dvf_rot
for ctl_sz in ctl_szs:
_v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
# temp>>
if ctl_sz <= 2:
_v_scale = _v_scale/2
# temp<<
dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale
dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear')
dvf=dvf+dvf_comp
return dvf
def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3):
sample_value = low
for _ in range(order_num):
sample_value = np.random.uniform(low=sample_value, high=high)
return sample_value
def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
crop_rate=2
for _ in range(self.ndims+1):
mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
# v_scale = v_scale *crop_rate
ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
if ddf0 is not None:
ddf=ddf0
else:
ddf = torch.zeros(ctl_ddf_sz) * 0
dddf = torch.zeros(ctl_ddf_sz) * 0
scale_num = min(8,int(math.log2(self.ctl_sz))) # allow affine
# scale_num = min(5,int(math.log2(self.ctl_sz))-1) # semi-allow affine
# scale_num = min(5,int(math.log2(self.ctl_sz))-2) # avoid coupling between deformation and affine
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
for i in range(rec_num):
# Randomly select 5 elements from ctl_szs (if there are at least 5)
if len(ctl_szs_all) > select_num:
ctl_szs = random.sample(ctl_szs_all, select_num)
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
# if True:
if noise_ratio==0:
dvf0=dvf
else:
dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf)
# print([num.shape for num in mul_num])
for j in range(torch.max(mul_num[0]).item()):
flag = [(n>j).int().to(self.device) for n in mul_num]
ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0])
dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1])
ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
# ddf = ddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
if self.ndims==2:
ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
else:
ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
# if rec_num==1:
if True:
dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
# dddf = dddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
if self.ndims == 2:
dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
else:
dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
return ddf,dddf
else:
return ddf
def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1):
if noise_type == 'gaussian':
noise_map = torch.randn_like(img) * noise_scale
elif noise_type == 'uniform':
noise_map = torch.rand_like(img)*noise_scale*2-noise_scale # 0-1
elif noise_type == 'binary':
noise_map = torch.bernoulli(torch.rand_like(img))
else:
noise_map = torch.zeros_like(img)
noise_map = noise_map.to(img.device)
return noise_map
def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]):
noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio
def apply_noise(self, img, noise_map=None, apply_mask=None):
return img * apply_mask + noise_map * (1-apply_mask)
def downsample(self, img, down_ratio_range=[1./32,1]):
down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims]))
# print(down_ratio)
down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear')
# print(down_img)
# return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.prod(down_ratio)
return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio)) # jzheng: cond weight based on entropy
def get_slice_mask(self, img, slice_num_range=[0,32]):
slice_num_range[1] = min(slice_num_range[1], self.image_chw[1])
mask = torch.zeros_like(img)
sample_ratio = 0
for i in range(self.ndims):
if self.inf_mode:
slice_num = 1 # use max slice num for inference for better performance
slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
else:
slice_num = random.randint(slice_num_range[0], slice_num_range[1])
slice_idx = random.sample(range(self.image_chw[1]), slice_num)
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
for idx in slice_idx:
mask[..., idx] = 1
mask = mask.permute(*transpose_list)
# sample_ratio += slice_num / self.image_chw[1] / self.ndims
sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims # jzheng: cond weight based on entropy
# print(mask)
# print("sample_ratio:", sample_ratio)
return mask, sample_ratio
def project(self, img):
proj_img = torch.zeros_like(img)
rand_bourn = np.random.randint(0, 2, size=[self.ndims])
proj_dim_num = np.sum(rand_bourn)
for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn):
if pflag:
proj_img += torch.mean(img, dim=i, keepdim=True)
# print("projecting dim:", i)
return proj_img/(proj_dim_num+EPS), proj_dim_num
def proc_cond_img(self, img, proc_type=None,noise_scale=0.1):
# Remove torch.no_grad() since most operations are not differentiable anyway
proc_img = img.clone().detach()
if proc_type is None:
# Heavily bias towards 'uncon' for efficiency
proc_type = random.choices(
# ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
# weights=[1, 1, 1, 1, 1, 1, 3], k=1
['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'],
weights=[1, 1, 1, 1, 1, 3], k=1
)[0]
mask = torch.tensor(1, device=img.device)
cond_ratio = torch.tensor(1., device=img.device)
self.msk_noise_scale = torch.tensor(0, device=img.device)
noise_type = random.choice(['gaussian', 'uniform', 'none'])
# Precompute noise_map only if needed
noise_map = None
if proc_type not in ['none', None, '']:
if proc_type == 'uncon':
noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
proc_img = noise_map
mask = torch.tensor(0, device=img.device)
cond_ratio = torch.tensor(0, device=img.device)
return proc_img, mask, cond_ratio
if proc_type in ['adding', 'independ', 'slice','slice1']:
# self.msk_noise_scale = 0
noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
if proc_type == 'adding':
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
elif proc_type == 'independ':
mask = self.create_noise_map(img, noise_type='binary')
if self.msk_noise_scale == 0:
proc_img = img * mask
else:
proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
with torch.no_grad():
cond_ratio = mask.float().mean()
elif proc_type == 'downsample':
# proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
cond_ratio = torch.tensor(down_ratio, device=img.device)
elif proc_type == 'slice' or proc_type == 'slice1':
if proc_type == 'slice1':
slice_num_max = 1
else:
slice_num_max = random.randint(1, 64)
slice_num_max = random.randint(1, slice_num_max)
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
if self.msk_noise_scale == 0:
proc_img = img * mask
else:
proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
cond_ratio = torch.tensor(sample_ratio, device=img.device)
elif proc_type == 'project':
proc_img, proj_num = self.project(proc_img)
cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
# cond_ratio = torch.tensor(proj_num / (32 * self.ndims), device=img.device) # jzheng: cond weight based on entropy
return proc_img, mask, cond_ratio
def diffuse(self, x_0, t):
t=torch.tensor(t)
# img_t, dvf_forward, ddf_forward, ddf_stn, img_stn = self.ddf_enc(img= x_0, t=t)
# return img_t, dvf_forward,ddf_forward,ddf_stn,img_stn
return self._get_random_ddf(img = x_0, t = t)
def recover(self, x, y, t,rec_num=2, text=None):
if isinstance(t, list):
t=[torch.tensor(t0) for t0 in t]
t=[t0.to(x.device) for t0 in t]
else:
t=torch.tensor(t)
t.to(x.device)
if rec_num is None:
rec_num = self.rec_num
return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None):
"""
use detach to recover:
- but not include no_grad
"""
if isinstance(t, list):
t = [torch.tensor(t0, device=x.device) for t0 in t]
else:
t = torch.tensor(t, device=x.device)
if rec_num is None:
rec_num = self.rec_num
# params = {k: v.detach() for k, v in self.network.named_parameters()}
# buffers = dict(self.network.named_buffers()) # BN running stats etc. buffer
# # functional_call require position args,here kwargs doesnot work, so:
# def _forward(module, kw):
# return module(**kw)
# # functional_call(module, ...) can only pass args/kwargs to module.forward
# # PyTorch 2.x support functional_call(module, (params, buffers), args, kwargs)
# return functional_call(
# self.network,
# (params, buffers),
# args=(),
# kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
# )
# 1) param detached
params = {k: v.detach() for k, v in self.network.named_parameters()}
# 2) buffers keeps unchanged
buffers = dict(self.network.named_buffers())
# 3) old version of PyTorch doesnot support passing params and buffers together
params_and_buffers = {}
params_and_buffers.update(params)
params_and_buffers.update(buffers)
return functional_call(
self.network,
params_and_buffers,
(),
kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
)
def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None):
if mask is None:
mask = 1
# org_imgs=self.copy_opt(x0)
if cond_imgs is None:
cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type)
noisy_imgs, dvf_I,_ = self.diffuse(x0, t)
if isinstance(self.network,DefRec_MutAttnNet):
t = [t] * 1
return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
if T is not None:
return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
else:
return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
# if mask is None:
# mask = 1
# cond_imgs = self.proc_cond_img(x0, proc_type=proc_type, **kwargs)
# noisy_imgs, dvf_I, _ = self.diffuse(x0, t)
# if isinstance(self.network, DefRec_MutAttnNet):
# t = [t] * 1
# return self.recover(x=noisy_imgs * mask, y=cond_imgs, t=t, rec_num=rec_num), dvf_I
def diff_recover(self,
img_org,
msk_org=None,
T=[None,None],
ddf_rand=None,
v_scale = None,
t_save=None,
cond_imgs=None,
proc_type=None,
text=None,
):
if cond_imgs is None:
cond_imgs = img_org.clone().detach()
# if proc_type is not None:
cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type)
if ddf_rand is None:
if v_scale is not None:
self.v_scale=v_scale
self._DDF_Encoder_init()
if T[0] is None or T[0] == 0:
img_diff = img_org.clone().detach()
ddf_rand = torch.zeros_like(img_diff)
else:
img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
else:
img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
ddf_comp = ddf_rand.clone().detach()
img_rec = img_diff.clone().detach()
if msk_org is not None:
msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
else:
msk_diff = None
msk_rec = msk_diff.clone().detach() if msk_org is not None else None
img_save=[]
msk_save=[]
if isinstance(self.network,DefRec_MutAttnNet):
# Denosing image via list of t
t_list = list(range(T[1]-1, -1, -1))
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text)
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
if msk_org is not None:
msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
else:
# Denosing image
if isinstance(T[-1], int):
time_steps = range(T[-1] - 1, -1, -1)
trainable_iterations =[]
else:
time_steps = T[-1]
# # Randomly select k iterations to make their parameters trainable
# win_len = 2 # Number of iterations to make trainable
# if len(time_steps) <= win_len:
# win_start = 0
# else:
# win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
# win_end = win_start + win_len - 1
k=2
# trainable_iterations = time_steps[win_start: win_start + win_len]
# trainable_iterations = random.sample(time_steps, k)
trainable_iterations = time_steps[-1:-k-1:-1]
# print(time_steps)
# print("trainable_iterations:", trainable_iterations)
for i in time_steps:
t = torch.tensor(np.array([i])).to(self.device)
if i in trainable_iterations:
# Make parameters trainable for this iteration
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
else:
# Freeze parameters for this iteration using torch.no_grad()
with torch.no_grad():
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
# for idx, i in enumerate(time_steps):
# t = torch.tensor(np.array([i])).to(self.device)
# if idx < win_start:
# # just no_grad
# with torch.no_grad():
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
# elif win_start <= idx <= win_end:
# # normal update
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
# else:
# # freeze params but keep grad for input
# pre_dvf_I = self.recover_frozen_params_but_grad_input(
# x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text
# )
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
# Apply to image
img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
if msk_org is not None:
msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
if t_save is not None:
if i in t_save:
img_save.append(img_rec)
if msk_org is not None:
msk_save.append(msk_rec)
# for i in time_steps:
# t = torch.tensor(np.array([i])).to(self.device)
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t,rec_num=None)
# ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
# # apply to image
# img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
# if msk_org is not None:
# msk_rec = self.img_stn(msk_org.clone().detach(), ddf_comp)
# if t_save is not None:
# if i in t_save:
# img_save.append(img_rec)
# if msk_org is not None:
# msk_save.append(msk_rec)
# print(torch.max(torch.abs(ddf_comp)))
# print(torch.max(torch.abs(ddf_rand)))
return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save]
if __name__ == "__main__":
H, W = 8, 8
deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu')
# img = torch.zeros([1, 1, H, W])
img = torch.randn([1, 1, H, W])
t = 1
rec_num = 2
# proc_type = 'adding'
# proc_type = 'independ'
# proc_type = 'downsample'
proc_type = 'slice'
# proc_type = 'project'
# proc_type = 'none'
print(img)
cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type)
print(cond_imgs)
# img_rec, dvf_I = deformddpm.forward(img, t, rec_num=rec_num, proc_type=proc_type)
# print(img_rec.shape, dvf_I.shape)
# proc_type = 'adding'
# ddf_comp, ddf_rand = deformddpm.diff_recover(img, T=[1,1], proc_type=proc_type)