from torch import nn import torch import numpy as np from torch.nn.utils.stateless import functional_call import Diffusion.utils_diff as utils from Diffusion.networks import * # from networks import * import random EPS = 1e-8 class DeformDDPM(nn.Module): def __init__( self, network, n_steps=50, beta_schedule_fn = None, device='cpu', image_chw=(1, 28, 28), batch_size = 1, img_pad_mode = "zeros", ddf_pad_mode="border", padding_mode="border", v_scale = 0.008/256, resample_mode=None, inf_mode = False, ): super(DeformDDPM, self).__init__() self.rec_num=2 self.ndims=len(image_chw)-1 self.n_steps = n_steps self.v_scale = v_scale self.device = device self.msk_noise_scale = torch.tensor(0) # self.msk_noise_scale = torch.tensor(1) # print('================') # print("device:",device) # if device == 'cpu': # print("num_device: 1") # else: # print("num_device:", torch.cuda.device_count()) # print('================') self.num_device = torch.cuda.device_count() self.batch_size = batch_size #//self.num_device self.img_pad_mode = img_pad_mode self.ddf_pad_mode = ddf_pad_mode self.padding_mode = padding_mode self.resample_mode = resample_mode self.image_chw = image_chw self.network = network#.to(self.device) self.ddf_stn_full = STN( img_sz = self.image_chw[1], ndims = self.ndims, padding_mode = self.padding_mode, device = self.device, ) self._DDF_Encoder_init() self.copy_opt = nn.Identity() self.inf_mode = inf_mode return def get_stn(self): return self.img_stn, self.ddf_stn_full def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None): if ctl_sz is None: ctl_sz = self.image_chw[1] // ctl_ratio self.ctl_sz=ctl_sz self.img_sz=self.image_chw[1] self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode) self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode) self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest') def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128 rec_num = 1 mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int() mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int() # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf) # mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf) # mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf) mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num) mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num) # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf) return rec_num,mul_num_ddf,mul_num_dvf # def _sample_random_uniform_multi_order(self, high=None, low=0, order_num=3): # # high: tensor of shape (...), low: int or tensor broadcastable to high # sample_num = torch.full_like(high, low) if not isinstance(low, torch.Tensor) else low.clone() # for _ in range(order_num): # # For each element, sample in [sample_num, high] # # torch.randint requires scalar low/high, so we use elementwise sampling # rand_shape = high.shape # # Clamp sample_num to be <= high # sample_num = torch.minimum(sample_num, high) # # Generate random numbers for each element # rand = torch.empty(rand_shape, dtype=high.dtype, device=high.device) # for idx in np.ndindex(rand_shape): # l = sample_num[idx].item() # h = high[idx].item() # if l >= h: # rand[idx] = l # else: # rand[idx] = torch.randint(l, h + 1, (1,), device=high.device) # sample_num = rand.to(high.dtype) # return sample_num def _get_random_ddf(self,img,t): rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t) ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf]) warped_img = self.img_stn(img,ddf_forward) return warped_img, dvf_forward,ddf_forward def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True): dvf=0 if self.img_sz is None: self.img_sz=max(ctl_szs) if 1 in ctl_szs: dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90) dvf = dvf + dvf_rot for ctl_sz in ctl_szs: _v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale # temp>> if ctl_sz <= 2: _v_scale = _v_scale/2 # temp<< dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear') dvf=dvf+dvf_comp return dvf def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3): sample_value = low for _ in range(order_num): sample_value = np.random.uniform(low=sample_value, high=high) return sample_value def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5): crop_rate=2 for _ in range(self.ndims+1): mul_num=[torch.unsqueeze(n,-1) for n in mul_num] # v_scale = v_scale *crop_rate ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims if ddf0 is not None: ddf=ddf0 else: ddf = torch.zeros(ctl_ddf_sz) * 0 dddf = torch.zeros(ctl_ddf_sz) * 0 scale_num = min(8,int(math.log2(self.ctl_sz))) # allow affine # scale_num = min(5,int(math.log2(self.ctl_sz))-1) # semi-allow affine # scale_num = min(5,int(math.log2(self.ctl_sz))-2) # avoid coupling between deformation and affine ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)] for i in range(rec_num): # Randomly select 5 elements from ctl_szs (if there are at least 5) if len(ctl_szs_all) > select_num: ctl_szs = random.sample(ctl_szs_all, select_num) dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device) # if True: if noise_ratio==0: dvf0=dvf else: dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf) # print([num.shape for num in mul_num]) for j in range(torch.max(mul_num[0]).item()): flag = [(n>j).int().to(self.device) for n in mul_num] ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0]) dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1]) ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear') # ddf = ddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2] if self.ndims==2: ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2] else: ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2] # if rec_num==1: if True: dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear') # dddf = dddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2] if self.ndims == 2: dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2] else: dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2] return ddf,dddf else: return ddf def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1): if noise_type == 'gaussian': noise_map = torch.randn_like(img) * noise_scale elif noise_type == 'uniform': noise_map = torch.rand_like(img)*noise_scale*2-noise_scale # 0-1 elif noise_type == 'binary': noise_map = torch.bernoulli(torch.rand_like(img)) else: noise_map = torch.zeros_like(img) noise_map = noise_map.to(img.device) return noise_map def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]): noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1]) return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio def apply_noise(self, img, noise_map=None, apply_mask=None): return img * apply_mask + noise_map * (1-apply_mask) def downsample(self, img, down_ratio_range=[1./32,1]): down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims])) # print(down_ratio) down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear') # print(down_img) # return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.prod(down_ratio) return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio)) # jzheng: cond weight based on entropy def get_slice_mask(self, img, slice_num_range=[0,32]): slice_num_range[1] = min(slice_num_range[1], self.image_chw[1]) mask = torch.zeros_like(img) sample_ratio = 0 for i in range(self.ndims): if self.inf_mode: slice_num = 1 # use max slice num for inference for better performance slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance else: slice_num = random.randint(slice_num_range[0], slice_num_range[1]) slice_idx = random.sample(range(self.image_chw[1]), slice_num) transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims)) for idx in slice_idx: mask[..., idx] = 1 mask = mask.permute(*transpose_list) # sample_ratio += slice_num / self.image_chw[1] / self.ndims sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims # jzheng: cond weight based on entropy # print(mask) # print("sample_ratio:", sample_ratio) return mask, sample_ratio def project(self, img): proj_img = torch.zeros_like(img) rand_bourn = np.random.randint(0, 2, size=[self.ndims]) proj_dim_num = np.sum(rand_bourn) for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn): if pflag: proj_img += torch.mean(img, dim=i, keepdim=True) # print("projecting dim:", i) return proj_img/(proj_dim_num+EPS), proj_dim_num def proc_cond_img(self, img, proc_type=None,noise_scale=0.1): # Remove torch.no_grad() since most operations are not differentiable anyway proc_img = img.clone().detach() if proc_type is None: # Heavily bias towards 'uncon' for efficiency proc_type = random.choices( # ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'], # weights=[1, 1, 1, 1, 1, 1, 3], k=1 ['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'], weights=[1, 1, 1, 1, 1, 3], k=1 )[0] mask = torch.tensor(1, device=img.device) cond_ratio = torch.tensor(1., device=img.device) self.msk_noise_scale = torch.tensor(0, device=img.device) noise_type = random.choice(['gaussian', 'uniform', 'none']) # Precompute noise_map only if needed noise_map = None if proc_type not in ['none', None, '']: if proc_type == 'uncon': noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale) proc_img = noise_map mask = torch.tensor(0, device=img.device) cond_ratio = torch.tensor(0, device=img.device) return proc_img, mask, cond_ratio if proc_type in ['adding', 'independ', 'slice','slice1']: # self.msk_noise_scale = 0 noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale) if proc_type == 'adding': proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.]) cond_ratio = torch.tensor(1 - noise_ratio, device=img.device) elif proc_type == 'independ': mask = self.create_noise_map(img, noise_type='binary') if self.msk_noise_scale == 0: proc_img = img * mask else: proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask) with torch.no_grad(): cond_ratio = mask.float().mean() elif proc_type == 'downsample': # proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1]) proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1]) cond_ratio = torch.tensor(down_ratio, device=img.device) elif proc_type == 'slice' or proc_type == 'slice1': if proc_type == 'slice1': slice_num_max = 1 else: slice_num_max = random.randint(1, 64) slice_num_max = random.randint(1, slice_num_max) mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max]) if self.msk_noise_scale == 0: proc_img = img * mask else: proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask) cond_ratio = torch.tensor(sample_ratio, device=img.device) elif proc_type == 'project': proc_img, proj_num = self.project(proc_img) cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device) # cond_ratio = torch.tensor(proj_num / (32 * self.ndims), device=img.device) # jzheng: cond weight based on entropy return proc_img, mask, cond_ratio def diffuse(self, x_0, t): t=torch.tensor(t) # img_t, dvf_forward, ddf_forward, ddf_stn, img_stn = self.ddf_enc(img= x_0, t=t) # return img_t, dvf_forward,ddf_forward,ddf_stn,img_stn return self._get_random_ddf(img = x_0, t = t) def recover(self, x, y, t,rec_num=2, text=None): if isinstance(t, list): t=[torch.tensor(t0) for t0 in t] t=[t0.to(x.device) for t0 in t] else: t=torch.tensor(t) t.to(x.device) if rec_num is None: rec_num = self.rec_num return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text) def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None): """ use detach to recover: - but not include no_grad """ if isinstance(t, list): t = [torch.tensor(t0, device=x.device) for t0 in t] else: t = torch.tensor(t, device=x.device) if rec_num is None: rec_num = self.rec_num # params = {k: v.detach() for k, v in self.network.named_parameters()} # buffers = dict(self.network.named_buffers()) # BN running stats etc. buffer # # functional_call require position args,here kwargs doesnot work, so: # def _forward(module, kw): # return module(**kw) # # functional_call(module, ...) can only pass args/kwargs to module.forward # # PyTorch 2.x support functional_call(module, (params, buffers), args, kwargs) # return functional_call( # self.network, # (params, buffers), # args=(), # kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text), # ) # 1) param detached params = {k: v.detach() for k, v in self.network.named_parameters()} # 2) buffers keeps unchanged buffers = dict(self.network.named_buffers()) # 3) old version of PyTorch doesnot support passing params and buffers together params_and_buffers = {} params_and_buffers.update(params) params_and_buffers.update(buffers) return functional_call( self.network, params_and_buffers, (), kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text), ) def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None): if mask is None: mask = 1 # org_imgs=self.copy_opt(x0) if cond_imgs is None: cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type) noisy_imgs, dvf_I,_ = self.diffuse(x0, t) if isinstance(self.network,DefRec_MutAttnNet): t = [t] * 1 return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs): if T is not None: return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs) else: return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs) # if mask is None: # mask = 1 # cond_imgs = self.proc_cond_img(x0, proc_type=proc_type, **kwargs) # noisy_imgs, dvf_I, _ = self.diffuse(x0, t) # if isinstance(self.network, DefRec_MutAttnNet): # t = [t] * 1 # return self.recover(x=noisy_imgs * mask, y=cond_imgs, t=t, rec_num=rec_num), dvf_I def diff_recover(self, img_org, msk_org=None, T=[None,None], ddf_rand=None, v_scale = None, t_save=None, cond_imgs=None, proc_type=None, text=None, ): if cond_imgs is None: cond_imgs = img_org.clone().detach() # if proc_type is not None: cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type) if ddf_rand is None: if v_scale is not None: self.v_scale=v_scale self._DDF_Encoder_init() if T[0] is None or T[0] == 0: img_diff = img_org.clone().detach() ddf_rand = torch.zeros_like(img_diff) else: img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device)) else: img_diff = self.img_stn(img_org.clone().detach(), ddf_rand) ddf_comp = ddf_rand.clone().detach() img_rec = img_diff.clone().detach() if msk_org is not None: msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand) else: msk_diff = None msk_rec = msk_diff.clone().detach() if msk_org is not None else None img_save=[] msk_save=[] if isinstance(self.network,DefRec_MutAttnNet): # Denosing image via list of t t_list = list(range(T[1]-1, -1, -1)) pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text) ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I img_rec = self.img_stn(img_org.clone().detach(), ddf_comp) if msk_org is not None: msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp) else: # Denosing image if isinstance(T[-1], int): time_steps = range(T[-1] - 1, -1, -1) trainable_iterations =[] else: time_steps = T[-1] # # Randomly select k iterations to make their parameters trainable # win_len = 2 # Number of iterations to make trainable # if len(time_steps) <= win_len: # win_start = 0 # else: # win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len) # win_end = win_start + win_len - 1 k=2 # trainable_iterations = time_steps[win_start: win_start + win_len] # trainable_iterations = random.sample(time_steps, k) trainable_iterations = time_steps[-1:-k-1:-1] # print(time_steps) # print("trainable_iterations:", trainable_iterations) for i in time_steps: t = torch.tensor(np.array([i])).to(self.device) if i in trainable_iterations: # Make parameters trainable for this iteration pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text) else: # Freeze parameters for this iteration using torch.no_grad() with torch.no_grad(): pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text) # for idx, i in enumerate(time_steps): # t = torch.tensor(np.array([i])).to(self.device) # if idx < win_start: # # just no_grad # with torch.no_grad(): # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text) # elif win_start <= idx <= win_end: # # normal update # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text) # else: # # freeze params but keep grad for input # pre_dvf_I = self.recover_frozen_params_but_grad_input( # x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text # ) ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I # Apply to image img_rec = self.img_stn(img_org.clone().detach(), ddf_comp) if msk_org is not None: msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp) if t_save is not None: if i in t_save: img_save.append(img_rec) if msk_org is not None: msk_save.append(msk_rec) # for i in time_steps: # t = torch.tensor(np.array([i])).to(self.device) # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t,rec_num=None) # ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I # # apply to image # img_rec = self.img_stn(img_org.clone().detach(), ddf_comp) # if msk_org is not None: # msk_rec = self.img_stn(msk_org.clone().detach(), ddf_comp) # if t_save is not None: # if i in t_save: # img_save.append(img_rec) # if msk_org is not None: # msk_save.append(msk_rec) # print(torch.max(torch.abs(ddf_comp))) # print(torch.max(torch.abs(ddf_rand))) return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] if __name__ == "__main__": H, W = 8, 8 deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu') # img = torch.zeros([1, 1, H, W]) img = torch.randn([1, 1, H, W]) t = 1 rec_num = 2 # proc_type = 'adding' # proc_type = 'independ' # proc_type = 'downsample' proc_type = 'slice' # proc_type = 'project' # proc_type = 'none' print(img) cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type) print(cond_imgs) # img_rec, dvf_I = deformddpm.forward(img, t, rec_num=rec_num, proc_type=proc_type) # print(img_rec.shape, dvf_I.shape) # proc_type = 'adding' # ddf_comp, ddf_rand = deformddpm.diff_recover(img, T=[1,1], proc_type=proc_type)