| import os, sys |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import json |
| import SimpleITK as sitk |
| import numpy as np |
| from skimage.transform import rescale, resize, downscale_local_mean |
| |
| |
| sys.path.append(ROOT_DIR) |
| from Dataloader.dataloader_utils import * |
| import random |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| mapping_files = { |
| 'MSD': 'nifty_mappings/MSD_mappings.json', |
| 'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json', |
| 'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json', |
| 'CancerImageArchive': 'nifty_mappings/CIA_mappings.json', |
| 'MnMs': 'nifty_mappings/MnMs_mappings.json', |
| |
| 'Brats2020': 'nifty_mappings/Brats2020_mappings.json', |
| 'Brats2021': 'nifty_mappings/Brats2021_mappings.json', |
| 'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json', |
| 'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json', |
| 'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json', |
| 'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json', |
| 'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json', |
| 'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json', |
| 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_KL_mappings.json', |
| |
| } |
| for k,v in mapping_files.items(): |
| mapping_files[k] = os.path.join(ROOT_DIR, v) |
|
|
| CLAMP_RANGE = [-400, 400] |
|
|
| indivi_ROI_list = ['abdomen','arm','brain','hand','head','leg','neck','pelvis','skeleton','thorax'] |
|
|
| def reverse_axis_order(arr): |
| """SimpleITK to NumPy axis order conversion.""" |
| |
| return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1]))) |
|
|
| def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high'): |
| """Sample a random value from a uniform distribution with multiple orders. |
| |
| Args: |
| high (float): Upper bound of the uniform distribution. |
| low (float): Lower bound of the uniform distribution. |
| order_num (int): Number of times to sample. |
| type (str): 'high' or 'low', determines the sampling direction. |
| |
| Returns: |
| sample_value (float): The sampled value after multiple orders. |
| |
| Notes: |
| - If type is 'high', samples are drawn iteratively from [low, high], each time using the previous sample as the new lower bound. |
| - If type is 'low', samples are drawn iteratively from [low, high], each time using the previous sample as the new upper bound. |
| - If order_num is 0, returns the low value. |
| - If order_num is 1, returns a single random value from the uniform distribution. |
| - If order_num is 2, returns a value from a linear distribution. |
| - If order_num is 3, returns a value from a quadratic distribution. |
| """ |
| if type == 'high': |
| sample_value = low |
| for _ in range(order_num): |
| sample_value = np.random.uniform(low=sample_value, high=high) |
| elif type == 'low': |
| sample_value = high |
| for _ in range(order_num): |
| sample_value = np.random.uniform(low, high=sample_value) |
| return sample_value |
|
|
| class OminiDataset(object): |
| """Base class for OmniMorph datasets.""" |
| def __init__(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files): |
| |
| |
| self.ALLdata = self.combine_data(mappings = mapping_files) |
| self.out_sz = out_sz |
| self.reverse_axis_order = reverse_axis_order |
| self.min_dim = min_dim |
| self.clamp_range = clamp_range |
| self.min_crop_ratio = min_crop_ratio |
| self.transform = transform |
| self.ndims = 3 |
| |
| def get_ALLdata(self): |
| return self.ALLdata |
| |
| def get_all_ROI(self): |
| |
| ROIs = [] |
| |
| for k in self.ALLdata_filtered.keys(): |
| ROIs.append(self.ALLdata[k]['ROI']) |
| ROIs = set(ROIs) |
| return ROIs |
| |
| def get_filter_ROIs(self,keep_single_roi=False): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| |
| |
| |
| |
| |
| for k in ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def combine_data(self, mappings = mapping_files): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in mappings.keys(): |
| with open(mappings[j], 'r') as f: |
| mappings_tmp = json.load(f) |
| skipped = 0 |
| for k, v in mappings_tmp.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings_tmp) - skipped |
| total_entries += len(mappings_tmp) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
| |
| def get_3D_volume(self, volume, select_channel = None): |
| |
| if self.reverse_axis_order: |
| volume = reverse_axis_order(volume) |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| return volume |
| |
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: |
| del ALLdata[k] |
| return ALLdata |
|
|
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
| |
| def random_crop_3d(self, volume, crop_size=None): |
| |
| d, h, w = volume.shape |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| |
| pad_d = max(0, crop_d - d) |
| pad_h = max(0, crop_h - h) |
| pad_w = max(0, crop_w - w) |
| if pad_d or pad_h or pad_w: |
| pad_width = ( |
| (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), |
| (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), |
| (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), |
| ) |
| volume = np.pad(volume, pad_width, mode='constant', constant_values=0) |
| d, h, w = volume.shape |
|
|
| |
| start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 |
| start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 |
| start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 |
|
|
| |
| return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
|
|
| class OminiDataset_v1(Dataset): |
| def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.2, reverse_axis_order = False): |
| self.mappings = mapping_files |
| self.ALLdata = self.combine_data() |
| self.out_sz = out_sz |
| self.reverse_axis_order = reverse_axis_order |
| self.min_crop_ratio = min_crop_ratio |
| self.crop_ratio_sample_order = 2 |
| self.transform = transform |
| self.clamp_range = clamp_range |
| self.ndims = 3 |
| |
| self.ALLdata_filtered = self.get_filter_mindim() |
| |
|
|
| |
| |
| def find_min_dim(self): |
| |
| min_dim = 100000 |
| for k in self.ALLdata.keys(): |
| value = self.ALLdata[k] |
| if min(value['Size']) < min_dim: |
| min_dim = min(value['Size']) |
| return min_dim |
| |
| def random_crop_3d(self, volume, crop_size=None): |
| |
| d, h, w = volume.shape |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| |
| pad_d = max(0, crop_d - d) |
| pad_h = max(0, crop_h - h) |
| pad_w = max(0, crop_w - w) |
| if pad_d or pad_h or pad_w: |
| pad_width = ( |
| (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), |
| (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), |
| (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), |
| ) |
| volume = np.pad(volume, pad_width, mode='constant', constant_values=0) |
| d, h, w = volume.shape |
|
|
| |
| start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 |
| start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 |
| start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 |
|
|
| |
| return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
| |
| def get_ALLdata(self): |
| |
| return self.ALLdata |
| |
| def get_3D_volume(self, volume, select_channel = None): |
| if self.reverse_axis_order: |
| volume = reverse_axis_order(volume) |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| |
| return volume |
| |
| def get_filter_ROI(self, key_word): |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if key_word not in k["ROI"]: |
| del ALLdata[k] |
| return ALLdata |
|
|
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: |
| del ALLdata[k] |
| return ALLdata |
| |
| def combine_data(self): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in self.mappings.keys(): |
| with open(self.mappings[j], 'r') as f: |
| mappings = json.load(f) |
| skipped = 0 |
| for k, v in mappings.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings) - skipped |
| total_entries += len(mappings) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
|
|
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
|
|
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
|
|
| def __getitem__(self, idx): |
| key = list(self.ALLdata_filtered.keys())[idx] |
| if 0: |
| print(key) |
| volume = sitk.ReadImage(key) |
| volume = sitk.GetArrayFromImage(volume) |
| |
| volume = self.get_3D_volume(volume) |
| |
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) |
| volume = self.normalize(volume) |
| |
| if self.min_crop_ratio is not None: |
| |
| |
| crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high') |
| |
| crop_size = int(max(volume.shape) * crop_ratio) |
| volume = self.random_crop_3d(volume, crop_size) |
| volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
|
|
| |
| else: |
| volume = self.random_crop_3d(volume, self.out_sz) |
| volume = volume[None, :, :, :] |
|
|
| if self.transform is not None: |
| return self.transform(volume) |
|
|
| return volume |
| |
| class OMDataset_indiv(Dataset): |
| def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, reverse_axis_order = False): |
| |
| self.ALLdata = self.combine_data(mappings=mapping_files) |
| self.out_sz = out_sz |
| self.max_sz = out_sz*8 |
| self.reverse_axis_order = reverse_axis_order |
| self.min_crop_ratio = min_crop_ratio |
| self.crop_ratio_sample_order = 2 |
| self.transform = transform |
| self.clamp_range = clamp_range |
| self.ndims = 3 |
|
|
| |
| |
| print(f"Diffusion mode: Total data size before filtering: {len(self.ALLdata)}") |
| self.ALLdata_filtered = self.get_filter_mindim() |
| print(f"Diffusion mode: Filtered data size: {len(self.ALLdata_filtered)}") |
| |
| |
| |
| |
| def find_min_dim(self): |
| |
| min_dim = 100000 |
| for k in self.ALLdata.keys(): |
| value = self.ALLdata[k] |
| if min(value['Size']) < min_dim: |
| min_dim = min(value['Size']) |
| return min_dim |
| |
| def random_crop_3d(self, volume, crop_size=None): |
| |
| d, h, w = volume.shape |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| |
| pad_d = max(0, crop_d - d) |
| pad_h = max(0, crop_h - h) |
| pad_w = max(0, crop_w - w) |
| if pad_d or pad_h or pad_w: |
| pad_width = ( |
| (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), |
| (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), |
| (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), |
| ) |
| volume = np.pad(volume, pad_width, mode='constant', constant_values=0) |
| d, h, w = volume.shape |
|
|
| |
| start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 |
| start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 |
| start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 |
|
|
| |
| return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
| |
| def get_ALLdata(self): |
| |
| return self.ALLdata |
| |
| def get_3D_volume(self, volume, select_channel = None): |
| if self.reverse_axis_order: |
| volume = reverse_axis_order(volume) |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| |
| return volume |
| |
| def get_filter_ROI(self, key_word): |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if key_word not in k["ROI"]: |
| del ALLdata[k] |
| return ALLdata |
|
|
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: |
| del ALLdata[k] |
| return ALLdata |
| |
| def combine_data(self, mappings = mapping_files): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in mappings.keys(): |
| with open(mappings[j], 'r') as f: |
| mappings_tmp = json.load(f) |
| skipped = 0 |
| for k, v in mappings_tmp.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings_tmp) - skipped |
| total_entries += len(mappings_tmp) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
|
|
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
|
|
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
|
|
| def __getitem__(self, idx): |
| key = list(self.ALLdata_filtered.keys())[idx] |
| embd = self.ALLdata_filtered[key]['embd'] |
| embd = np.array(embd, dtype=np.float32) |
|
|
| if 0: |
| print(key) |
| volume = sitk.ReadImage(key) |
| volume = sitk.GetArrayFromImage(volume) |
| |
| volume = self.get_3D_volume(volume) |
| |
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) |
| volume = self.normalize(volume) |
| |
| if self.min_crop_ratio is not None: |
| |
| |
| crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high') |
| |
| crop_size = int(max(volume.shape) * crop_ratio) |
| crop_size = min(crop_size, self.max_sz) |
| volume = self.random_crop_3d(volume, crop_size) |
| volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
|
|
| |
| else: |
| volume = self.random_crop_3d(volume, self.out_sz) |
| volume = volume[None, :, :, :] |
|
|
| if self.transform is not None: |
| return self.transform(volume) |
|
|
| return [volume, embd] |
| |
| class OminiDataset_paired(Dataset): |
| def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.85, ROIs = None, modality = None, reverse_axis_order = False): |
| |
| self.ALLdata = self.combine_data(mappings=mapping_files) |
| self.out_sz = out_sz |
| self.sz_range = get_sizeRange_dict() |
| self.min_dim_ratio = 0.5 |
| self.reverse_axis_order = reverse_axis_order |
| self.min_crop_ratio = min_crop_ratio |
| self.transform = transform |
| self.clamp_range = clamp_range |
| self.ndims = 3 |
| |
| |
| self.ALLdata_filtered = self.get_filter_mindim() |
| |
| self.ALLdata_filtered = self.get_filter_modality(modality) |
| |
| if ROIs is None: |
| self.ROIs = self.get_all_ROI() |
| else: |
| self.ROIs = ROIs |
| self.ALLdata_filtered = self.get_filter_ROIs() |
| |
| |
| |
|
|
|
|
| def combine_data(self, mappings = mapping_files): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in mappings.keys(): |
| with open(mappings[j], 'r') as f: |
| mappings_tmp = json.load(f) |
| skipped = 0 |
| for k, v in mappings_tmp.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings_tmp) - skipped |
| total_entries += len(mappings_tmp) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
| |
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
| |
| def random_crop_3d(self, volume, crop_size=None): |
| |
| d, h, w = volume.shape |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| |
| pad_d = max(0, crop_d - d) |
| pad_h = max(0, crop_h - h) |
| pad_w = max(0, crop_w - w) |
| if pad_d or pad_h or pad_w: |
| pad_width = ( |
| (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), |
| (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), |
| (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), |
| ) |
| volume = np.pad(volume, pad_width, mode='constant', constant_values=0) |
| d, h, w = volume.shape |
|
|
| |
| start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 |
| start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 |
| start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 |
|
|
| |
| return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| def get_all_ROI(self): |
| |
| ROIs = [] |
| for k in self.ALLdata_filtered.keys(): |
| ROIs.append(self.ALLdata[k]['ROI']) |
| ROIs = set(ROIs) |
| return ROIs |
| |
| def find_min_dim(self): |
| |
| min_dim = 100000 |
| for k in self.ALLdata.keys(): |
| value = self.ALLdata[k] |
| if min(value['Size']) < min_dim: |
| min_dim = min(value['Size']) |
| return min_dim |
| |
| def get_ALLdata(self): |
| |
| return self.ALLdata |
| |
| def get_filter_modality(self, key_words=None): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| if key_words is not None: |
| for k in self.ALLdata_filtered.keys(): |
| if ALLdata_filtered[k]["Modality"] not in key_words: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_filter_ROI(self, key_word): |
| |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if key_word not in k["ROI"]: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
| |
| def get_key_by_ROI(self, key_word): |
| |
| keys = [] |
| for k in self.ALLdata_filtered.keys(): |
| if key_word == self.ALLdata_filtered[k]["ROI"]: |
| keys.append(k) |
| return keys |
| |
| def get_filter_ROIs(self): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_3D_volume(self, volume, select_channel = None): |
| if self.reverse_axis_order: |
| volume = reverse_axis_order(volume) |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| return volume |
| |
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| img_sz = self.ALLdata[k]['Size'][:self.ndims] |
| del_flag = False |
| del_flag = del_flag or min(img_sz) < self.out_sz |
| |
| |
| del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0] |
| del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio) |
| |
| if del_flag: |
| del ALLdata[k] |
| return ALLdata |
|
|
| |
| |
| def __getitem__(self,idx): |
| key = list(self.ALLdata_filtered.keys())[idx] |
| volume_A = sitk.ReadImage(key) |
| volume_A = sitk.GetArrayFromImage(volume_A) |
| |
| paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI']) |
| paired_key = random.choice(paired_keys) |
| |
| volume_B = sitk.ReadImage(paired_key) |
| volume_B = sitk.GetArrayFromImage(volume_B) |
| |
| |
| volume_A = self.get_3D_volume(volume_A) |
| volume_B = self.get_3D_volume(volume_B) |
|
|
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1]) |
| volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1]) |
| volume_A = self.normalize(volume_A) |
| volume_B = self.normalize(volume_B) |
| |
| if self.min_crop_ratio is not None: |
| |
| |
| crop_ratio = np.random.uniform(self.min_crop_ratio, 1) |
| crop_size_A = int(min(volume_A.shape) * crop_ratio) |
| crop_size_B = int(min(volume_B.shape) * crop_ratio) |
| |
| |
| volume_A = self.random_crop_3d(volume_A, crop_size_A) |
| volume_B = self.random_crop_3d(volume_B, crop_size_B) |
| volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
| volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
|
|
| else: |
| volume_A = self.random_crop_3d(volume_A, self.out_sz) |
| volume_B = self.random_crop_3d(volume_B, self.out_sz) |
| volume_A = volume_A[None, :, :, :] |
| volume_B = volume_B[None, :, :, :] |
|
|
| if self.transform is not None: |
| return self.transform(volume_A), self.transform(volume_B) |
| |
| |
| return volume_A, volume_B |
|
|
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
|
|
| class OMDataset_pair(Dataset): |
| def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = indivi_ROI_list, modality = None, reverse_axis_order = False): |
| |
| self.ALLdata = self.combine_data(mappings=mapping_files) |
| self.out_sz = out_sz |
| self.max_sz = out_sz*8 |
| self.sz_range = get_sizeRange_dict() |
| self.min_dim_ratio = 0.7 |
| self.reverse_axis_order = reverse_axis_order |
| self.min_crop_ratio = min_crop_ratio |
| self.transform = transform |
| self.clamp_range = clamp_range |
| self.ndims = 3 |
| |
| |
| print(f"Registration mode: Total data size before filtering: {len(self.ALLdata)}") |
|
|
| self.ALLdata_filtered = self.get_filter_mindim() |
| |
| self.ALLdata_filtered = self.get_filter_modality(modality) |
| |
| if ROIs is None: |
| self.ROIs = self.get_all_ROI() |
| else: |
| self.ROIs = ROIs |
| self.ALLdata_filtered = self.get_filter_ROIs() |
| print(f"Registration mode: Number of images after filtering: {len(self.ALLdata_filtered.keys())}") |
| |
| |
|
|
|
|
| def combine_data(self, mappings = mapping_files): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in mappings.keys(): |
| with open(mappings[j], 'r') as f: |
| mappings_tmp = json.load(f) |
| skipped = 0 |
| for k, v in mappings_tmp.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings_tmp) - skipped |
| total_entries += len(mappings_tmp) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
| |
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
| |
| def random_crop_3d(self, volume, crop_size=None): |
| |
| d, h, w = volume.shape |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| |
| pad_d = max(0, crop_d - d) |
| pad_h = max(0, crop_h - h) |
| pad_w = max(0, crop_w - w) |
| if pad_d or pad_h or pad_w: |
| pad_width = ( |
| (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)), |
| (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)), |
| (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)), |
| ) |
| volume = np.pad(volume, pad_width, mode='constant', constant_values=0) |
| d, h, w = volume.shape |
|
|
| |
| start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0 |
| start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0 |
| start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0 |
|
|
| |
| return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| def get_all_ROI(self): |
| |
| ROIs = [] |
| for k in self.ALLdata_filtered.keys(): |
| ROIs.append(self.ALLdata[k]['ROI']) |
| ROIs = set(ROIs) |
| return ROIs |
| |
| def find_min_dim(self): |
| |
| min_dim = 100000 |
| for k in self.ALLdata.keys(): |
| value = self.ALLdata[k] |
| if min(value['Size']) < min_dim: |
| min_dim = min(value['Size']) |
| return min_dim |
| |
| def get_ALLdata(self): |
| |
| return self.ALLdata |
| |
| def get_filter_modality(self, key_words=None): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| if key_words is not None: |
| for k in self.ALLdata_filtered.keys(): |
| if ALLdata_filtered[k]["Modality"] not in key_words: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_filter_ROI(self, key_word): |
| |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if key_word not in k["ROI"]: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
| |
| def get_key_by_ROI(self, key_word): |
| |
| keys = [] |
| for k in self.ALLdata_filtered.keys(): |
| if key_word == self.ALLdata_filtered[k]["ROI"]: |
| keys.append(k) |
| return keys |
|
|
| def filter_keys_by_xx(self, key_word, keys=None, term="ROI"): |
| |
| filtered_keys = [] |
| if keys is None: |
| keys = self.ALLdata_filtered.keys() |
| for k in keys: |
| value = self.ALLdata_filtered[k].get(term, None) |
| if value is not None and key_word == value: |
| filtered_keys.append(k) |
| return filtered_keys |
| |
| def get_filter_ROIs(self): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_3D_volume(self, volume, select_channel = None): |
| if self.reverse_axis_order: |
| volume = reverse_axis_order(volume) |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| return volume |
| |
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| img_sz = self.ALLdata[k]['Size'][:self.ndims] |
| del_flag = False |
| del_flag = del_flag or min(img_sz) < self.out_sz |
| |
| |
| del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0] |
| del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio) |
| |
| if del_flag: |
| del ALLdata[k] |
| return ALLdata |
|
|
| |
| |
| def __getitem__(self,idx): |
| key = list(self.ALLdata_filtered.keys())[idx] |
| volume_A = sitk.ReadImage(key) |
| volume_A = sitk.GetArrayFromImage(volume_A) |
|
|
| embd_A = self.ALLdata_filtered[key]['embd'] |
| embd_A = np.array(embd_A, dtype=np.float32) |
| |
| all_keys = list(self.ALLdata_filtered.keys()) |
| paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['ROI'], all_keys, term="ROI") |
| paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['Modality'], paired_keys, term="Modality") |
| |
| |
| paired_key = random.choice(paired_keys) |
| |
| |
| |
|
|
|
|
| volume_B = sitk.ReadImage(paired_key) |
| volume_B = sitk.GetArrayFromImage(volume_B) |
|
|
| embd_B = self.ALLdata_filtered[paired_key]['embd'] |
| embd_B = np.array(embd_B, dtype=np.float32) |
|
|
| |
| volume_A = self.get_3D_volume(volume_A) |
| volume_B = self.get_3D_volume(volume_B) |
|
|
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1]) |
| volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1]) |
| volume_A = self.normalize(volume_A) |
| volume_B = self.normalize(volume_B) |
| |
| if self.min_crop_ratio is not None: |
| |
| |
| crop_ratio = np.random.uniform(self.min_crop_ratio, 1) |
| |
| |
| crop_size_A = int(max(volume_A.shape) * crop_ratio) |
| crop_size_B = int(max(volume_B.shape) * crop_ratio) |
| crop_size_A = min(crop_size_A, self.max_sz) |
| crop_size_B = min(crop_size_B, self.max_sz) |
| volume_A = self.random_crop_3d(volume_A, crop_size_A) |
| volume_B = self.random_crop_3d(volume_B, crop_size_B) |
| volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
| volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
|
|
| else: |
| volume_A = self.random_crop_3d(volume_A, self.out_sz) |
| volume_B = self.random_crop_3d(volume_B, self.out_sz) |
| volume_A = volume_A[None, :, :, :] |
| volume_B = volume_B[None, :, :, :] |
|
|
| |
| if self.transform is not None: |
| return self.transform(volume_A), self.transform(volume_B) |
| |
| |
| return [volume_A, volume_B, embd_A, embd_B] |
|
|
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
|
|
| class OminiDataset_paired_inf(object): |
| def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, ROIs = None): |
| |
| self.ALLdata = self.combine_data(mappings=mapping_files) |
| self.out_sz = out_sz |
| self.min_crop_ratio = min_crop_ratio |
| self.transform = transform |
| self.clamp_range = clamp_range |
| self.ndims = 3 |
| |
| |
| self.ALLdata_filtered = self.get_filter_mindim() |
| |
| if ROIs is None: |
| self.ROIs = self.get_all_ROI() |
| else: |
| self.ROIs = ROIs |
| self.ALLdata_filtered = self.get_filter_ROIs() |
| |
| |
| self.roi_scan_mapping = self.build_ROI_scan_mapping() |
| self.keys_dist, self.total = self.get_keys_dist() |
| |
| |
|
|
| |
| def get_all_ROI(self): |
| |
| ROIs = [] |
| for k in self.ALLdata_filtered.keys(): |
| ROIs.append(self.ALLdata[k]['ROI']) |
| ROIs = set(ROIs) |
| return ROIs |
| |
| def get_ALLdata(self): |
| |
| return self.ALLdata |
| |
| def combine_data(self, mappings = mapping_files): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in mappings.keys(): |
| with open(mappings[j], 'r') as f: |
| mappings_tmp = json.load(f) |
| skipped = 0 |
| for k, v in mappings_tmp.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings_tmp) - skipped |
| total_entries += len(mappings_tmp) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
| |
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
| |
| def random_crop_3d(self, volume, crop_size=None): |
| |
| d, h, w = volume.shape |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| if crop_d > d or crop_h > h or crop_w > w: |
| raise ValueError("Crop size must be smaller than the original array size") |
| |
| start_d = np.random.randint(0, d - crop_d + 1) |
| start_h = np.random.randint(0, h - crop_h + 1) |
| start_w = np.random.randint(0, w - crop_w + 1) |
|
|
| cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
|
|
| return cropped_array |
| |
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
| |
| def get_3D_volume(self, volume, select_channel = None): |
| volume = reverse_axis_order(volume) |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| return volume |
| |
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: |
| del ALLdata[k] |
| return ALLdata |
| |
| def get_filter_ROI(self, key_word): |
| |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if key_word not in k["ROI"]: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
| |
| |
| def get_filter_ROIs(self): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_keys_dist(self): |
| ROIs = self.get_all_ROI() |
| keys_dist = {} |
| total = 0 |
| for item in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[item]['ROI'] not in keys_dist: |
| keys_dist[self.ALLdata_filtered[item]['ROI']] = 0 |
| keys_dist[self.ALLdata_filtered[item]['ROI']] += 1 |
| |
| return keys_dist, total |
| |
| def build_ROI_scan_mapping(self): |
| |
| ROI_scan_mapping = {} |
| for item in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping: |
| ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = [] |
| ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item) |
| return ROI_scan_mapping |
|
|
|
|
| def get_random_2_items(self, mode = 'uniform'): |
| |
| if mode == 'uniform': |
| idx = random.randint(0, len(self.keys_dist.keys()) - 1) |
| key = list(self.keys_dist.keys())[idx] |
| path_1 = random.choice(self.roi_scan_mapping[key]) |
| path_2 = random.choice(self.roi_scan_mapping[key]) |
| |
| volume_A = sitk.ReadImage(path_1) |
| volume_A = sitk.GetArrayFromImage(volume_A) |
| |
| volume_B = sitk.ReadImage(path_2) |
| volume_B = sitk.GetArrayFromImage(volume_B) |
| |
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1]) |
| volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1]) |
| volume_A = self.normalize(volume_A) |
| volume_B = self.normalize(volume_B) |
| |
| if self.min_crop_ratio is not None: |
| crop_ratio = np.random.uniform(self.min_crop_ratio, 1) |
| crop_size_A = int(min(volume_A.shape) * crop_ratio) |
| crop_size_B = int(min(volume_B.shape) * crop_ratio) |
| volume_A = self.random_crop_3d(volume_A, crop_size_A) |
| volume_B = self.random_crop_3d(volume_B, crop_size_B) |
| volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
| volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
| else: |
| volume_A = self.radndom_crop_3d(volume_A, self.out_sz) |
| volume_B = self.radndom_crop_3d(volume_B, self.out_sz) |
| volume_A = volume_A[None, :, :, :] |
| volume_B = volume_B[None, :, :, :] |
| if self.transform is not None: |
| return self.transform(volume_A), self.transform(volume_B) |
| return volume_A, volume_B |
| |
| elif mode == 'original': |
| pass |
| |
| def build_batch(self, batch_size = 2): |
| batch_1 = [] |
| batch_2 = [] |
| for i in range(batch_size): |
| V_a, V_b = self.get_random_2_items() |
| batch_1.append(V_a) |
| batch_2.append(V_b) |
| return np.array(batch_1), np.array(batch_2) |
| |
| class OminiDataset_inference_w_all(object): |
| def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = None, label_key = ['brain'], task_key = 'segmentation', database = None, select_channels_dict = {}): |
| self.mappings = mapping_files |
| |
| if database is not None: |
| self.mappings = {db: self.mappings[db] for db in database if db in self.mappings} |
| |
| |
| |
| self.select_channels_dict = select_channels_dict |
| self.ALLdata = self.combine_data(mappings=self.mappings) |
| self.out_sz = out_sz |
| self.label_key = label_key |
| self.min_crop_ratio = min_crop_ratio |
| self.transform = transform |
| self.clamp_range = clamp_range |
| self.ndims = 3 |
| self.is_reverse_axis_order = True |
|
|
| |
| |
| |
| self.ALLdata_filtered = self.get_filter_mindim() |
| |
| if ROIs is None: |
| self.ROIs = self.get_all_ROI() |
| else: |
| self.ROIs = ROIs |
| self.ALLdata_filtered = self.get_filter_ROIs() |
| self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key) |
| |
| |
| self.roi_scan_mapping = self.build_ROI_scan_mapping() |
| self.keys_dist, self.total = self.get_keys_dist() |
| |
| |
|
|
| def get_all_ROI(self): |
| |
| ROIs = [] |
| for k in self.ALLdata_filtered.keys(): |
| ROIs.append(self.ALLdata[k]['ROI']) |
| ROIs = set(ROIs) |
| return ROIs |
|
|
| def get_keys_dist(self): |
| ROIs = self.get_all_ROI() |
| keys_dist = {} |
| total = 0 |
| for item in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[item]['ROI'] not in keys_dist: |
| keys_dist[self.ALLdata_filtered[item]['ROI']] = 0 |
| keys_dist[self.ALLdata_filtered[item]['ROI']] += 1 |
| |
| return keys_dist, total |
|
|
| def build_ROI_scan_mapping(self): |
| |
| ROI_scan_mapping = {} |
| for item in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping: |
| ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = [] |
| ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item) |
| return ROI_scan_mapping |
|
|
| def get_3D_volume(self, volume, select_channel = None): |
| volume = reverse_axis_order(volume) if self.is_reverse_axis_order else volume |
| if volume.ndim == 4: |
| if select_channel is None: |
| select_channel = np.random.randint(0, volume.shape[3] - 1) |
| volume = volume[:, :, :, select_channel] |
| |
| return volume |
|
|
| def get_filter_mindim(self): |
| |
| |
| ALLdata = self.ALLdata.copy() |
| for k in self.ALLdata.keys(): |
| if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2: |
| del ALLdata[k] |
| return ALLdata |
|
|
| def find_min_dim(self): |
| |
| min_dim = 100000 |
| for k in self.ALLdata.keys(): |
| value = self.ALLdata[k] |
| if min(value['Size']) < min_dim: |
| min_dim = min(value['Size']) |
| return min_dim |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def combine_data(self, mappings = mapping_files): |
| ALLdata = {} |
| total_entries = 0 |
| total_skipped = 0 |
| for j in mappings.keys(): |
| with open(mappings[j], 'r') as f: |
| mappings_tmp = json.load(f) |
| skipped = 0 |
| for k, v in mappings_tmp.items(): |
| if not os.path.exists(k) or os.path.getsize(k) == 0: |
| skipped += 1 |
| continue |
| ALLdata[k] = v |
| accessible = len(mappings_tmp) - skipped |
| total_entries += len(mappings_tmp) |
| total_skipped += skipped |
| if skipped > 0: |
| print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)") |
| if total_skipped > 0: |
| print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)") |
| if len(ALLdata) < 1000: |
| print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. " |
| f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***") |
| return ALLdata |
|
|
| def normalize(self, volume, eps=1e-7): |
| |
| volume = volume.astype(np.float64) |
| volume = (volume - np.min(volume)) / (np.ptp(volume) + eps) |
| return volume |
|
|
| def get_key_by_ROI(self, key_word): |
| |
| keys = [] |
| for k in self.ALLdata_filtered.keys(): |
| if key_word == self.ALLdata_filtered[k]["ROI"]: |
| keys.append(k) |
| return keys |
| |
| def get_filter_task(self, task_key = 'segmentation'): |
| |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if 'Label_path' not in self.ALLdata_filtered[k] or task_key not in self.ALLdata_filtered[k]['Label_path']: |
| del ALLdata_filtered[k] |
| Warning(f"Label path not found for {k} with task key {task_key}. This image will be removed from the dataset.") |
| return ALLdata_filtered |
|
|
| def get_filter_labels(self, task_key='segmentation', label_keys=['heart']): |
| |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| keys_to_remove = [] |
| for k in list(ALLdata_filtered.keys()): |
| label_path = ALLdata_filtered[k].get('Label_path', {}) |
| task_labels = label_path.get(task_key, {}) |
| |
| |
| has_any_label = any((tk in label_keys) for tk in task_labels.keys()) |
| |
| if not has_any_label: |
| keys_to_remove.append(k) |
| |
| for k in keys_to_remove: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_random_pad_crop_params(self, volume_shape, crop_size=None, random=True): |
| |
| d, h, w = volume_shape[:3] |
| if crop_size is None: |
| crop_size = self.out_sz |
| crop_d, crop_h, crop_w = crop_size, crop_size, crop_size |
|
|
| |
| pad_width = [] |
| for size, crop in zip((d, h, w), (crop_d, crop_h, crop_w)): |
| if crop > size: |
| total_pad = crop - size |
| pad_before = np.random.randint(0, total_pad + 1) |
| pad_after = total_pad - pad_before |
| pad_width.append((pad_before, pad_after)) |
| else: |
| pad_width.append((0, 0)) |
|
|
| |
| d_p, h_p, w_p = d + pad_width[0][0] + pad_width[0][1], h + pad_width[1][0] + pad_width[1][1], w + pad_width[2][0] + pad_width[2][1] |
|
|
| if random: |
| |
| start_d = np.random.randint(0, d_p - crop_d + 1) if d_p > crop_d else 0 |
| start_h = np.random.randint(0, h_p - crop_h + 1) if h_p > crop_h else 0 |
| start_w = np.random.randint(0, w_p - crop_w + 1) if w_p > crop_w else 0 |
| else: |
| |
| start_d = max((d_p - crop_d) // 2, 0) |
| start_h = max((h_p - crop_h) // 2, 0) |
| start_w = max((w_p - crop_w) // 2, 0) |
|
|
| crop_slices = (start_d, start_h, start_w, crop_d, crop_h, crop_w) |
| return pad_width, crop_slices |
|
|
| def apply_pad_crop(self, volume, pad_width, crop_slices): |
| |
| if any(pad != (0, 0) for pad in pad_width): |
| volume = np.pad(volume, pad_width, mode='constant', constant_values=0) |
| start_d, start_h, start_w, crop_d, crop_h, crop_w = crop_slices |
| cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w] |
| return cropped_array |
| |
| def get_filter_ROIs(self): |
| ALLdata_filtered = self.ALLdata_filtered.copy() |
| for k in self.ALLdata_filtered.keys(): |
| if self.ALLdata_filtered[k]['ROI'] not in self.ROIs: |
| del ALLdata_filtered[k] |
| return ALLdata_filtered |
|
|
| def get_channel_ids(self, key): |
| """ |
| Get the indices where ImgDict values match the selected channels (e.g., 'ed', 'es'). |
| |
| Returns: |
| list: List of integer indices matching the selected channels |
| """ |
| img_dict = self.ALLdata_filtered[key].get("ImgDict", {}) |
| selected_values = self.select_channels_dict.get("ImgDict", []) |
| |
| value_to_idx = {value: int(idx) for idx, value in img_dict.items()} |
| |
| |
| indices = [ |
| value_to_idx[val] for val in selected_values |
| if val in value_to_idx |
| ] |
| return indices |
| |
|
|
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
|
|
| def __getitem__(self, idx): |
| key = list(self.ALLdata_filtered.keys())[idx] |
| return_dict = dict() |
| |
| print(f"Processing key: {key}") |
|
|
| volume = sitk.ReadImage(key) |
| volume = sitk.GetArrayFromImage(volume) |
| |
| if volume.ndim == 4: |
| channel_ids = self.get_channel_ids(key) |
| if len(channel_ids) == 0: |
| |
| Warning(f"No matching channels found for key: {key} with ImgDict: {self.ALLdata_filtered[key].get('ImgDict', {})} and selected channels: {self.select_channels_dict.get('ImgDict', [])}. Using random channel.") |
| channel_id = None |
| else: |
| channel_id=channel_ids[0] |
|
|
| volume = self.get_3D_volume(volume, select_channel = channel_id) |
|
|
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) |
| volume = self.normalize(volume) |
| |
| crop_ratio = np.random.uniform(self.min_crop_ratio, 1) |
|
|
| crop_size = int(max(volume.shape) * crop_ratio) |
| pad_width, crop_slices = self.get_random_pad_crop_params(volume.shape, crop_size) |
| |
| volume = self.apply_pad_crop(volume, pad_width, crop_slices) |
|
|
| label_dict = dict() |
| if 'Label_path' in self.ALLdata_filtered[key]: |
| for lk in self.label_key: |
| if lk in self.ALLdata_filtered[key]['Label_path']['segmentation'].keys(): |
| label = sitk.ReadImage(self.ALLdata_filtered[key]['Label_path']['segmentation'][lk]) |
| label = sitk.GetArrayFromImage(label) |
| |
| label = reverse_axis_order(label) if self.is_reverse_axis_order else label |
| |
| |
| if label.ndim > self.ndims: |
| if len(channel_ids) != 0: |
| label = label[...,channel_ids] |
| pad_width_lab = pad_width + [(0,0)]*(label.ndim - self.ndims) |
| |
| else: |
| pad_width_lab = pad_width |
| |
| label = self.apply_pad_crop(label, pad_width_lab, crop_slices) |
| |
| label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0) |
| if label.ndim > self.ndims: |
| if self.ndims==3: |
| label_dict[lk] = np.transpose(label_dict[lk], (3,0,1,2)) |
| elif self.ndims==4: |
| label_dict[lk] = np.transpose(label_dict[lk], (4,0,1,2,3)) |
| |
| else: |
| label_dict[lk] = np.full([self.out_sz]*self.ndims, -1) |
| Warning(f"Label path not found for {key} with label key {lk}.") |
| label_dict[lk] = label_dict[lk][None, :, :, :] if label_dict[lk].ndim == 3 else label_dict[lk] |
| else: |
| for lk in self.label_key: |
| label_dict[lk] = np.full([self.out_sz]*self.ndims, -1) |
| Warning(f"Label path not found for {key} with label key {lk}.") |
| label_dict[lk] = label_dict[lk][None, :, :, :] |
| |
| volume =resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
| |
| return_dict['labels'] = np.concatenate([v for v in label_dict.values()], axis=1) |
| |
| return_dict['img'] = volume[None, :, :, :] |
| return_dict['label_channels'] = list(self.select_channels_dict.get("ImgDict", [])) |
| return return_dict |
|
|
|
|
|
|
| class OminiDataset_bertembd(OminiDataset): |
| def __init__(self, |
| out_sz = 128, |
| transform=None, |
| clamp_range = CLAMP_RANGE, |
| min_crop_ratio = 0.85, |
| ROIs = None, |
| modality = None, |
| reverse_axis_order = False, |
| min_dim = 3, |
| mapping_files = mapping_files): |
| super().__init__(out_sz = out_sz, |
| transform = transform, |
| clamp_range = clamp_range, |
| min_crop_ratio = min_crop_ratio, |
| ROIs = ROIs, |
| modality = modality, |
| reverse_axis_order = reverse_axis_order, |
| min_dim = min_dim, |
| mapping_files=mapping_files) |
| |
| self.ALLdata_filtered = self.get_filter_mindim() |
| if ROIs is None: |
| |
| self.ROIs = self.get_all_ROI() |
| else: |
| self.ROIs = ROIs |
| self.ALLdata_filtered = self.get_filter_ROIs() |
| |
| |
| |
| def __getitem__(self, idx): |
| key = list(self.ALLdata_filtered.keys())[idx] |
| embd = self.ALLdata_filtered[key]['embd'] |
| if 0: |
| print(key) |
|
|
| volume = sitk.ReadImage(key) |
| volume = sitk.GetArrayFromImage(volume) |
| volume = self.get_3D_volume(volume) |
| |
| if self.clamp_range is not None: |
| modality = self.ALLdata_filtered[key].get("Modality", None) |
| if modality == "CT": |
| volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1]) |
| volume = self.normalize(volume) |
| |
| if self.min_crop_ratio is not None: |
| crop_ratio = np.random.uniform(self.min_crop_ratio, 1) |
| crop_size = int(max(volume.shape) * crop_ratio) |
| volume = self.random_crop_3d(volume, crop_size) |
| volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True) |
| else: |
| volume = self.random_crop_3d(volume, self.out_sz) |
| volume = volume[None, :, :, :] |
|
|
| if self.transform is not None: |
| return self.transform(volume) |
|
|
| return volume,np.array(embd) |
| |
| def __len__(self): |
| return len(self.ALLdata_filtered.keys()) |
| |
| def filter_embd(self): |
| for k in self.ALLdata_filtered.keys(): |
| if 'BERT_embedding_keys' not in self.ALLdata_filtered[k]['Metadata']: |
| del self.ALLdata_filtered[k] |
| return self.ALLdata_filtered |
| |
|
|