|
|
import sys |
|
|
sys.path.append("./BranchSBM") |
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import copy |
|
|
from pytorch_lightning import Trainer |
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
import wandb |
|
|
import hydra |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from torchcfm.optimal_transport import OTPlanSampler |
|
|
|
|
|
from branchsbm.branchsbm import BranchSBM |
|
|
from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar |
|
|
from branchsbm.branch_interpolant_train import BranchInterpolantTrain |
|
|
from branchsbm.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar |
|
|
|
|
|
from dataloaders.trajectory_data import TemporalDataModule |
|
|
from dataloaders.mouse_data import WeightedBranchedCellDataModule |
|
|
from dataloaders.three_branch_data import ThreeBranchTahoeDataModule |
|
|
from dataloaders.clonidine_v2_data import ClonidineV2DataModule |
|
|
from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule |
|
|
from dataloaders.trametinib_single import TrametinibSingleBranchDataModule |
|
|
from dataloaders.lidar_data import WeightedBranchedLidarDataModule |
|
|
from dataloaders.lidar_data_single import LidarSingleDataModule |
|
|
|
|
|
from networks.flow_mlp import VelocityNet |
|
|
from networks.growth_mlp import GrowthNet |
|
|
from networks.interpolant_mlp import GeoPathMLP |
|
|
from utils import set_seed |
|
|
|
|
|
from train.parsers import parse_args |
|
|
from branchsbm.ema import EMA |
|
|
from train.train_utils import ( |
|
|
load_config, |
|
|
merge_config, |
|
|
generate_group_string, |
|
|
dataset_name2datapath, |
|
|
create_callbacks, |
|
|
) |
|
|
from state_costs.metric_factory import DataManifoldMetric |
|
|
import torch.nn as nn |
|
|
|
|
|
def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None: |
|
|
set_seed(seed) |
|
|
branches = args.branches |
|
|
|
|
|
skipped_time_points = [t_exclude] if t_exclude else [] |
|
|
|
|
|
|
|
|
if args.data_name == "lidar": |
|
|
datamodule = WeightedBranchedLidarDataModule(args=args) |
|
|
elif args.data_name == "lidarsingle": |
|
|
datamodule = LidarSingleDataModule(args=args) |
|
|
elif args.data_name == "mouse": |
|
|
datamodule = WeightedBranchedCellDataModule(args=args) |
|
|
elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]: |
|
|
datamodule = ClonidineV2DataModule(args=args) |
|
|
elif args.data_name == "clonidine50Dsingle": |
|
|
datamodule = ClonidineSingleBranchDataModule(args=args) |
|
|
elif args.data_name == "trametinib": |
|
|
datamodule = ThreeBranchTahoeDataModule(args=args) |
|
|
elif args.data_name == "trametinibsingle": |
|
|
datamodule = TrametinibSingleBranchDataModule(args=args) |
|
|
|
|
|
flow_nets = nn.ModuleList() |
|
|
geopath_nets = nn.ModuleList() |
|
|
growth_nets = nn.ModuleList() |
|
|
|
|
|
|
|
|
for i in range(branches): |
|
|
flow_net = VelocityNet( |
|
|
dim=args.dim, |
|
|
hidden_dims=args.hidden_dims_flow, |
|
|
activation=args.activation_flow, |
|
|
batch_norm=False, |
|
|
) |
|
|
geopath_net = GeoPathMLP( |
|
|
input_dim=args.dim, |
|
|
hidden_dims=args.hidden_dims_geopath, |
|
|
time_geopath=args.time_geopath, |
|
|
activation=args.activation_geopath, |
|
|
batch_norm=False, |
|
|
) |
|
|
|
|
|
if i == 0: |
|
|
growth_net = GrowthNet( |
|
|
dim=args.dim, |
|
|
hidden_dims=args.hidden_dims_growth, |
|
|
activation=args.activation_growth, |
|
|
batch_norm=False, |
|
|
negative=True |
|
|
) |
|
|
else: |
|
|
growth_net = GrowthNet( |
|
|
dim=args.dim, |
|
|
hidden_dims=args.hidden_dims_growth, |
|
|
activation=args.activation_growth, |
|
|
batch_norm=False, |
|
|
negative=False |
|
|
) |
|
|
|
|
|
if args.ema_decay is not None: |
|
|
flow_net = EMA(model=flow_net, decay=args.ema_decay) |
|
|
geopath_net = EMA(model=geopath_net, decay=args.ema_decay) |
|
|
growth_net = EMA(model=growth_net, decay=args.ema_decay) |
|
|
|
|
|
flow_nets.append(flow_net) |
|
|
geopath_nets.append(geopath_net) |
|
|
growth_nets.append(growth_net) |
|
|
|
|
|
|
|
|
ot_sampler = ( |
|
|
OTPlanSampler(method=args.optimal_transport_method) |
|
|
if args.optimal_transport_method != "None" |
|
|
else None |
|
|
) |
|
|
|
|
|
wandb.init( |
|
|
project=f"branchsbm-{args.data_name}-{branches}-branches", |
|
|
group=args.group_name, |
|
|
config=vars(args), |
|
|
dir=args.working_dir, |
|
|
) |
|
|
|
|
|
flow_matcher_base = BranchSBM( |
|
|
geopath_nets=geopath_nets, |
|
|
sigma=args.sigma, |
|
|
alpha=int(args.branchsbm), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
geopath_callbacks = create_callbacks( |
|
|
args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id |
|
|
) |
|
|
|
|
|
|
|
|
data_manifold_metric = DataManifoldMetric( |
|
|
args=args, |
|
|
skipped_time_points=skipped_time_points, |
|
|
datamodule=datamodule, |
|
|
) |
|
|
geopath_model = BranchInterpolantTrain( |
|
|
flow_matcher=flow_matcher_base, |
|
|
skipped_time_points=skipped_time_points, |
|
|
ot_sampler=ot_sampler, |
|
|
args=args, |
|
|
data_manifold_metric=data_manifold_metric |
|
|
) |
|
|
|
|
|
wandb_logger = WandbLogger() |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=args.epochs, |
|
|
callbacks=geopath_callbacks, |
|
|
accelerator=args.accelerator, |
|
|
logger=wandb_logger, |
|
|
num_sanity_val_steps=0, |
|
|
default_root_dir=args.working_dir, |
|
|
gradient_clip_val=(1.0 if args.data_type == "image" else None), |
|
|
) |
|
|
|
|
|
if args.load_geopath_model_ckpt: |
|
|
best_model_path = args.load_geopath_model_ckpt |
|
|
else: |
|
|
trainer.fit( |
|
|
geopath_model, |
|
|
datamodule=datamodule, |
|
|
) |
|
|
|
|
|
best_model_path = geopath_callbacks[0].best_model_path |
|
|
|
|
|
geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path) |
|
|
|
|
|
flow_matcher_base.geopath_nets = geopath_model.geopath_nets |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flow_callbacks = create_callbacks( |
|
|
args, |
|
|
phase="flow", |
|
|
data_type=args.data_type, |
|
|
run_id=wandb.run.id, |
|
|
datamodule=datamodule, |
|
|
) |
|
|
if args.data_type == "lidar": |
|
|
FlowNetTrain = FlowNetTrainLidar |
|
|
else: |
|
|
FlowNetTrain = FlowNetTrainCell |
|
|
|
|
|
flow_train = FlowNetTrain( |
|
|
flow_matcher=flow_matcher_base, |
|
|
flow_nets=flow_nets, |
|
|
ot_sampler=ot_sampler, |
|
|
skipped_time_points=skipped_time_points, |
|
|
args=args, |
|
|
) |
|
|
|
|
|
wandb_logger = WandbLogger() |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=args.epochs, |
|
|
callbacks=flow_callbacks, |
|
|
check_val_every_n_epoch=args.check_val_every_n_epoch, |
|
|
accelerator=args.accelerator, |
|
|
logger=wandb_logger, |
|
|
default_root_dir=args.working_dir, |
|
|
gradient_clip_val=(1.0 if args.data_type == "image" else None), |
|
|
num_sanity_val_steps=(0 if args.data_type == "image" else None), |
|
|
) |
|
|
|
|
|
trainer.fit( |
|
|
flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt |
|
|
) |
|
|
if args.data_type == "lidar": |
|
|
trainer.test(flow_train, datamodule=datamodule) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flow_nets = flow_train.flow_nets |
|
|
|
|
|
growth_callbacks = create_callbacks( |
|
|
args, |
|
|
phase="growth", |
|
|
data_type=args.data_type, |
|
|
run_id=wandb.run.id, |
|
|
datamodule=datamodule, |
|
|
) |
|
|
|
|
|
if args.data_type == "lidar": |
|
|
GrowthNetTrain = GrowthNetTrainLidar |
|
|
else: |
|
|
GrowthNetTrain = GrowthNetTrainCell |
|
|
|
|
|
growth_train = GrowthNetTrain( |
|
|
flow_nets = flow_nets, |
|
|
growth_nets = growth_nets, |
|
|
ot_sampler=ot_sampler, |
|
|
skipped_time_points=skipped_time_points, |
|
|
args=args, |
|
|
data_manifold_metric=data_manifold_metric, |
|
|
joint = False |
|
|
) |
|
|
|
|
|
wandb_logger = WandbLogger() |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=args.epochs, |
|
|
callbacks=growth_callbacks, |
|
|
check_val_every_n_epoch=args.check_val_every_n_epoch, |
|
|
accelerator=args.accelerator, |
|
|
logger=wandb_logger, |
|
|
default_root_dir=args.working_dir, |
|
|
gradient_clip_val=(1.0 if args.data_type == "image" else None), |
|
|
num_sanity_val_steps=(0 if args.data_type == "image" else None), |
|
|
) |
|
|
|
|
|
trainer.fit( |
|
|
growth_train, datamodule=datamodule, ckpt_path=None |
|
|
) |
|
|
trainer.test(growth_train, datamodule=datamodule) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
growth_nets = growth_train.growth_nets |
|
|
|
|
|
joint_callbacks = create_callbacks( |
|
|
args, |
|
|
phase="joint", |
|
|
data_type=args.data_type, |
|
|
run_id=wandb.run.id, |
|
|
datamodule=datamodule, |
|
|
) |
|
|
|
|
|
if args.data_type == "lidar": |
|
|
GrowthNetTrain = GrowthNetTrainLidar |
|
|
else: |
|
|
GrowthNetTrain = GrowthNetTrainCell |
|
|
|
|
|
joint_train = GrowthNetTrain( |
|
|
flow_nets = flow_nets, |
|
|
growth_nets = growth_nets, |
|
|
ot_sampler=ot_sampler, |
|
|
skipped_time_points=skipped_time_points, |
|
|
args=args, |
|
|
data_manifold_metric=data_manifold_metric, |
|
|
joint = True |
|
|
) |
|
|
|
|
|
wandb_logger = WandbLogger() |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=args.epochs, |
|
|
callbacks=joint_callbacks, |
|
|
check_val_every_n_epoch=args.check_val_every_n_epoch, |
|
|
accelerator=args.accelerator, |
|
|
logger=wandb_logger, |
|
|
default_root_dir=args.working_dir, |
|
|
gradient_clip_val=(1.0 if args.data_type == "image" else None), |
|
|
num_sanity_val_steps=(0 if args.data_type == "image" else None), |
|
|
) |
|
|
|
|
|
trainer.fit( |
|
|
joint_train, datamodule=datamodule, ckpt_path=None |
|
|
) |
|
|
trainer.test(joint_train, datamodule=datamodule) |
|
|
|
|
|
|
|
|
|
|
|
wandb.finish() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
updated_args = copy.deepcopy(args) |
|
|
if args.config_path: |
|
|
config = load_config(args.config_path) |
|
|
updated_args = merge_config(updated_args, config) |
|
|
|
|
|
updated_args.group_name = generate_group_string() |
|
|
updated_args.data_path = dataset_name2datapath( |
|
|
updated_args.data_name, updated_args.working_dir |
|
|
) |
|
|
for seed in updated_args.seeds: |
|
|
if updated_args.t_exclude: |
|
|
for i, t_exclude in enumerate(updated_args.t_exclude): |
|
|
updated_args.t_exclude_current = t_exclude |
|
|
updated_args.seed_current = seed |
|
|
updated_args.gamma_current = updated_args.gammas[i] |
|
|
main(updated_args, seed=seed, t_exclude=t_exclude) |
|
|
else: |
|
|
updated_args.seed_current = seed |
|
|
updated_args.gamma_current = updated_args.gammas[0] |
|
|
main(updated_args, seed=seed, t_exclude=None) |