OmniMorph

Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).

OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting:

  • Generation β€” text-conditioned image synthesis via BERT embeddings.
  • Restoration β€” recover anatomically plausible images from degraded inputs.
  • Registration β€” paired / unpaired / flexible-resolution registration via diffused deformation vector fields.

Repository Contents

Path Description
OM_train*.py Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU)
OM_aug*.py, OM_reg*.py, OM_contrastive*.py Inference / augmentation / registration / contrastive scripts
Diffusion/ DeformDDPM core: diffuser.py, networks, losses, spatial utils
OMorpher/ Higher-level model wrapper
Dataloader/ Multi-modality dataloaders + dataset mappings (16 datasets)
Config/ YAML training/inference configs
Scripts/ Auxiliary scripts (registration, evaluation)
tests/ Pytest suite for OMorpher and loss functions
bash_*.sh, *.slurm SLURM submission scripts (CUDA + Intel XPU/Dawn)
Models/all_om_net/000110_all_om_net.pth Trained checkpoint β€” production multi-modal recmulmodmutattnnet (epoch 110, ~3.0 GB)
Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth Earlier recmulmodmutattnnet run (epoch 10, ~906 MB)

Note Only the final checkpoint of each training run is shipped β€” intermediate epochs and the bert_large_uncased weights are not bundled. Download bert-large-uncased from the official Hugging Face repo if you need the contrastive text encoder.

Setup

git clone https://huggingface.co/DRDMsig/Omini3D
cd Omini3D
pip install -r requirements.txt

For Intel XPU / Dawn cluster, install the matching intel-extension-for-pytorch build before installing the rest of the requirements.

Quick Start

Training

# Single-mode diffusion
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml

# Dual mode (diffusion + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml

# Triple mode (diffusion + contrastive + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml

# Intel XPU (single node)
sbatch bash_train_single_node.sh

Inference

# Augmentation / restoration with a trained model
python OM_aug.py -C Config/config_om.yaml

# Paired registration
python OM_reg.py -C Config/config_om.yaml

# Flexible-resolution registration
python OM_reg_flexres.py -C Config/config_om.yaml

Loading the checkpoint

import torch
from Diffusion.networks import get_net

# Production network (multi-modal recmulmodmutattnnet)
net = get_net("recmulmodmutattnnet")

# Production checkpoint (epoch 110)
ckpt_path = "Models/all_om_net/000110_all_om_net.pth"
# Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth"

state = torch.load(ckpt_path, map_location="cpu")
net.load_state_dict(state["model"] if "model" in state else state)
net.eval()

Architecture

Config YAML β†’ DataLoader(s) β†’ DeformDDPM(Network, STN) β†’ Loss β†’ Checkpoint
  • DeformDDPM (Diffusion/diffuser.py) β€” forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios [4, 8, 16, 32, 64].
  • Networks (Diffusion/networks.py) β€” selectable via get_net(name):
    • recmulmodmutattnnet β€” current production multi-modal multi-head-attention net (used by 000110_all_om_net.pth)
    • recmutattnnet, recmutattnnet_contrastive, recresacnet, defrecmutattnnet
  • STN β€” Spatial Transformer for differentiable warping; composes deformations as comp_ddf = dvf + stn(ddf, dvf).
  • Losses (Diffusion/losses.py, losses_ncc0.py) β€” Grad, LNCC, LMSE, NCC, MRSE, RMSE.

Datasets Supported

Dataloader/nifty_mappings/ contains pre-computed mappings for 16 public medical-imaging datasets, including: AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA.

The dataset files themselves are not included; obtain them from their respective sources and update the mapping paths.

Citation

@article{omnimorph,
  title  = {OmniMorph: Deform All-in-One Framework for Medical Image Generation,
            Restoration and Registration via Conditional Deformation-Recovery
            Diffusion Models},
  author = {Zheng, J. and Mo, M. and others},
  year   = {2025}
}

License

MIT β€” see LICENSE.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support