import os, sys ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) import torch from torch.utils.data import Dataset, DataLoader import json import SimpleITK as sitk import numpy as np from skimage.transform import rescale, resize, downscale_local_mean # from torchvision.transforms import v2 # sys.path.append('./') sys.path.append(ROOT_DIR) from Dataloader.dataloader_utils import * import random # add your mapping files here # mapping_files = { # 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json', # 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json', # # 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_1/nifti_mappings.json', # } # mapping_files = { # 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json', # 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json', # 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json', # 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json', # 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json', # # 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json', # 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json', # 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json', # 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json', # 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json', # 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json', # 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json', # 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json', # 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json', # } mapping_files = { 'MSD': 'nifty_mappings/MSD_mappings.json', 'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json', 'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json', 'CancerImageArchive': 'nifty_mappings/CIA_mappings.json', 'MnMs': 'nifty_mappings/MnMs_mappings.json', # 'Brats2019': 'nifty_mappings/Brats2019_mappings.json', # should be commented out after testing 'Brats2020': 'nifty_mappings/Brats2020_mappings.json', 'Brats2021': 'nifty_mappings/Brats2021_mappings.json', 'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json', 'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json', 'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json', 'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json', 'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json', 'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json', 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_KL_mappings.json', # 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_WOMAC_mappings.json', # alternative: WOMAC scores instead of KL-grade } for k,v in mapping_files.items(): mapping_files[k] = os.path.join(ROOT_DIR, v) CLAMP_RANGE = [-400, 400] # default clamp range for the images indivi_ROI_list = ['abdomen','arm','brain','hand','head','leg','neck','pelvis','skeleton','thorax'] def reverse_axis_order(arr): """SimpleITK to NumPy axis order conversion.""" # For 3D or 4D arrays, this is just a fast view, not a copy. return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1]))) def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high'): """Sample a random value from a uniform distribution with multiple orders. Args: high (float): Upper bound of the uniform distribution. low (float): Lower bound of the uniform distribution. order_num (int): Number of times to sample. type (str): 'high' or 'low', determines the sampling direction. Returns: sample_value (float): The sampled value after multiple orders. Notes: - If type is 'high', samples are drawn iteratively from [low, high], each time using the previous sample as the new lower bound. - If type is 'low', samples are drawn iteratively from [low, high], each time using the previous sample as the new upper bound. - If order_num is 0, returns the low value. - If order_num is 1, returns a single random value from the uniform distribution. - If order_num is 2, returns a value from a linear distribution. - If order_num is 3, returns a value from a quadratic distribution. """ if type == 'high': sample_value = low for _ in range(order_num): sample_value = np.random.uniform(low=sample_value, high=high) elif type == 'low': sample_value = high for _ in range(order_num): sample_value = np.random.uniform(low, high=sample_value) return sample_value class OminiDataset(object): """Base class for OmniMorph datasets.""" def __init__(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files): # self.mappings = mapping_files self.ALLdata = self.combine_data(mappings = mapping_files) self.out_sz = out_sz self.reverse_axis_order = reverse_axis_order self.min_dim = min_dim self.clamp_range = clamp_range self.min_crop_ratio = min_crop_ratio self.transform = transform self.ndims = 3 def get_ALLdata(self): return self.ALLdata def get_all_ROI(self): # Get all the ROI options. and remove the reduntant ones ROIs = [] # ALLdata_filtered = data for k in self.ALLdata_filtered.keys(): ROIs.append(self.ALLdata[k]['ROI']) ROIs = set(ROIs) return ROIs def get_filter_ROIs(self,keep_single_roi=False): ALLdata_filtered = self.ALLdata_filtered.copy() # if keep_single_roi == True: # for k in self.ALLdata_filtered.keys(): # if '-' in self.ALLdata_filtered[k]['ROI']: # del ALLdata_filtered[k] # d = {k: v for k, v in ALLdata_filtered.items() if v['ROI'] in self.ROIs} for k in ALLdata_filtered.keys(): if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: del ALLdata_filtered[k] return ALLdata_filtered def combine_data(self, mappings = mapping_files): ALLdata = {} total_entries = 0 total_skipped = 0 for j in mappings.keys(): with open(mappings[j], 'r') as f: mappings_tmp = json.load(f) skipped = 0 for k, v in mappings_tmp.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings_tmp) - skipped total_entries += len(mappings_tmp) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def get_3D_volume(self, volume, select_channel = None): # Get a 3D volume from the 4D volume, sometime the input image may have 4 dimensions if self.reverse_axis_order: volume = reverse_axis_order(volume) if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] return volume def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: del ALLdata[k] return ALLdata def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def random_crop_3d(self, volume, crop_size=None): # Fast random crop with optional padding using NumPy d, h, w = volume.shape if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # Only pad if needed (avoid np.pad if not necessary) pad_d = max(0, crop_d - d) pad_h = max(0, crop_h - h) pad_w = max(0, crop_w - w) if pad_d or pad_h or pad_w: pad_width = ( (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), ) volume = np.pad(volume, pad_width, mode='constant', constant_values=0) d, h, w = volume.shape # Crop indices start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 # Use NumPy slicing (very fast) return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] class OminiDataset_v1(Dataset): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.2, reverse_axis_order = False): self.mappings = mapping_files self.ALLdata = self.combine_data() self.out_sz = out_sz self.reverse_axis_order = reverse_axis_order self.min_crop_ratio = min_crop_ratio self.crop_ratio_sample_order = 2 self.transform = transform self.clamp_range = clamp_range self.ndims = 3 # Start you filtering here self.ALLdata_filtered = self.get_filter_mindim() # self.min_dim = self.find_min_dim() def find_min_dim(self): # Find the minimum dimension of the images min_dim = 100000 for k in self.ALLdata.keys(): value = self.ALLdata[k] if min(value['Size']) < min_dim: min_dim = min(value['Size']) return min_dim def random_crop_3d(self, volume, crop_size=None): # Fast random crop with optional padding using NumPy d, h, w = volume.shape if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # Only pad if needed (avoid np.pad if not necessary) pad_d = max(0, crop_d - d) pad_h = max(0, crop_h - h) pad_w = max(0, crop_w - w) if pad_d or pad_h or pad_w: pad_width = ( (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), ) volume = np.pad(volume, pad_width, mode='constant', constant_values=0) d, h, w = volume.shape # Crop indices start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 # Use NumPy slicing (very fast) return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] def get_ALLdata(self): # Return all data return self.ALLdata def get_3D_volume(self, volume, select_channel = None): if self.reverse_axis_order: volume = reverse_axis_order(volume) if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] # print(f"Volume shape: {volume.shape}, selected channel: {select_channel}") return volume def get_filter_ROI(self, key_word): # Filter out images with a key word ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if key_word not in k["ROI"]: del ALLdata[k] return ALLdata def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: del ALLdata[k] return ALLdata def combine_data(self): ALLdata = {} total_entries = 0 total_skipped = 0 for j in self.mappings.keys(): with open(self.mappings[j], 'r') as f: mappings = json.load(f) skipped = 0 for k, v in mappings.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings) - skipped total_entries += len(mappings) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def __len__(self): return len(self.ALLdata_filtered.keys()) def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def __getitem__(self, idx): key = list(self.ALLdata_filtered.keys())[idx] if 0: print(key) volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) # if volume.ndim == 4: volume = self.get_3D_volume(volume) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) volume = self.normalize(volume) if self.min_crop_ratio is not None: # print(f'before volume_shape: {volume.shape}') # crop_ratio = np.random.uniform(self.min_crop_ratio, 1) crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high') # crop_size = int(min(volume.shape) * crop_ratio) crop_size = int(max(volume.shape) * crop_ratio) volume = self.random_crop_3d(volume, crop_size) volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) else: volume = self.random_crop_3d(volume, self.out_sz) volume = volume[None, :, :, :] if self.transform is not None: return self.transform(volume) return volume class OMDataset_indiv(Dataset): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, reverse_axis_order = False): # self.mappings = mapping_files self.ALLdata = self.combine_data(mappings=mapping_files) self.out_sz = out_sz self.max_sz = out_sz*8 self.reverse_axis_order = reverse_axis_order self.min_crop_ratio = min_crop_ratio self.crop_ratio_sample_order = 2 self.transform = transform self.clamp_range = clamp_range self.ndims = 3 # Start you filtering here # print(f"Filtering data with out_sz: {self.out_sz}, min_crop_ratio: {min_crop_ratio}") print(f"Diffusion mode: Total data size before filtering: {len(self.ALLdata)}") self.ALLdata_filtered = self.get_filter_mindim() print(f"Diffusion mode: Filtered data size: {len(self.ALLdata_filtered)}") # self.min_dim = self.find_min_dim() def find_min_dim(self): # Find the minimum dimension of the images min_dim = 100000 for k in self.ALLdata.keys(): value = self.ALLdata[k] if min(value['Size']) < min_dim: min_dim = min(value['Size']) return min_dim def random_crop_3d(self, volume, crop_size=None): # Fast random crop with optional padding using NumPy d, h, w = volume.shape if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # Only pad if needed (avoid np.pad if not necessary) pad_d = max(0, crop_d - d) pad_h = max(0, crop_h - h) pad_w = max(0, crop_w - w) if pad_d or pad_h or pad_w: pad_width = ( (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), ) volume = np.pad(volume, pad_width, mode='constant', constant_values=0) d, h, w = volume.shape # Crop indices start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 # Use NumPy slicing (very fast) return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] def get_ALLdata(self): # Return all data return self.ALLdata def get_3D_volume(self, volume, select_channel = None): if self.reverse_axis_order: volume = reverse_axis_order(volume) if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] # print(f"Volume shape: {volume.shape}, selected channel: {select_channel}") return volume def get_filter_ROI(self, key_word): # Filter out images with a key word ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if key_word not in k["ROI"]: del ALLdata[k] return ALLdata def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: del ALLdata[k] return ALLdata def combine_data(self, mappings = mapping_files): ALLdata = {} total_entries = 0 total_skipped = 0 for j in mappings.keys(): with open(mappings[j], 'r') as f: mappings_tmp = json.load(f) skipped = 0 for k, v in mappings_tmp.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings_tmp) - skipped total_entries += len(mappings_tmp) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def __len__(self): return len(self.ALLdata_filtered.keys()) def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def __getitem__(self, idx): key = list(self.ALLdata_filtered.keys())[idx] embd = self.ALLdata_filtered[key]['embd'] embd = np.array(embd, dtype=np.float32) if 0: print(key) volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) # if volume.ndim == 4: volume = self.get_3D_volume(volume) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) volume = self.normalize(volume) if self.min_crop_ratio is not None: # print(f'before volume_shape: {volume.shape}') # crop_ratio = np.random.uniform(self.min_crop_ratio, 1) crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high') # crop_size = int(min(volume.shape) * crop_ratio) crop_size = int(max(volume.shape) * crop_ratio) crop_size = min(crop_size, self.max_sz) volume = self.random_crop_3d(volume, crop_size) volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) else: volume = self.random_crop_3d(volume, self.out_sz) volume = volume[None, :, :, :] if self.transform is not None: return self.transform(volume) return [volume, embd] class OminiDataset_paired(Dataset): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.85, ROIs = None, modality = None, reverse_axis_order = False): # self.mappings = mapping_files self.ALLdata = self.combine_data(mappings=mapping_files) self.out_sz = out_sz self.sz_range = get_sizeRange_dict() self.min_dim_ratio = 0.5 self.reverse_axis_order = reverse_axis_order self.min_crop_ratio = min_crop_ratio self.transform = transform self.clamp_range = clamp_range self.ndims = 3 # Start you filtering here # print(f"Number of images before filtering: {len(self.ALLdata.keys())}") self.ALLdata_filtered = self.get_filter_mindim() # print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}") self.ALLdata_filtered = self.get_filter_modality(modality) # print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}") if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data self.ROIs = self.get_all_ROI() else: self.ROIs = ROIs self.ALLdata_filtered = self.get_filter_ROIs() # print(f"Number of images after ROI filtering: {len(self.ALLdata_filtered.keys())}") # filtering ends here def combine_data(self, mappings = mapping_files): ALLdata = {} total_entries = 0 total_skipped = 0 for j in mappings.keys(): with open(mappings[j], 'r') as f: mappings_tmp = json.load(f) skipped = 0 for k, v in mappings_tmp.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings_tmp) - skipped total_entries += len(mappings_tmp) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def random_crop_3d(self, volume, crop_size=None): # Fast random crop with optional padding using NumPy d, h, w = volume.shape if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # Only pad if needed (avoid np.pad if not necessary) pad_d = max(0, crop_d - d) pad_h = max(0, crop_h - h) pad_w = max(0, crop_w - w) if pad_d or pad_h or pad_w: pad_width = ( (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), ) volume = np.pad(volume, pad_width, mode='constant', constant_values=0) d, h, w = volume.shape # Crop indices start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 # Use NumPy slicing (very fast) return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] # def random_crop_3d(self, volume, crop_size=None): # # Randomly crop the image # d, h, w = volume.shape # if crop_size is None: # crop_size = self.out_sz # crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # if crop_d > d or crop_h > h or crop_w > w: # raise ValueError("Crop size must be smaller than the original array size") # start_d = np.random.randint(0, d - crop_d + 1) # start_h = np.random.randint(0, h - crop_h + 1) # start_w = np.random.randint(0, w - crop_w + 1) # cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] # return cropped_array def get_all_ROI(self): # Get all the ROI options. and remove the reduntant ones ROIs = [] for k in self.ALLdata_filtered.keys(): ROIs.append(self.ALLdata[k]['ROI']) ROIs = set(ROIs) return ROIs def find_min_dim(self): # Find the minimum dimension of the images min_dim = 100000 for k in self.ALLdata.keys(): value = self.ALLdata[k] if min(value['Size']) < min_dim: min_dim = min(value['Size']) return min_dim def get_ALLdata(self): # Return all data return self.ALLdata def get_filter_modality(self, key_words=None): ALLdata_filtered = self.ALLdata_filtered.copy() if key_words is not None: for k in self.ALLdata_filtered.keys(): if ALLdata_filtered[k]["Modality"] not in key_words: del ALLdata_filtered[k] return ALLdata_filtered def get_filter_ROI(self, key_word): # Filter out images with a key word ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if key_word not in k["ROI"]: del ALLdata_filtered[k] return ALLdata_filtered def get_key_by_ROI(self, key_word): # Get all the keys with a key word keys = [] for k in self.ALLdata_filtered.keys(): if key_word == self.ALLdata_filtered[k]["ROI"]: keys.append(k) return keys def get_filter_ROIs(self): ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: del ALLdata_filtered[k] return ALLdata_filtered def get_3D_volume(self, volume, select_channel = None): if self.reverse_axis_order: volume = reverse_axis_order(volume) if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] return volume def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): img_sz = self.ALLdata[k]['Size'][:self.ndims] del_flag = False del_flag = del_flag or min(img_sz) < self.out_sz # print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}") # print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}") del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0] del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio) # del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio) if del_flag: del ALLdata[k] return ALLdata def __getitem__(self,idx): key = list(self.ALLdata_filtered.keys())[idx] volume_A = sitk.ReadImage(key) volume_A = sitk.GetArrayFromImage(volume_A) paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI']) paired_key = random.choice(paired_keys) volume_B = sitk.ReadImage(paired_key) volume_B = sitk.GetArrayFromImage(volume_B) # if volume_A.ndim == 4 or volume_B.ndim == 4: volume_A = self.get_3D_volume(volume_A) volume_B = self.get_3D_volume(volume_B) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1]) volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1]) volume_A = self.normalize(volume_A) volume_B = self.normalize(volume_B) if self.min_crop_ratio is not None: # print(f'before volume_shape: {volume.shape}') crop_ratio = np.random.uniform(self.min_crop_ratio, 1) crop_size_A = int(min(volume_A.shape) * crop_ratio) crop_size_B = int(min(volume_B.shape) * crop_ratio) # crop_size_A = int(max(volume_A.shape) * crop_ratio) # crop_size_B = int(max(volume_B.shape) * crop_ratio) volume_A = self.random_crop_3d(volume_A, crop_size_A) volume_B = self.random_crop_3d(volume_B, crop_size_B) volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) else: volume_A = self.random_crop_3d(volume_A, self.out_sz) volume_B = self.random_crop_3d(volume_B, self.out_sz) volume_A = volume_A[None, :, :, :] volume_B = volume_B[None, :, :, :] if self.transform is not None: return self.transform(volume_A), self.transform(volume_B) # print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name']) return volume_A, volume_B def __len__(self): return len(self.ALLdata_filtered.keys()) class OMDataset_pair(Dataset): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = indivi_ROI_list, modality = None, reverse_axis_order = False): # self.mappings = mapping_files self.ALLdata = self.combine_data(mappings=mapping_files) self.out_sz = out_sz self.max_sz = out_sz*8 self.sz_range = get_sizeRange_dict() self.min_dim_ratio = 0.7 self.reverse_axis_order = reverse_axis_order self.min_crop_ratio = min_crop_ratio self.transform = transform self.clamp_range = clamp_range self.ndims = 3 # Start you filtering here # print(f"Number of images before filtering: {len(self.ALLdata.keys())}") print(f"Registration mode: Total data size before filtering: {len(self.ALLdata)}") self.ALLdata_filtered = self.get_filter_mindim() # print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}") self.ALLdata_filtered = self.get_filter_modality(modality) # print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}") if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data self.ROIs = self.get_all_ROI() else: self.ROIs = ROIs self.ALLdata_filtered = self.get_filter_ROIs() print(f"Registration mode: Number of images after filtering: {len(self.ALLdata_filtered.keys())}") # filtering ends here def combine_data(self, mappings = mapping_files): ALLdata = {} total_entries = 0 total_skipped = 0 for j in mappings.keys(): with open(mappings[j], 'r') as f: mappings_tmp = json.load(f) skipped = 0 for k, v in mappings_tmp.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings_tmp) - skipped total_entries += len(mappings_tmp) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def random_crop_3d(self, volume, crop_size=None): # Fast random crop with optional padding using NumPy d, h, w = volume.shape if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # Only pad if needed (avoid np.pad if not necessary) pad_d = max(0, crop_d - d) pad_h = max(0, crop_h - h) pad_w = max(0, crop_w - w) if pad_d or pad_h or pad_w: pad_width = ( (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), ) volume = np.pad(volume, pad_width, mode='constant', constant_values=0) d, h, w = volume.shape # Crop indices start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 # Use NumPy slicing (very fast) return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] # def random_crop_3d(self, volume, crop_size=None): # # Randomly crop the image # d, h, w = volume.shape # if crop_size is None: # crop_size = self.out_sz # crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # if crop_d > d or crop_h > h or crop_w > w: # raise ValueError("Crop size must be smaller than the original array size") # start_d = np.random.randint(0, d - crop_d + 1) # start_h = np.random.randint(0, h - crop_h + 1) # start_w = np.random.randint(0, w - crop_w + 1) # cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] # return cropped_array def get_all_ROI(self): # Get all the ROI options. and remove the reduntant ones ROIs = [] for k in self.ALLdata_filtered.keys(): ROIs.append(self.ALLdata[k]['ROI']) ROIs = set(ROIs) return ROIs def find_min_dim(self): # Find the minimum dimension of the images min_dim = 100000 for k in self.ALLdata.keys(): value = self.ALLdata[k] if min(value['Size']) < min_dim: min_dim = min(value['Size']) return min_dim def get_ALLdata(self): # Return all data return self.ALLdata def get_filter_modality(self, key_words=None): ALLdata_filtered = self.ALLdata_filtered.copy() if key_words is not None: for k in self.ALLdata_filtered.keys(): if ALLdata_filtered[k]["Modality"] not in key_words: del ALLdata_filtered[k] return ALLdata_filtered def get_filter_ROI(self, key_word): # Filter out images with a key word ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if key_word not in k["ROI"]: del ALLdata_filtered[k] return ALLdata_filtered def get_key_by_ROI(self, key_word): # Get all the keys with a key word keys = [] for k in self.ALLdata_filtered.keys(): if key_word == self.ALLdata_filtered[k]["ROI"]: keys.append(k) return keys def filter_keys_by_xx(self, key_word, keys=None, term="ROI"): # Filter out images with a key word filtered_keys = [] if keys is None: keys = self.ALLdata_filtered.keys() for k in keys: value = self.ALLdata_filtered[k].get(term, None) if value is not None and key_word == value: filtered_keys.append(k) return filtered_keys def get_filter_ROIs(self): ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: del ALLdata_filtered[k] return ALLdata_filtered def get_3D_volume(self, volume, select_channel = None): if self.reverse_axis_order: volume = reverse_axis_order(volume) if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] return volume def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): img_sz = self.ALLdata[k]['Size'][:self.ndims] del_flag = False del_flag = del_flag or min(img_sz) < self.out_sz # print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}") # print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}") del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0] del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio) # del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio) if del_flag: del ALLdata[k] return ALLdata def __getitem__(self,idx): key = list(self.ALLdata_filtered.keys())[idx] volume_A = sitk.ReadImage(key) volume_A = sitk.GetArrayFromImage(volume_A) embd_A = self.ALLdata_filtered[key]['embd'] embd_A = np.array(embd_A, dtype=np.float32) all_keys = list(self.ALLdata_filtered.keys()) paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['ROI'], all_keys, term="ROI") paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['Modality'], paired_keys, term="Modality") # paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI']) paired_key = random.choice(paired_keys) # print(f"Key: {key}, Paired Key: {paired_key}") # print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}") volume_B = sitk.ReadImage(paired_key) volume_B = sitk.GetArrayFromImage(volume_B) embd_B = self.ALLdata_filtered[paired_key]['embd'] embd_B = np.array(embd_B, dtype=np.float32) # if volume_A.ndim == 4 or volume_B.ndim == 4: volume_A = self.get_3D_volume(volume_A) volume_B = self.get_3D_volume(volume_B) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1]) volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1]) volume_A = self.normalize(volume_A) volume_B = self.normalize(volume_B) if self.min_crop_ratio is not None: # print(f'before volume_shape: {volume.shape}') crop_ratio = np.random.uniform(self.min_crop_ratio, 1) # crop_size_A = int(min(volume_A.shape) * crop_ratio) # crop_size_B = int(min(volume_B.shape) * crop_ratio) crop_size_A = int(max(volume_A.shape) * crop_ratio) crop_size_B = int(max(volume_B.shape) * crop_ratio) crop_size_A = min(crop_size_A, self.max_sz) crop_size_B = min(crop_size_B, self.max_sz) volume_A = self.random_crop_3d(volume_A, crop_size_A) volume_B = self.random_crop_3d(volume_B, crop_size_B) volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) else: volume_A = self.random_crop_3d(volume_A, self.out_sz) volume_B = self.random_crop_3d(volume_B, self.out_sz) volume_A = volume_A[None, :, :, :] volume_B = volume_B[None, :, :, :] if self.transform is not None: return self.transform(volume_A), self.transform(volume_B) # print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name']) return [volume_A, volume_B, embd_A, embd_B] def __len__(self): return len(self.ALLdata_filtered.keys()) class OminiDataset_paired_inf(object): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, ROIs = None): # self.mappings = mapping_files self.ALLdata = self.combine_data(mappings=mapping_files) self.out_sz = out_sz self.min_crop_ratio = min_crop_ratio self.transform = transform self.clamp_range = clamp_range self.ndims = 3 # Start you filtering here: # filter out images with dimensions less than min_dim self.ALLdata_filtered = self.get_filter_mindim() # filter out images with ROIs that are not in the provided ROIs if ROIs is None: self.ROIs = self.get_all_ROI() else: self.ROIs = ROIs self.ALLdata_filtered = self.get_filter_ROIs() # filtering ends here self.roi_scan_mapping = self.build_ROI_scan_mapping() self.keys_dist, self.total = self.get_keys_dist() def get_all_ROI(self): # Get all the ROI options. and remove the reduntant ones ROIs = [] for k in self.ALLdata_filtered.keys(): ROIs.append(self.ALLdata[k]['ROI']) ROIs = set(ROIs) return ROIs def get_ALLdata(self): # Return all data return self.ALLdata def combine_data(self, mappings = mapping_files): ALLdata = {} total_entries = 0 total_skipped = 0 for j in mappings.keys(): with open(mappings[j], 'r') as f: mappings_tmp = json.load(f) skipped = 0 for k, v in mappings_tmp.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings_tmp) - skipped total_entries += len(mappings_tmp) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def __len__(self): return len(self.ALLdata_filtered.keys()) def random_crop_3d(self, volume, crop_size=None): # Randomly crop the image d, h, w = volume.shape if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size if crop_d > d or crop_h > h or crop_w > w: raise ValueError("Crop size must be smaller than the original array size") start_d = np.random.randint(0, d - crop_d + 1) start_h = np.random.randint(0, h - crop_h + 1) start_w = np.random.randint(0, w - crop_w + 1) cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] return cropped_array def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def get_3D_volume(self, volume, select_channel = None): volume = reverse_axis_order(volume) if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] return volume def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: del ALLdata[k] return ALLdata def get_filter_ROI(self, key_word): # Filter out images with a key word ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if key_word not in k["ROI"]: del ALLdata_filtered[k] return ALLdata_filtered def get_filter_ROIs(self): ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: del ALLdata_filtered[k] return ALLdata_filtered def get_keys_dist(self): ROIs = self.get_all_ROI() keys_dist = {} total = 0 for item in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[item]['ROI'] not in keys_dist: keys_dist[self.ALLdata_filtered[item]['ROI']] = 0 keys_dist[self.ALLdata_filtered[item]['ROI']] += 1 return keys_dist, total def build_ROI_scan_mapping(self): # Build a mapping of ROIs to scans ROI_scan_mapping = {} for item in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping: ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = [] ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item) return ROI_scan_mapping def get_random_2_items(self, mode = 'uniform'): # Get a random pair of items from the dataset with the same ROI if mode == 'uniform': idx = random.randint(0, len(self.keys_dist.keys()) - 1) key = list(self.keys_dist.keys())[idx] path_1 = random.choice(self.roi_scan_mapping[key]) path_2 = random.choice(self.roi_scan_mapping[key]) volume_A = sitk.ReadImage(path_1) volume_A = sitk.GetArrayFromImage(volume_A) volume_B = sitk.ReadImage(path_2) volume_B = sitk.GetArrayFromImage(volume_B) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1]) volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1]) volume_A = self.normalize(volume_A) volume_B = self.normalize(volume_B) if self.min_crop_ratio is not None: crop_ratio = np.random.uniform(self.min_crop_ratio, 1) crop_size_A = int(min(volume_A.shape) * crop_ratio) crop_size_B = int(min(volume_B.shape) * crop_ratio) volume_A = self.random_crop_3d(volume_A, crop_size_A) volume_B = self.random_crop_3d(volume_B, crop_size_B) volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) else: volume_A = self.radndom_crop_3d(volume_A, self.out_sz) volume_B = self.radndom_crop_3d(volume_B, self.out_sz) volume_A = volume_A[None, :, :, :] volume_B = volume_B[None, :, :, :] if self.transform is not None: return self.transform(volume_A), self.transform(volume_B) return volume_A, volume_B elif mode == 'original': pass def build_batch(self, batch_size = 2): batch_1 = [] batch_2 = [] for i in range(batch_size): V_a, V_b = self.get_random_2_items() batch_1.append(V_a) batch_2.append(V_b) return np.array(batch_1), np.array(batch_2) class OminiDataset_inference_w_all(object): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = None, label_key = ['brain'], task_key = 'segmentation', database = None, select_channels_dict = {}): self.mappings = mapping_files # database=['MSD', 'TotalSegmentor'] if database is not None: self.mappings = {db: self.mappings[db] for db in database if db in self.mappings} # select_channels_dict={ # "ImgDict":["ed","es"] # } self.select_channels_dict = select_channels_dict self.ALLdata = self.combine_data(mappings=self.mappings) self.out_sz = out_sz self.label_key = label_key self.min_crop_ratio = min_crop_ratio self.transform = transform self.clamp_range = clamp_range self.ndims = 3 self.is_reverse_axis_order = True # for inference, always reverse axis order (nifty is reverse order than numpy) # Start you filtering here: # self.ALLdata_filtered = self.ALLdata.copy() # filter out images with dimensions less than min_dim self.ALLdata_filtered = self.get_filter_mindim() # filter out images with ROIs that are not in the provided ROIs if ROIs is None: self.ROIs = self.get_all_ROI() else: self.ROIs = ROIs self.ALLdata_filtered = self.get_filter_ROIs() self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key) # filtering ends here self.roi_scan_mapping = self.build_ROI_scan_mapping() self.keys_dist, self.total = self.get_keys_dist() def get_all_ROI(self): # Get all the ROI options. and remove the reduntant ones ROIs = [] for k in self.ALLdata_filtered.keys(): ROIs.append(self.ALLdata[k]['ROI']) ROIs = set(ROIs) return ROIs def get_keys_dist(self): ROIs = self.get_all_ROI() keys_dist = {} total = 0 for item in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[item]['ROI'] not in keys_dist: keys_dist[self.ALLdata_filtered[item]['ROI']] = 0 keys_dist[self.ALLdata_filtered[item]['ROI']] += 1 return keys_dist, total def build_ROI_scan_mapping(self): # Build a mapping of ROIs to scans ROI_scan_mapping = {} for item in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping: ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = [] ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item) return ROI_scan_mapping def get_3D_volume(self, volume, select_channel = None): volume = reverse_axis_order(volume) if self.is_reverse_axis_order else volume if volume.ndim == 4: if select_channel is None: select_channel = np.random.randint(0, volume.shape[3] - 1) volume = volume[:, :, :, select_channel] # print(f"Volume shape: {volume.shape}, selected channel: {select_channel}") return volume def get_filter_mindim(self): # Filter out images with dimensions less than min_dim # Top priority is to filter out images with dimensions less than min_dim ALLdata = self.ALLdata.copy() for k in self.ALLdata.keys(): if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: del ALLdata[k] return ALLdata def find_min_dim(self): # Find the minimum dimension of the images min_dim = 100000 for k in self.ALLdata.keys(): value = self.ALLdata[k] if min(value['Size']) < min_dim: min_dim = min(value['Size']) return min_dim # def combine_data(self): # ALLdata = {} # for j in self.mappings.keys(): # with open(self.mappings[j], 'r') as f: # mappings = json.load(f) # ALLdata.update(mappings) # return ALLdata def combine_data(self, mappings = mapping_files): ALLdata = {} total_entries = 0 total_skipped = 0 for j in mappings.keys(): with open(mappings[j], 'r') as f: mappings_tmp = json.load(f) skipped = 0 for k, v in mappings_tmp.items(): if not os.path.exists(k) or os.path.getsize(k) == 0: skipped += 1 continue ALLdata[k] = v accessible = len(mappings_tmp) - skipped total_entries += len(mappings_tmp) total_skipped += skipped if skipped > 0: print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") if total_skipped > 0: print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") if len(ALLdata) < 1000: print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") return ALLdata def normalize(self, volume, eps=1e-7): # Normalize the image (0-1) volume = volume.astype(np.float64) volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) return volume def get_key_by_ROI(self, key_word): # Get all the keys with a key word keys = [] for k in self.ALLdata_filtered.keys(): if key_word == self.ALLdata_filtered[k]["ROI"]: keys.append(k) return keys def get_filter_task(self, task_key = 'segmentation'): # Filter out images with task type that are not in the provided labels_path ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if 'Label_path' not in self.ALLdata_filtered[k] or task_key not in self.ALLdata_filtered[k]['Label_path']: del ALLdata_filtered[k] Warning(f"Label path not found for {k} with task key {task_key}. This image will be removed from the dataset.") return ALLdata_filtered def get_filter_labels(self, task_key='segmentation', label_keys=['heart']): # Filter out images where 'Label_path' does not contain any of the label_keys for the given task_key ALLdata_filtered = self.ALLdata_filtered.copy() keys_to_remove = [] for k in list(ALLdata_filtered.keys()): label_path = ALLdata_filtered[k].get('Label_path', {}) task_labels = label_path.get(task_key, {}) # Check if any label_keys are present in task_labels # print(f"Checking {k} for task key {task_labels.keys()} with label keys {label_keys}") has_any_label = any((tk in label_keys) for tk in task_labels.keys()) # print(f"Has any label: {has_any_label}") if not has_any_label: keys_to_remove.append(k) # print(f"Label path not found for {k} with task key {task_key} and label keys {label_keys}. This image will be removed from the dataset.") for k in keys_to_remove: del ALLdata_filtered[k] return ALLdata_filtered def get_random_pad_crop_params(self, volume_shape, crop_size=None, random=True): # Get random padding and cropping parameters for a given shape d, h, w = volume_shape[:3] if crop_size is None: crop_size = self.out_sz crop_d, crop_h, crop_w = crop_size, crop_size, crop_size # Calculate padding pad_width = [] for size, crop in zip((d, h, w), (crop_d, crop_h, crop_w)): if crop > size: total_pad = crop - size pad_before = np.random.randint(0, total_pad + 1) pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) else: pad_width.append((0, 0)) # Update shape after padding d_p, h_p, w_p = d + pad_width[0][0] + pad_width[0][1], h + pad_width[1][0] + pad_width[1][1], w + pad_width[2][0] + pad_width[2][1] if random: # Calculate cropping start indices (random crop) start_d = np.random.randint(0, d_p - crop_d + 1) if d_p > crop_d else 0 start_h = np.random.randint(0, h_p - crop_h + 1) if h_p > crop_h else 0 start_w = np.random.randint(0, w_p - crop_w + 1) if w_p > crop_w else 0 else: # Calculate cropping start indices (center crop) start_d = max((d_p - crop_d) // 2, 0) start_h = max((h_p - crop_h) // 2, 0) start_w = max((w_p - crop_w) // 2, 0) crop_slices = (start_d, start_h, start_w, crop_d, crop_h, crop_w) return pad_width, crop_slices def apply_pad_crop(self, volume, pad_width, crop_slices): # Apply padding and cropping to the volume if any(pad != (0, 0) for pad in pad_width): volume = np.pad(volume, pad_width, mode='constant', constant_values=0) start_d, start_h, start_w, crop_d, crop_h, crop_w = crop_slices cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] return cropped_array def get_filter_ROIs(self): ALLdata_filtered = self.ALLdata_filtered.copy() for k in self.ALLdata_filtered.keys(): if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: del ALLdata_filtered[k] return ALLdata_filtered def get_channel_ids(self, key): """ Get the indices where ImgDict values match the selected channels (e.g., 'ed', 'es'). Returns: list: List of integer indices matching the selected channels """ img_dict = self.ALLdata_filtered[key].get("ImgDict", {}) selected_values = self.select_channels_dict.get("ImgDict", []) # Build reverse mapping: value -> index value_to_idx = {value: int(idx) for idx, value in img_dict.items()} # Get indices in the order of selected_values indices = [ value_to_idx[val] for val in selected_values if val in value_to_idx ] return indices # return sorted(indices) def __len__(self): return len(self.ALLdata_filtered.keys()) def __getitem__(self, idx): key = list(self.ALLdata_filtered.keys())[idx] return_dict = dict() print(f"Processing key: {key}") volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) if volume.ndim == 4: channel_ids = self.get_channel_ids(key) if len(channel_ids) == 0: # warning message that this key has no matching channels Warning(f"No matching channels found for key: {key} with ImgDict: {self.ALLdata_filtered[key].get('ImgDict', {})} and selected channels: {self.select_channels_dict.get('ImgDict', [])}. Using random channel.") channel_id = None else: channel_id=channel_ids[0] volume = self.get_3D_volume(volume, select_channel = channel_id) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) volume = self.normalize(volume) crop_ratio = np.random.uniform(self.min_crop_ratio, 1) crop_size = int(max(volume.shape) * crop_ratio) pad_width, crop_slices = self.get_random_pad_crop_params(volume.shape, crop_size) # print(f"Pad width: {pad_width}, Crop slices: {crop_slices}, Original shape: {volume.shape}") volume = self.apply_pad_crop(volume, pad_width, crop_slices) label_dict = dict() if 'Label_path' in self.ALLdata_filtered[key]: for lk in self.label_key: if lk in self.ALLdata_filtered[key]['Label_path']['segmentation'].keys(): label = sitk.ReadImage(self.ALLdata_filtered[key]['Label_path']['segmentation'][lk]) label = sitk.GetArrayFromImage(label) # print(f"Label shape: {label.shape}, key: {key}, label key: {lk}") label = reverse_axis_order(label) if self.is_reverse_axis_order else label # print(f"Label shape: {label.shape}, key: {key}, label key: {lk}") if label.ndim > self.ndims: if len(channel_ids) != 0: label = label[...,channel_ids] # assuming channel last pad_width_lab = pad_width + [(0,0)]*(label.ndim - self.ndims) # print(f"Label with channels, pad_width_lab: {pad_width_lab}") else: pad_width_lab = pad_width label = self.apply_pad_crop(label, pad_width_lab, crop_slices) # print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}") label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0) if label.ndim > self.ndims: if self.ndims==3: label_dict[lk] = np.transpose(label_dict[lk], (3,0,1,2)) # assuming channel last elif self.ndims==4: label_dict[lk] = np.transpose(label_dict[lk], (4,0,1,2,3)) # assuming channel last # print(f"After resize, label shape: {label_dict[lk].shape}, key: {key}, label key: {lk}") else: label_dict[lk] = np.full([self.out_sz]*self.ndims, -1) Warning(f"Label path not found for {key} with label key {lk}.") label_dict[lk] = label_dict[lk][None, :, :, :] if label_dict[lk].ndim == 3 else label_dict[lk] else: for lk in self.label_key: label_dict[lk] = np.full([self.out_sz]*self.ndims, -1) Warning(f"Label path not found for {key} with label key {lk}.") label_dict[lk] = label_dict[lk][None, :, :, :] volume =resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) # return_dict['labels'] = label_dict return_dict['labels'] = np.concatenate([v for v in label_dict.values()], axis=1) return_dict['img'] = volume[None, :, :, :] return_dict['label_channels'] = list(self.select_channels_dict.get("ImgDict", [])) return return_dict class OminiDataset_bertembd(OminiDataset): def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.85, ROIs = None, modality = None, reverse_axis_order = False, min_dim = 3, mapping_files = mapping_files): super().__init__(out_sz = out_sz, transform = transform, clamp_range = clamp_range, min_crop_ratio = min_crop_ratio, ROIs = ROIs, modality = modality, reverse_axis_order = reverse_axis_order, min_dim = min_dim, mapping_files=mapping_files) # start you filtering here self.ALLdata_filtered = self.get_filter_mindim() if ROIs is None: # if no ROIs are provided, get all the ROIs from filtered data self.ROIs = self.get_all_ROI() else: self.ROIs = ROIs self.ALLdata_filtered = self.get_filter_ROIs() # self.ALLdata_filtered = self.filter_embd() # self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key) # end your filtering here def __getitem__(self, idx): key = list(self.ALLdata_filtered.keys())[idx] embd = self.ALLdata_filtered[key]['embd'] if 0: print(key) volume = sitk.ReadImage(key) volume = sitk.GetArrayFromImage(volume) volume = self.get_3D_volume(volume) if self.clamp_range is not None: modality = self.ALLdata_filtered[key].get("Modality", None) if modality == "CT": volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) volume = self.normalize(volume) if self.min_crop_ratio is not None: crop_ratio = np.random.uniform(self.min_crop_ratio, 1) crop_size = int(max(volume.shape) * crop_ratio) volume = self.random_crop_3d(volume, crop_size) volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) else: volume = self.random_crop_3d(volume, self.out_sz) volume = volume[None, :, :, :] if self.transform is not None: return self.transform(volume) return volume,np.array(embd) def __len__(self): return len(self.ALLdata_filtered.keys()) def filter_embd(self): for k in self.ALLdata_filtered.keys(): if 'BERT_embedding_keys' not in self.ALLdata_filtered[k]['Metadata']: del self.ALLdata_filtered[k] return self.ALLdata_filtered