Omini3D / Dataloader /dataLoader.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
import os, sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
from torch.utils.data import Dataset, DataLoader
import json
import SimpleITK as sitk
import numpy as np
from skimage.transform import rescale, resize, downscale_local_mean
# from torchvision.transforms import v2
# sys.path.append('./')
sys.path.append(ROOT_DIR)
from Dataloader.dataloader_utils import *
import random
# add your mapping files here
# mapping_files = {
# 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
# 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
# # 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_1/nifti_mappings.json',
# }
# mapping_files = {
# 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
# 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
# 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
# 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
# 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
# # 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
# 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
# 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
# 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
# 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
# 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
# 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
# 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
# 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
# }
mapping_files = {
'MSD': 'nifty_mappings/MSD_mappings.json',
'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json',
'CancerImageArchive': 'nifty_mappings/CIA_mappings.json',
'MnMs': 'nifty_mappings/MnMs_mappings.json',
# 'Brats2019': 'nifty_mappings/Brats2019_mappings.json', # should be commented out after testing
'Brats2020': 'nifty_mappings/Brats2020_mappings.json',
'Brats2021': 'nifty_mappings/Brats2021_mappings.json',
'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json',
'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json',
'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json',
'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json',
'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json',
'OAI_ZIB': 'nifty_mappings/OAI_ZIB_KL_mappings.json',
# 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_WOMAC_mappings.json', # alternative: WOMAC scores instead of KL-grade
}
for k,v in mapping_files.items():
mapping_files[k] = os.path.join(ROOT_DIR, v)
CLAMP_RANGE = [-400, 400] # default clamp range for the images
indivi_ROI_list = ['abdomen','arm','brain','hand','head','leg','neck','pelvis','skeleton','thorax']
def reverse_axis_order(arr):
"""SimpleITK to NumPy axis order conversion."""
# For 3D or 4D arrays, this is just a fast view, not a copy.
return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high'):
"""Sample a random value from a uniform distribution with multiple orders.
Args:
high (float): Upper bound of the uniform distribution.
low (float): Lower bound of the uniform distribution.
order_num (int): Number of times to sample.
type (str): 'high' or 'low', determines the sampling direction.
Returns:
sample_value (float): The sampled value after multiple orders.
Notes:
- If type is 'high', samples are drawn iteratively from [low, high], each time using the previous sample as the new lower bound.
- If type is 'low', samples are drawn iteratively from [low, high], each time using the previous sample as the new upper bound.
- If order_num is 0, returns the low value.
- If order_num is 1, returns a single random value from the uniform distribution.
- If order_num is 2, returns a value from a linear distribution.
- If order_num is 3, returns a value from a quadratic distribution.
"""
if type == 'high':
sample_value = low
for _ in range(order_num):
sample_value = np.random.uniform(low=sample_value, high=high)
elif type == 'low':
sample_value = high
for _ in range(order_num):
sample_value = np.random.uniform(low, high=sample_value)
return sample_value
class OminiDataset(object):
"""Base class for OmniMorph datasets."""
def __init__(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
# self.mappings = mapping_files
self.ALLdata = self.combine_data(mappings = mapping_files)
self.out_sz = out_sz
self.reverse_axis_order = reverse_axis_order
self.min_dim = min_dim
self.clamp_range = clamp_range
self.min_crop_ratio = min_crop_ratio
self.transform = transform
self.ndims = 3
def get_ALLdata(self):
return self.ALLdata
def get_all_ROI(self):
# Get all the ROI options. and remove the reduntant ones
ROIs = []
# ALLdata_filtered = data
for k in self.ALLdata_filtered.keys():
ROIs.append(self.ALLdata[k]['ROI'])
ROIs = set(ROIs)
return ROIs
def get_filter_ROIs(self,keep_single_roi=False):
ALLdata_filtered = self.ALLdata_filtered.copy()
# if keep_single_roi == True:
# for k in self.ALLdata_filtered.keys():
# if '-' in self.ALLdata_filtered[k]['ROI']:
# del ALLdata_filtered[k]
# d = {k: v for k, v in ALLdata_filtered.items() if v['ROI'] in self.ROIs}
for k in ALLdata_filtered.keys():
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
del ALLdata_filtered[k]
return ALLdata_filtered
def combine_data(self, mappings = mapping_files):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in mappings.keys():
with open(mappings[j], 'r') as f:
mappings_tmp = json.load(f)
skipped = 0
for k, v in mappings_tmp.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings_tmp) - skipped
total_entries += len(mappings_tmp)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def get_3D_volume(self, volume, select_channel = None):
# Get a 3D volume from the 4D volume, sometime the input image may have 4 dimensions
if self.reverse_axis_order:
volume = reverse_axis_order(volume)
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
return volume
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
del ALLdata[k]
return ALLdata
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def random_crop_3d(self, volume, crop_size=None):
# Fast random crop with optional padding using NumPy
d, h, w = volume.shape
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# Only pad if needed (avoid np.pad if not necessary)
pad_d = max(0, crop_d - d)
pad_h = max(0, crop_h - h)
pad_w = max(0, crop_w - w)
if pad_d or pad_h or pad_w:
pad_width = (
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
)
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
d, h, w = volume.shape
# Crop indices
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
# Use NumPy slicing (very fast)
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
class OminiDataset_v1(Dataset):
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.2, reverse_axis_order = False):
self.mappings = mapping_files
self.ALLdata = self.combine_data()
self.out_sz = out_sz
self.reverse_axis_order = reverse_axis_order
self.min_crop_ratio = min_crop_ratio
self.crop_ratio_sample_order = 2
self.transform = transform
self.clamp_range = clamp_range
self.ndims = 3
# Start you filtering here
self.ALLdata_filtered = self.get_filter_mindim()
# self.min_dim = self.find_min_dim()
def find_min_dim(self):
# Find the minimum dimension of the images
min_dim = 100000
for k in self.ALLdata.keys():
value = self.ALLdata[k]
if min(value['Size']) < min_dim:
min_dim = min(value['Size'])
return min_dim
def random_crop_3d(self, volume, crop_size=None):
# Fast random crop with optional padding using NumPy
d, h, w = volume.shape
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# Only pad if needed (avoid np.pad if not necessary)
pad_d = max(0, crop_d - d)
pad_h = max(0, crop_h - h)
pad_w = max(0, crop_w - w)
if pad_d or pad_h or pad_w:
pad_width = (
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
)
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
d, h, w = volume.shape
# Crop indices
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
# Use NumPy slicing (very fast)
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
def get_ALLdata(self):
# Return all data
return self.ALLdata
def get_3D_volume(self, volume, select_channel = None):
if self.reverse_axis_order:
volume = reverse_axis_order(volume)
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
# print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
return volume
def get_filter_ROI(self, key_word):
# Filter out images with a key word
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if key_word not in k["ROI"]:
del ALLdata[k]
return ALLdata
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
del ALLdata[k]
return ALLdata
def combine_data(self):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in self.mappings.keys():
with open(self.mappings[j], 'r') as f:
mappings = json.load(f)
skipped = 0
for k, v in mappings.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings) - skipped
total_entries += len(mappings)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def __len__(self):
return len(self.ALLdata_filtered.keys())
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def __getitem__(self, idx):
key = list(self.ALLdata_filtered.keys())[idx]
if 0:
print(key)
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
# if volume.ndim == 4:
volume = self.get_3D_volume(volume)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
volume = self.normalize(volume)
if self.min_crop_ratio is not None:
# print(f'before volume_shape: {volume.shape}')
# crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high')
# crop_size = int(min(volume.shape) * crop_ratio)
crop_size = int(max(volume.shape) * crop_ratio)
volume = self.random_crop_3d(volume, crop_size)
volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
else:
volume = self.random_crop_3d(volume, self.out_sz)
volume = volume[None, :, :, :]
if self.transform is not None:
return self.transform(volume)
return volume
class OMDataset_indiv(Dataset):
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, reverse_axis_order = False):
# self.mappings = mapping_files
self.ALLdata = self.combine_data(mappings=mapping_files)
self.out_sz = out_sz
self.max_sz = out_sz*8
self.reverse_axis_order = reverse_axis_order
self.min_crop_ratio = min_crop_ratio
self.crop_ratio_sample_order = 2
self.transform = transform
self.clamp_range = clamp_range
self.ndims = 3
# Start you filtering here
# print(f"Filtering data with out_sz: {self.out_sz}, min_crop_ratio: {min_crop_ratio}")
print(f"Diffusion mode: Total data size before filtering: {len(self.ALLdata)}")
self.ALLdata_filtered = self.get_filter_mindim()
print(f"Diffusion mode: Filtered data size: {len(self.ALLdata_filtered)}")
# self.min_dim = self.find_min_dim()
def find_min_dim(self):
# Find the minimum dimension of the images
min_dim = 100000
for k in self.ALLdata.keys():
value = self.ALLdata[k]
if min(value['Size']) < min_dim:
min_dim = min(value['Size'])
return min_dim
def random_crop_3d(self, volume, crop_size=None):
# Fast random crop with optional padding using NumPy
d, h, w = volume.shape
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# Only pad if needed (avoid np.pad if not necessary)
pad_d = max(0, crop_d - d)
pad_h = max(0, crop_h - h)
pad_w = max(0, crop_w - w)
if pad_d or pad_h or pad_w:
pad_width = (
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
)
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
d, h, w = volume.shape
# Crop indices
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
# Use NumPy slicing (very fast)
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
def get_ALLdata(self):
# Return all data
return self.ALLdata
def get_3D_volume(self, volume, select_channel = None):
if self.reverse_axis_order:
volume = reverse_axis_order(volume)
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
# print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
return volume
def get_filter_ROI(self, key_word):
# Filter out images with a key word
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if key_word not in k["ROI"]:
del ALLdata[k]
return ALLdata
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
del ALLdata[k]
return ALLdata
def combine_data(self, mappings = mapping_files):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in mappings.keys():
with open(mappings[j], 'r') as f:
mappings_tmp = json.load(f)
skipped = 0
for k, v in mappings_tmp.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings_tmp) - skipped
total_entries += len(mappings_tmp)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def __len__(self):
return len(self.ALLdata_filtered.keys())
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def __getitem__(self, idx):
key = list(self.ALLdata_filtered.keys())[idx]
embd = self.ALLdata_filtered[key]['embd']
embd = np.array(embd, dtype=np.float32)
if 0:
print(key)
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
# if volume.ndim == 4:
volume = self.get_3D_volume(volume)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
volume = self.normalize(volume)
if self.min_crop_ratio is not None:
# print(f'before volume_shape: {volume.shape}')
# crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high')
# crop_size = int(min(volume.shape) * crop_ratio)
crop_size = int(max(volume.shape) * crop_ratio)
crop_size = min(crop_size, self.max_sz)
volume = self.random_crop_3d(volume, crop_size)
volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
else:
volume = self.random_crop_3d(volume, self.out_sz)
volume = volume[None, :, :, :]
if self.transform is not None:
return self.transform(volume)
return [volume, embd]
class OminiDataset_paired(Dataset):
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.85, ROIs = None, modality = None, reverse_axis_order = False):
# self.mappings = mapping_files
self.ALLdata = self.combine_data(mappings=mapping_files)
self.out_sz = out_sz
self.sz_range = get_sizeRange_dict()
self.min_dim_ratio = 0.5
self.reverse_axis_order = reverse_axis_order
self.min_crop_ratio = min_crop_ratio
self.transform = transform
self.clamp_range = clamp_range
self.ndims = 3
# Start you filtering here
# print(f"Number of images before filtering: {len(self.ALLdata.keys())}")
self.ALLdata_filtered = self.get_filter_mindim()
# print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
self.ALLdata_filtered = self.get_filter_modality(modality)
# print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}")
if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data
self.ROIs = self.get_all_ROI()
else:
self.ROIs = ROIs
self.ALLdata_filtered = self.get_filter_ROIs()
# print(f"Number of images after ROI filtering: {len(self.ALLdata_filtered.keys())}")
# filtering ends here
def combine_data(self, mappings = mapping_files):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in mappings.keys():
with open(mappings[j], 'r') as f:
mappings_tmp = json.load(f)
skipped = 0
for k, v in mappings_tmp.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings_tmp) - skipped
total_entries += len(mappings_tmp)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def random_crop_3d(self, volume, crop_size=None):
# Fast random crop with optional padding using NumPy
d, h, w = volume.shape
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# Only pad if needed (avoid np.pad if not necessary)
pad_d = max(0, crop_d - d)
pad_h = max(0, crop_h - h)
pad_w = max(0, crop_w - w)
if pad_d or pad_h or pad_w:
pad_width = (
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
)
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
d, h, w = volume.shape
# Crop indices
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
# Use NumPy slicing (very fast)
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
# def random_crop_3d(self, volume, crop_size=None):
# # Randomly crop the image
# d, h, w = volume.shape
# if crop_size is None:
# crop_size = self.out_sz
# crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# if crop_d > d or crop_h > h or crop_w > w:
# raise ValueError("Crop size must be smaller than the original array size")
# start_d = np.random.randint(0, d - crop_d + 1)
# start_h = np.random.randint(0, h - crop_h + 1)
# start_w = np.random.randint(0, w - crop_w + 1)
# cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
# return cropped_array
def get_all_ROI(self):
# Get all the ROI options. and remove the reduntant ones
ROIs = []
for k in self.ALLdata_filtered.keys():
ROIs.append(self.ALLdata[k]['ROI'])
ROIs = set(ROIs)
return ROIs
def find_min_dim(self):
# Find the minimum dimension of the images
min_dim = 100000
for k in self.ALLdata.keys():
value = self.ALLdata[k]
if min(value['Size']) < min_dim:
min_dim = min(value['Size'])
return min_dim
def get_ALLdata(self):
# Return all data
return self.ALLdata
def get_filter_modality(self, key_words=None):
ALLdata_filtered = self.ALLdata_filtered.copy()
if key_words is not None:
for k in self.ALLdata_filtered.keys():
if ALLdata_filtered[k]["Modality"] not in key_words:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_filter_ROI(self, key_word):
# Filter out images with a key word
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if key_word not in k["ROI"]:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_key_by_ROI(self, key_word):
# Get all the keys with a key word
keys = []
for k in self.ALLdata_filtered.keys():
if key_word == self.ALLdata_filtered[k]["ROI"]:
keys.append(k)
return keys
def get_filter_ROIs(self):
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_3D_volume(self, volume, select_channel = None):
if self.reverse_axis_order:
volume = reverse_axis_order(volume)
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
return volume
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
img_sz = self.ALLdata[k]['Size'][:self.ndims]
del_flag = False
del_flag = del_flag or min(img_sz) < self.out_sz
# print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}")
# print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}")
del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0]
del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio)
# del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio)
if del_flag:
del ALLdata[k]
return ALLdata
def __getitem__(self,idx):
key = list(self.ALLdata_filtered.keys())[idx]
volume_A = sitk.ReadImage(key)
volume_A = sitk.GetArrayFromImage(volume_A)
paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI'])
paired_key = random.choice(paired_keys)
volume_B = sitk.ReadImage(paired_key)
volume_B = sitk.GetArrayFromImage(volume_B)
# if volume_A.ndim == 4 or volume_B.ndim == 4:
volume_A = self.get_3D_volume(volume_A)
volume_B = self.get_3D_volume(volume_B)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
volume_A = self.normalize(volume_A)
volume_B = self.normalize(volume_B)
if self.min_crop_ratio is not None:
# print(f'before volume_shape: {volume.shape}')
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
crop_size_A = int(min(volume_A.shape) * crop_ratio)
crop_size_B = int(min(volume_B.shape) * crop_ratio)
# crop_size_A = int(max(volume_A.shape) * crop_ratio)
# crop_size_B = int(max(volume_B.shape) * crop_ratio)
volume_A = self.random_crop_3d(volume_A, crop_size_A)
volume_B = self.random_crop_3d(volume_B, crop_size_B)
volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
else:
volume_A = self.random_crop_3d(volume_A, self.out_sz)
volume_B = self.random_crop_3d(volume_B, self.out_sz)
volume_A = volume_A[None, :, :, :]
volume_B = volume_B[None, :, :, :]
if self.transform is not None:
return self.transform(volume_A), self.transform(volume_B)
# print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name'])
return volume_A, volume_B
def __len__(self):
return len(self.ALLdata_filtered.keys())
class OMDataset_pair(Dataset):
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = indivi_ROI_list, modality = None, reverse_axis_order = False):
# self.mappings = mapping_files
self.ALLdata = self.combine_data(mappings=mapping_files)
self.out_sz = out_sz
self.max_sz = out_sz*8
self.sz_range = get_sizeRange_dict()
self.min_dim_ratio = 0.7
self.reverse_axis_order = reverse_axis_order
self.min_crop_ratio = min_crop_ratio
self.transform = transform
self.clamp_range = clamp_range
self.ndims = 3
# Start you filtering here
# print(f"Number of images before filtering: {len(self.ALLdata.keys())}")
print(f"Registration mode: Total data size before filtering: {len(self.ALLdata)}")
self.ALLdata_filtered = self.get_filter_mindim()
# print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
self.ALLdata_filtered = self.get_filter_modality(modality)
# print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}")
if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data
self.ROIs = self.get_all_ROI()
else:
self.ROIs = ROIs
self.ALLdata_filtered = self.get_filter_ROIs()
print(f"Registration mode: Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
# filtering ends here
def combine_data(self, mappings = mapping_files):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in mappings.keys():
with open(mappings[j], 'r') as f:
mappings_tmp = json.load(f)
skipped = 0
for k, v in mappings_tmp.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings_tmp) - skipped
total_entries += len(mappings_tmp)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def random_crop_3d(self, volume, crop_size=None):
# Fast random crop with optional padding using NumPy
d, h, w = volume.shape
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# Only pad if needed (avoid np.pad if not necessary)
pad_d = max(0, crop_d - d)
pad_h = max(0, crop_h - h)
pad_w = max(0, crop_w - w)
if pad_d or pad_h or pad_w:
pad_width = (
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
)
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
d, h, w = volume.shape
# Crop indices
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
# Use NumPy slicing (very fast)
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
# def random_crop_3d(self, volume, crop_size=None):
# # Randomly crop the image
# d, h, w = volume.shape
# if crop_size is None:
# crop_size = self.out_sz
# crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# if crop_d > d or crop_h > h or crop_w > w:
# raise ValueError("Crop size must be smaller than the original array size")
# start_d = np.random.randint(0, d - crop_d + 1)
# start_h = np.random.randint(0, h - crop_h + 1)
# start_w = np.random.randint(0, w - crop_w + 1)
# cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
# return cropped_array
def get_all_ROI(self):
# Get all the ROI options. and remove the reduntant ones
ROIs = []
for k in self.ALLdata_filtered.keys():
ROIs.append(self.ALLdata[k]['ROI'])
ROIs = set(ROIs)
return ROIs
def find_min_dim(self):
# Find the minimum dimension of the images
min_dim = 100000
for k in self.ALLdata.keys():
value = self.ALLdata[k]
if min(value['Size']) < min_dim:
min_dim = min(value['Size'])
return min_dim
def get_ALLdata(self):
# Return all data
return self.ALLdata
def get_filter_modality(self, key_words=None):
ALLdata_filtered = self.ALLdata_filtered.copy()
if key_words is not None:
for k in self.ALLdata_filtered.keys():
if ALLdata_filtered[k]["Modality"] not in key_words:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_filter_ROI(self, key_word):
# Filter out images with a key word
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if key_word not in k["ROI"]:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_key_by_ROI(self, key_word):
# Get all the keys with a key word
keys = []
for k in self.ALLdata_filtered.keys():
if key_word == self.ALLdata_filtered[k]["ROI"]:
keys.append(k)
return keys
def filter_keys_by_xx(self, key_word, keys=None, term="ROI"):
# Filter out images with a key word
filtered_keys = []
if keys is None:
keys = self.ALLdata_filtered.keys()
for k in keys:
value = self.ALLdata_filtered[k].get(term, None)
if value is not None and key_word == value:
filtered_keys.append(k)
return filtered_keys
def get_filter_ROIs(self):
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_3D_volume(self, volume, select_channel = None):
if self.reverse_axis_order:
volume = reverse_axis_order(volume)
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
return volume
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
img_sz = self.ALLdata[k]['Size'][:self.ndims]
del_flag = False
del_flag = del_flag or min(img_sz) < self.out_sz
# print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}")
# print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}")
del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0]
del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio)
# del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio)
if del_flag:
del ALLdata[k]
return ALLdata
def __getitem__(self,idx):
key = list(self.ALLdata_filtered.keys())[idx]
volume_A = sitk.ReadImage(key)
volume_A = sitk.GetArrayFromImage(volume_A)
embd_A = self.ALLdata_filtered[key]['embd']
embd_A = np.array(embd_A, dtype=np.float32)
all_keys = list(self.ALLdata_filtered.keys())
paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['ROI'], all_keys, term="ROI")
paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['Modality'], paired_keys, term="Modality")
# paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI'])
paired_key = random.choice(paired_keys)
# print(f"Key: {key}, Paired Key: {paired_key}")
# print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
volume_B = sitk.ReadImage(paired_key)
volume_B = sitk.GetArrayFromImage(volume_B)
embd_B = self.ALLdata_filtered[paired_key]['embd']
embd_B = np.array(embd_B, dtype=np.float32)
# if volume_A.ndim == 4 or volume_B.ndim == 4:
volume_A = self.get_3D_volume(volume_A)
volume_B = self.get_3D_volume(volume_B)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
volume_A = self.normalize(volume_A)
volume_B = self.normalize(volume_B)
if self.min_crop_ratio is not None:
# print(f'before volume_shape: {volume.shape}')
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
# crop_size_A = int(min(volume_A.shape) * crop_ratio)
# crop_size_B = int(min(volume_B.shape) * crop_ratio)
crop_size_A = int(max(volume_A.shape) * crop_ratio)
crop_size_B = int(max(volume_B.shape) * crop_ratio)
crop_size_A = min(crop_size_A, self.max_sz)
crop_size_B = min(crop_size_B, self.max_sz)
volume_A = self.random_crop_3d(volume_A, crop_size_A)
volume_B = self.random_crop_3d(volume_B, crop_size_B)
volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
else:
volume_A = self.random_crop_3d(volume_A, self.out_sz)
volume_B = self.random_crop_3d(volume_B, self.out_sz)
volume_A = volume_A[None, :, :, :]
volume_B = volume_B[None, :, :, :]
if self.transform is not None:
return self.transform(volume_A), self.transform(volume_B)
# print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name'])
return [volume_A, volume_B, embd_A, embd_B]
def __len__(self):
return len(self.ALLdata_filtered.keys())
class OminiDataset_paired_inf(object):
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, ROIs = None):
# self.mappings = mapping_files
self.ALLdata = self.combine_data(mappings=mapping_files)
self.out_sz = out_sz
self.min_crop_ratio = min_crop_ratio
self.transform = transform
self.clamp_range = clamp_range
self.ndims = 3
# Start you filtering here:
# filter out images with dimensions less than min_dim
self.ALLdata_filtered = self.get_filter_mindim()
# filter out images with ROIs that are not in the provided ROIs
if ROIs is None:
self.ROIs = self.get_all_ROI()
else:
self.ROIs = ROIs
self.ALLdata_filtered = self.get_filter_ROIs()
# filtering ends here
self.roi_scan_mapping = self.build_ROI_scan_mapping()
self.keys_dist, self.total = self.get_keys_dist()
def get_all_ROI(self):
# Get all the ROI options. and remove the reduntant ones
ROIs = []
for k in self.ALLdata_filtered.keys():
ROIs.append(self.ALLdata[k]['ROI'])
ROIs = set(ROIs)
return ROIs
def get_ALLdata(self):
# Return all data
return self.ALLdata
def combine_data(self, mappings = mapping_files):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in mappings.keys():
with open(mappings[j], 'r') as f:
mappings_tmp = json.load(f)
skipped = 0
for k, v in mappings_tmp.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings_tmp) - skipped
total_entries += len(mappings_tmp)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def __len__(self):
return len(self.ALLdata_filtered.keys())
def random_crop_3d(self, volume, crop_size=None):
# Randomly crop the image
d, h, w = volume.shape
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
if crop_d > d or crop_h > h or crop_w > w:
raise ValueError("Crop size must be smaller than the original array size")
start_d = np.random.randint(0, d - crop_d + 1)
start_h = np.random.randint(0, h - crop_h + 1)
start_w = np.random.randint(0, w - crop_w + 1)
cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
return cropped_array
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def get_3D_volume(self, volume, select_channel = None):
volume = reverse_axis_order(volume)
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
return volume
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
del ALLdata[k]
return ALLdata
def get_filter_ROI(self, key_word):
# Filter out images with a key word
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if key_word not in k["ROI"]:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_filter_ROIs(self):
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_keys_dist(self):
ROIs = self.get_all_ROI()
keys_dist = {}
total = 0
for item in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[item]['ROI'] not in keys_dist:
keys_dist[self.ALLdata_filtered[item]['ROI']] = 0
keys_dist[self.ALLdata_filtered[item]['ROI']] += 1
return keys_dist, total
def build_ROI_scan_mapping(self):
# Build a mapping of ROIs to scans
ROI_scan_mapping = {}
for item in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping:
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = []
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item)
return ROI_scan_mapping
def get_random_2_items(self, mode = 'uniform'):
# Get a random pair of items from the dataset with the same ROI
if mode == 'uniform':
idx = random.randint(0, len(self.keys_dist.keys()) - 1)
key = list(self.keys_dist.keys())[idx]
path_1 = random.choice(self.roi_scan_mapping[key])
path_2 = random.choice(self.roi_scan_mapping[key])
volume_A = sitk.ReadImage(path_1)
volume_A = sitk.GetArrayFromImage(volume_A)
volume_B = sitk.ReadImage(path_2)
volume_B = sitk.GetArrayFromImage(volume_B)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
volume_A = self.normalize(volume_A)
volume_B = self.normalize(volume_B)
if self.min_crop_ratio is not None:
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
crop_size_A = int(min(volume_A.shape) * crop_ratio)
crop_size_B = int(min(volume_B.shape) * crop_ratio)
volume_A = self.random_crop_3d(volume_A, crop_size_A)
volume_B = self.random_crop_3d(volume_B, crop_size_B)
volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
else:
volume_A = self.radndom_crop_3d(volume_A, self.out_sz)
volume_B = self.radndom_crop_3d(volume_B, self.out_sz)
volume_A = volume_A[None, :, :, :]
volume_B = volume_B[None, :, :, :]
if self.transform is not None:
return self.transform(volume_A), self.transform(volume_B)
return volume_A, volume_B
elif mode == 'original':
pass
def build_batch(self, batch_size = 2):
batch_1 = []
batch_2 = []
for i in range(batch_size):
V_a, V_b = self.get_random_2_items()
batch_1.append(V_a)
batch_2.append(V_b)
return np.array(batch_1), np.array(batch_2)
class OminiDataset_inference_w_all(object):
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = None, label_key = ['brain'], task_key = 'segmentation', database = None, select_channels_dict = {}):
self.mappings = mapping_files
# database=['MSD', 'TotalSegmentor']
if database is not None:
self.mappings = {db: self.mappings[db] for db in database if db in self.mappings}
# select_channels_dict={
# "ImgDict":["ed","es"]
# }
self.select_channels_dict = select_channels_dict
self.ALLdata = self.combine_data(mappings=self.mappings)
self.out_sz = out_sz
self.label_key = label_key
self.min_crop_ratio = min_crop_ratio
self.transform = transform
self.clamp_range = clamp_range
self.ndims = 3
self.is_reverse_axis_order = True # for inference, always reverse axis order (nifty is reverse order than numpy)
# Start you filtering here:
# self.ALLdata_filtered = self.ALLdata.copy()
# filter out images with dimensions less than min_dim
self.ALLdata_filtered = self.get_filter_mindim()
# filter out images with ROIs that are not in the provided ROIs
if ROIs is None:
self.ROIs = self.get_all_ROI()
else:
self.ROIs = ROIs
self.ALLdata_filtered = self.get_filter_ROIs()
self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key)
# filtering ends here
self.roi_scan_mapping = self.build_ROI_scan_mapping()
self.keys_dist, self.total = self.get_keys_dist()
def get_all_ROI(self):
# Get all the ROI options. and remove the reduntant ones
ROIs = []
for k in self.ALLdata_filtered.keys():
ROIs.append(self.ALLdata[k]['ROI'])
ROIs = set(ROIs)
return ROIs
def get_keys_dist(self):
ROIs = self.get_all_ROI()
keys_dist = {}
total = 0
for item in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[item]['ROI'] not in keys_dist:
keys_dist[self.ALLdata_filtered[item]['ROI']] = 0
keys_dist[self.ALLdata_filtered[item]['ROI']] += 1
return keys_dist, total
def build_ROI_scan_mapping(self):
# Build a mapping of ROIs to scans
ROI_scan_mapping = {}
for item in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping:
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = []
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item)
return ROI_scan_mapping
def get_3D_volume(self, volume, select_channel = None):
volume = reverse_axis_order(volume) if self.is_reverse_axis_order else volume
if volume.ndim == 4:
if select_channel is None:
select_channel = np.random.randint(0, volume.shape[3] - 1)
volume = volume[:, :, :, select_channel]
# print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
return volume
def get_filter_mindim(self):
# Filter out images with dimensions less than min_dim
# Top priority is to filter out images with dimensions less than min_dim
ALLdata = self.ALLdata.copy()
for k in self.ALLdata.keys():
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
del ALLdata[k]
return ALLdata
def find_min_dim(self):
# Find the minimum dimension of the images
min_dim = 100000
for k in self.ALLdata.keys():
value = self.ALLdata[k]
if min(value['Size']) < min_dim:
min_dim = min(value['Size'])
return min_dim
# def combine_data(self):
# ALLdata = {}
# for j in self.mappings.keys():
# with open(self.mappings[j], 'r') as f:
# mappings = json.load(f)
# ALLdata.update(mappings)
# return ALLdata
def combine_data(self, mappings = mapping_files):
ALLdata = {}
total_entries = 0
total_skipped = 0
for j in mappings.keys():
with open(mappings[j], 'r') as f:
mappings_tmp = json.load(f)
skipped = 0
for k, v in mappings_tmp.items():
if not os.path.exists(k) or os.path.getsize(k) == 0:
skipped += 1
continue
ALLdata[k] = v
accessible = len(mappings_tmp) - skipped
total_entries += len(mappings_tmp)
total_skipped += skipped
if skipped > 0:
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
if total_skipped > 0:
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
if len(ALLdata) < 1000:
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
return ALLdata
def normalize(self, volume, eps=1e-7):
# Normalize the image (0-1)
volume = volume.astype(np.float64)
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
return volume
def get_key_by_ROI(self, key_word):
# Get all the keys with a key word
keys = []
for k in self.ALLdata_filtered.keys():
if key_word == self.ALLdata_filtered[k]["ROI"]:
keys.append(k)
return keys
def get_filter_task(self, task_key = 'segmentation'):
# Filter out images with task type that are not in the provided labels_path
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if 'Label_path' not in self.ALLdata_filtered[k] or task_key not in self.ALLdata_filtered[k]['Label_path']:
del ALLdata_filtered[k]
Warning(f"Label path not found for {k} with task key {task_key}. This image will be removed from the dataset.")
return ALLdata_filtered
def get_filter_labels(self, task_key='segmentation', label_keys=['heart']):
# Filter out images where 'Label_path' does not contain any of the label_keys for the given task_key
ALLdata_filtered = self.ALLdata_filtered.copy()
keys_to_remove = []
for k in list(ALLdata_filtered.keys()):
label_path = ALLdata_filtered[k].get('Label_path', {})
task_labels = label_path.get(task_key, {})
# Check if any label_keys are present in task_labels
# print(f"Checking {k} for task key {task_labels.keys()} with label keys {label_keys}")
has_any_label = any((tk in label_keys) for tk in task_labels.keys())
# print(f"Has any label: {has_any_label}")
if not has_any_label:
keys_to_remove.append(k)
# print(f"Label path not found for {k} with task key {task_key} and label keys {label_keys}. This image will be removed from the dataset.")
for k in keys_to_remove:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_random_pad_crop_params(self, volume_shape, crop_size=None, random=True):
# Get random padding and cropping parameters for a given shape
d, h, w = volume_shape[:3]
if crop_size is None:
crop_size = self.out_sz
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
# Calculate padding
pad_width = []
for size, crop in zip((d, h, w), (crop_d, crop_h, crop_w)):
if crop > size:
total_pad = crop - size
pad_before = np.random.randint(0, total_pad + 1)
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
else:
pad_width.append((0, 0))
# Update shape after padding
d_p, h_p, w_p = d + pad_width[0][0] + pad_width[0][1], h + pad_width[1][0] + pad_width[1][1], w + pad_width[2][0] + pad_width[2][1]
if random:
# Calculate cropping start indices (random crop)
start_d = np.random.randint(0, d_p - crop_d + 1) if d_p > crop_d else 0
start_h = np.random.randint(0, h_p - crop_h + 1) if h_p > crop_h else 0
start_w = np.random.randint(0, w_p - crop_w + 1) if w_p > crop_w else 0
else:
# Calculate cropping start indices (center crop)
start_d = max((d_p - crop_d) // 2, 0)
start_h = max((h_p - crop_h) // 2, 0)
start_w = max((w_p - crop_w) // 2, 0)
crop_slices = (start_d, start_h, start_w, crop_d, crop_h, crop_w)
return pad_width, crop_slices
def apply_pad_crop(self, volume, pad_width, crop_slices):
# Apply padding and cropping to the volume
if any(pad != (0, 0) for pad in pad_width):
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
start_d, start_h, start_w, crop_d, crop_h, crop_w = crop_slices
cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
return cropped_array
def get_filter_ROIs(self):
ALLdata_filtered = self.ALLdata_filtered.copy()
for k in self.ALLdata_filtered.keys():
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
del ALLdata_filtered[k]
return ALLdata_filtered
def get_channel_ids(self, key):
"""
Get the indices where ImgDict values match the selected channels (e.g., 'ed', 'es').
Returns:
list: List of integer indices matching the selected channels
"""
img_dict = self.ALLdata_filtered[key].get("ImgDict", {})
selected_values = self.select_channels_dict.get("ImgDict", [])
# Build reverse mapping: value -> index
value_to_idx = {value: int(idx) for idx, value in img_dict.items()}
# Get indices in the order of selected_values
indices = [
value_to_idx[val] for val in selected_values
if val in value_to_idx
]
return indices
# return sorted(indices)
def __len__(self):
return len(self.ALLdata_filtered.keys())
def __getitem__(self, idx):
key = list(self.ALLdata_filtered.keys())[idx]
return_dict = dict()
print(f"Processing key: {key}")
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
if volume.ndim == 4:
channel_ids = self.get_channel_ids(key)
if len(channel_ids) == 0:
# warning message that this key has no matching channels
Warning(f"No matching channels found for key: {key} with ImgDict: {self.ALLdata_filtered[key].get('ImgDict', {})} and selected channels: {self.select_channels_dict.get('ImgDict', [])}. Using random channel.")
channel_id = None
else:
channel_id=channel_ids[0]
volume = self.get_3D_volume(volume, select_channel = channel_id)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
volume = self.normalize(volume)
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
crop_size = int(max(volume.shape) * crop_ratio)
pad_width, crop_slices = self.get_random_pad_crop_params(volume.shape, crop_size)
# print(f"Pad width: {pad_width}, Crop slices: {crop_slices}, Original shape: {volume.shape}")
volume = self.apply_pad_crop(volume, pad_width, crop_slices)
label_dict = dict()
if 'Label_path' in self.ALLdata_filtered[key]:
for lk in self.label_key:
if lk in self.ALLdata_filtered[key]['Label_path']['segmentation'].keys():
label = sitk.ReadImage(self.ALLdata_filtered[key]['Label_path']['segmentation'][lk])
label = sitk.GetArrayFromImage(label)
# print(f"Label shape: {label.shape}, key: {key}, label key: {lk}")
label = reverse_axis_order(label) if self.is_reverse_axis_order else label
# print(f"Label shape: {label.shape}, key: {key}, label key: {lk}")
if label.ndim > self.ndims:
if len(channel_ids) != 0:
label = label[...,channel_ids] # assuming channel last
pad_width_lab = pad_width + [(0,0)]*(label.ndim - self.ndims)
# print(f"Label with channels, pad_width_lab: {pad_width_lab}")
else:
pad_width_lab = pad_width
label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
# print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
if label.ndim > self.ndims:
if self.ndims==3:
label_dict[lk] = np.transpose(label_dict[lk], (3,0,1,2)) # assuming channel last
elif self.ndims==4:
label_dict[lk] = np.transpose(label_dict[lk], (4,0,1,2,3)) # assuming channel last
# print(f"After resize, label shape: {label_dict[lk].shape}, key: {key}, label key: {lk}")
else:
label_dict[lk] = np.full([self.out_sz]*self.ndims, -1)
Warning(f"Label path not found for {key} with label key {lk}.")
label_dict[lk] = label_dict[lk][None, :, :, :] if label_dict[lk].ndim == 3 else label_dict[lk]
else:
for lk in self.label_key:
label_dict[lk] = np.full([self.out_sz]*self.ndims, -1)
Warning(f"Label path not found for {key} with label key {lk}.")
label_dict[lk] = label_dict[lk][None, :, :, :]
volume =resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
# return_dict['labels'] = label_dict
return_dict['labels'] = np.concatenate([v for v in label_dict.values()], axis=1)
return_dict['img'] = volume[None, :, :, :]
return_dict['label_channels'] = list(self.select_channels_dict.get("ImgDict", []))
return return_dict
class OminiDataset_bertembd(OminiDataset):
def __init__(self,
out_sz = 128,
transform=None,
clamp_range = CLAMP_RANGE,
min_crop_ratio = 0.85,
ROIs = None,
modality = None,
reverse_axis_order = False,
min_dim = 3,
mapping_files = mapping_files):
super().__init__(out_sz = out_sz,
transform = transform,
clamp_range = clamp_range,
min_crop_ratio = min_crop_ratio,
ROIs = ROIs,
modality = modality,
reverse_axis_order = reverse_axis_order,
min_dim = min_dim,
mapping_files=mapping_files)
# start you filtering here
self.ALLdata_filtered = self.get_filter_mindim()
if ROIs is None:
# if no ROIs are provided, get all the ROIs from filtered data
self.ROIs = self.get_all_ROI()
else:
self.ROIs = ROIs
self.ALLdata_filtered = self.get_filter_ROIs()
# self.ALLdata_filtered = self.filter_embd()
# self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key)
# end your filtering here
def __getitem__(self, idx):
key = list(self.ALLdata_filtered.keys())[idx]
embd = self.ALLdata_filtered[key]['embd']
if 0:
print(key)
volume = sitk.ReadImage(key)
volume = sitk.GetArrayFromImage(volume)
volume = self.get_3D_volume(volume)
if self.clamp_range is not None:
modality = self.ALLdata_filtered[key].get("Modality", None)
if modality == "CT":
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
volume = self.normalize(volume)
if self.min_crop_ratio is not None:
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
crop_size = int(max(volume.shape) * crop_ratio)
volume = self.random_crop_3d(volume, crop_size)
volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
else:
volume = self.random_crop_3d(volume, self.out_sz)
volume = volume[None, :, :, :]
if self.transform is not None:
return self.transform(volume)
return volume,np.array(embd)
def __len__(self):
return len(self.ALLdata_filtered.keys())
def filter_embd(self):
for k in self.ALLdata_filtered.keys():
if 'BERT_embedding_keys' not in self.ALLdata_filtered[k]['Metadata']:
del self.ALLdata_filtered[k]
return self.ALLdata_filtered