| from torch import nn
|
| import torch
|
| import numpy as np
|
| from torch.nn.utils.stateless import functional_call
|
|
|
| import Diffusion.utils_diff as utils
|
| from Diffusion.networks import *
|
|
|
|
|
| import random
|
|
|
| EPS = 1e-8
|
|
|
|
|
|
|
| class DeformDDPM(nn.Module):
|
| def __init__(
|
| self,
|
| network,
|
| n_steps=50,
|
| beta_schedule_fn = None,
|
| device='cpu',
|
| image_chw=(1, 28, 28),
|
| batch_size = 1,
|
| img_pad_mode = "zeros",
|
| ddf_pad_mode="border",
|
| padding_mode="border",
|
| v_scale = 0.008/256,
|
| resample_mode=None,
|
| inf_mode = False,
|
| ):
|
| super(DeformDDPM, self).__init__()
|
| self.rec_num=2
|
| self.ndims=len(image_chw)-1
|
| self.n_steps = n_steps
|
| self.v_scale = v_scale
|
| self.device = device
|
| self.msk_noise_scale = torch.tensor(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.num_device = torch.cuda.device_count()
|
|
|
| self.batch_size = batch_size
|
| self.img_pad_mode = img_pad_mode
|
| self.ddf_pad_mode = ddf_pad_mode
|
| self.padding_mode = padding_mode
|
| self.resample_mode = resample_mode
|
| self.image_chw = image_chw
|
| self.network = network
|
| self.ddf_stn_full = STN(
|
| img_sz = self.image_chw[1],
|
| ndims = self.ndims,
|
| padding_mode = self.padding_mode,
|
| device = self.device,
|
| )
|
| self._DDF_Encoder_init()
|
| self.copy_opt = nn.Identity()
|
| self.inf_mode = inf_mode
|
| return
|
|
|
| def get_stn(self):
|
| return self.img_stn, self.ddf_stn_full
|
|
|
| def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
|
| if ctl_sz is None:
|
| ctl_sz = self.image_chw[1] // ctl_ratio
|
| self.ctl_sz=ctl_sz
|
| self.img_sz=self.image_chw[1]
|
| self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode)
|
| self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode)
|
| self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest')
|
|
|
| def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200):
|
| rec_num = 1
|
| mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
|
| mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
|
|
|
|
|
|
|
| mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
|
| mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
|
|
|
| return rec_num,mul_num_ddf,mul_num_dvf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _get_random_ddf(self,img,t):
|
| rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
|
| warped_img = self.img_stn(img,ddf_forward)
|
| return warped_img, dvf_forward,ddf_forward
|
|
|
| def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True):
|
| dvf=0
|
| if self.img_sz is None:
|
| self.img_sz=max(ctl_szs)
|
| if 1 in ctl_szs:
|
| dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
|
| dvf = dvf + dvf_rot
|
| for ctl_sz in ctl_szs:
|
| _v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
|
|
|
| if ctl_sz <= 2:
|
| _v_scale = _v_scale/2
|
|
|
| dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale
|
| dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| dvf=dvf+dvf_comp
|
| return dvf
|
|
|
| def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3):
|
| sample_value = low
|
| for _ in range(order_num):
|
| sample_value = np.random.uniform(low=sample_value, high=high)
|
| return sample_value
|
|
|
| def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
|
| crop_rate=2
|
| for _ in range(self.ndims+1):
|
| mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
|
|
|
| ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
|
| if ddf0 is not None:
|
| ddf=ddf0
|
| else:
|
| ddf = torch.zeros(ctl_ddf_sz) * 0
|
| dddf = torch.zeros(ctl_ddf_sz) * 0
|
| scale_num = min(8,int(math.log2(self.ctl_sz)))
|
|
|
|
|
| ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
|
|
|
| for i in range(rec_num):
|
|
|
| if len(ctl_szs_all) > select_num:
|
| ctl_szs = random.sample(ctl_szs_all, select_num)
|
| dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
|
|
|
| if noise_ratio==0:
|
| dvf0=dvf
|
| else:
|
| dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf)
|
|
|
| for j in range(torch.max(mul_num[0]).item()):
|
| flag = [(n>j).int().to(self.device) for n in mul_num]
|
| ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0])
|
| dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1])
|
|
|
| ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
|
|
| if self.ndims==2:
|
| ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| else:
|
| ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
|
|
| if True:
|
| dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
|
|
| if self.ndims == 2:
|
| dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| else:
|
| dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| return ddf,dddf
|
| else:
|
| return ddf
|
|
|
| def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1):
|
| if noise_type == 'gaussian':
|
| noise_map = torch.randn_like(img) * noise_scale
|
| elif noise_type == 'uniform':
|
| noise_map = torch.rand_like(img)*noise_scale*2-noise_scale
|
| elif noise_type == 'binary':
|
| noise_map = torch.bernoulli(torch.rand_like(img))
|
| else:
|
| noise_map = torch.zeros_like(img)
|
| noise_map = noise_map.to(img.device)
|
| return noise_map
|
|
|
| def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]):
|
| noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
|
| return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio
|
|
|
| def apply_noise(self, img, noise_map=None, apply_mask=None):
|
| return img * apply_mask + noise_map * (1-apply_mask)
|
|
|
| def downsample(self, img, down_ratio_range=[1./32,1]):
|
| down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims]))
|
|
|
| down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
|
|
|
|
| return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio))
|
|
|
| def get_slice_mask(self, img, slice_num_range=[0,32]):
|
| slice_num_range[1] = min(slice_num_range[1], self.image_chw[1])
|
| mask = torch.zeros_like(img)
|
| sample_ratio = 0
|
| for i in range(self.ndims):
|
| if self.inf_mode:
|
| slice_num = 1
|
| slice_idx = [self.image_chw[1]//2]
|
| else:
|
| slice_num = random.randint(slice_num_range[0], slice_num_range[1])
|
| slice_idx = random.sample(range(self.image_chw[1]), slice_num)
|
| transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| for idx in slice_idx:
|
| mask[..., idx] = 1
|
| mask = mask.permute(*transpose_list)
|
|
|
| sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims
|
|
|
|
|
|
|
| return mask, sample_ratio
|
|
|
| def project(self, img):
|
| proj_img = torch.zeros_like(img)
|
| rand_bourn = np.random.randint(0, 2, size=[self.ndims])
|
| proj_dim_num = np.sum(rand_bourn)
|
| for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn):
|
| if pflag:
|
| proj_img += torch.mean(img, dim=i, keepdim=True)
|
|
|
| return proj_img/(proj_dim_num+EPS), proj_dim_num
|
|
|
| def proc_cond_img(self, img, proc_type=None,noise_scale=0.1):
|
|
|
| proc_img = img.clone().detach()
|
| if proc_type is None:
|
|
|
| proc_type = random.choices(
|
|
|
|
|
| ['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'],
|
| weights=[1, 1, 1, 1, 1, 3], k=1
|
| )[0]
|
| mask = torch.tensor(1, device=img.device)
|
| cond_ratio = torch.tensor(1., device=img.device)
|
| self.msk_noise_scale = torch.tensor(0, device=img.device)
|
| noise_type = random.choice(['gaussian', 'uniform', 'none'])
|
|
|
| noise_map = None
|
| if proc_type not in ['none', None, '']:
|
| if proc_type == 'uncon':
|
| noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
|
| proc_img = noise_map
|
| mask = torch.tensor(0, device=img.device)
|
| cond_ratio = torch.tensor(0, device=img.device)
|
| return proc_img, mask, cond_ratio
|
| if proc_type in ['adding', 'independ', 'slice','slice1']:
|
|
|
| noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
|
| if proc_type == 'adding':
|
| proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
|
| cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
| elif proc_type == 'independ':
|
| mask = self.create_noise_map(img, noise_type='binary')
|
| if self.msk_noise_scale == 0:
|
| proc_img = img * mask
|
| else:
|
| proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
|
| with torch.no_grad():
|
| cond_ratio = mask.float().mean()
|
| elif proc_type == 'downsample':
|
|
|
| proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
|
| cond_ratio = torch.tensor(down_ratio, device=img.device)
|
| elif proc_type == 'slice' or proc_type == 'slice1':
|
| if proc_type == 'slice1':
|
| slice_num_max = 1
|
| else:
|
| slice_num_max = random.randint(1, 64)
|
| slice_num_max = random.randint(1, slice_num_max)
|
| mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| if self.msk_noise_scale == 0:
|
| proc_img = img * mask
|
| else:
|
| proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
|
| cond_ratio = torch.tensor(sample_ratio, device=img.device)
|
| elif proc_type == 'project':
|
| proc_img, proj_num = self.project(proc_img)
|
| cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
|
|
|
| return proc_img, mask, cond_ratio
|
|
|
| def diffuse(self, x_0, t):
|
| t=torch.tensor(t)
|
|
|
|
|
| return self._get_random_ddf(img = x_0, t = t)
|
|
|
|
|
| def recover(self, x, y, t,rec_num=2, text=None):
|
| if isinstance(t, list):
|
| t=[torch.tensor(t0) for t0 in t]
|
| t=[t0.to(x.device) for t0 in t]
|
| else:
|
| t=torch.tensor(t)
|
| t.to(x.device)
|
| if rec_num is None:
|
| rec_num = self.rec_num
|
| return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
|
|
|
| def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None):
|
| """
|
| use detach to recover:
|
| - but not include no_grad
|
| """
|
| if isinstance(t, list):
|
| t = [torch.tensor(t0, device=x.device) for t0 in t]
|
| else:
|
| t = torch.tensor(t, device=x.device)
|
|
|
| if rec_num is None:
|
| rec_num = self.rec_num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| params = {k: v.detach() for k, v in self.network.named_parameters()}
|
|
|
| buffers = dict(self.network.named_buffers())
|
|
|
|
|
| params_and_buffers = {}
|
| params_and_buffers.update(params)
|
| params_and_buffers.update(buffers)
|
| return functional_call(
|
| self.network,
|
| params_and_buffers,
|
| (),
|
| kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
|
| )
|
|
|
|
|
| def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None):
|
| if mask is None:
|
| mask = 1
|
|
|
| if cond_imgs is None:
|
| cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type)
|
| noisy_imgs, dvf_I,_ = self.diffuse(x0, t)
|
| if isinstance(self.network,DefRec_MutAttnNet):
|
| t = [t] * 1
|
| return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
|
|
|
| def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
|
| if T is not None:
|
| return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| else:
|
| return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def diff_recover(self,
|
| img_org,
|
| msk_org=None,
|
| T=[None,None],
|
| ddf_rand=None,
|
| v_scale = None,
|
| t_save=None,
|
| cond_imgs=None,
|
| proc_type=None,
|
| text=None,
|
| ):
|
| if cond_imgs is None:
|
| cond_imgs = img_org.clone().detach()
|
|
|
| cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type)
|
| if ddf_rand is None:
|
| if v_scale is not None:
|
| self.v_scale=v_scale
|
| self._DDF_Encoder_init()
|
| if T[0] is None or T[0] == 0:
|
| img_diff = img_org.clone().detach()
|
| ddf_rand = torch.zeros_like(img_diff)
|
| else:
|
| img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
|
| else:
|
| img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
|
| ddf_comp = ddf_rand.clone().detach()
|
| img_rec = img_diff.clone().detach()
|
| if msk_org is not None:
|
| msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
|
| else:
|
| msk_diff = None
|
| msk_rec = msk_diff.clone().detach() if msk_org is not None else None
|
| img_save=[]
|
| msk_save=[]
|
|
|
| if isinstance(self.network,DefRec_MutAttnNet):
|
|
|
| t_list = list(range(T[1]-1, -1, -1))
|
| pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text)
|
| ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| if msk_org is not None:
|
| msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
|
| else:
|
|
|
| if isinstance(T[-1], int):
|
| time_steps = range(T[-1] - 1, -1, -1)
|
| trainable_iterations =[]
|
| else:
|
| time_steps = T[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| k=2
|
|
|
|
|
| trainable_iterations = time_steps[-1:-k-1:-1]
|
|
|
|
|
| for i in time_steps:
|
| t = torch.tensor(np.array([i])).to(self.device)
|
|
|
| if i in trainable_iterations:
|
|
|
| pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| else:
|
|
|
| with torch.no_grad():
|
| pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
|
|
| img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| if msk_org is not None:
|
| msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
|
| if t_save is not None:
|
| if i in t_save:
|
| img_save.append(img_rec)
|
| if msk_org is not None:
|
| msk_save.append(msk_rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save]
|
|
|
| if __name__ == "__main__":
|
| H, W = 8, 8
|
| deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu')
|
|
|
| img = torch.randn([1, 1, H, W])
|
| t = 1
|
| rec_num = 2
|
|
|
|
|
|
|
| proc_type = 'slice'
|
|
|
|
|
| print(img)
|
| cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type)
|
| print(cond_imgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| |