BranchSBM / train /main_branches.py
sophtang's picture
update
a03ffb8 verified
import sys
sys.path.append("./BranchSBM")
import os
import sys
import argparse
import copy
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
from torchcfm.optimal_transport import OTPlanSampler
from branchsbm.branchsbm import BranchSBM
from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
from branchsbm.branch_interpolant_train import BranchInterpolantTrain
from branchsbm.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
from dataloaders.trajectory_data import TemporalDataModule
from dataloaders.mouse_data import WeightedBranchedCellDataModule
from dataloaders.three_branch_data import ThreeBranchTahoeDataModule
from dataloaders.clonidine_v2_data import ClonidineV2DataModule
from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule
from dataloaders.trametinib_single import TrametinibSingleBranchDataModule
from dataloaders.lidar_data import WeightedBranchedLidarDataModule
from dataloaders.lidar_data_single import LidarSingleDataModule
from networks.flow_mlp import VelocityNet
from networks.growth_mlp import GrowthNet
from networks.interpolant_mlp import GeoPathMLP
from utils import set_seed
from train.parsers import parse_args
from branchsbm.ema import EMA
from train.train_utils import (
load_config,
merge_config,
generate_group_string,
dataset_name2datapath,
create_callbacks,
)
from state_costs.metric_factory import DataManifoldMetric
import torch.nn as nn
def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
set_seed(seed)
branches = args.branches
skipped_time_points = [t_exclude] if t_exclude else []
### DATAMODULES ###
if args.data_name == "lidar":
datamodule = WeightedBranchedLidarDataModule(args=args)
elif args.data_name == "lidarsingle":
datamodule = LidarSingleDataModule(args=args)
elif args.data_name == "mouse":
datamodule = WeightedBranchedCellDataModule(args=args)
elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]:
datamodule = ClonidineV2DataModule(args=args)
elif args.data_name == "clonidine50Dsingle":
datamodule = ClonidineSingleBranchDataModule(args=args)
elif args.data_name == "trametinib":
datamodule = ThreeBranchTahoeDataModule(args=args)
elif args.data_name == "trametinibsingle":
datamodule = TrametinibSingleBranchDataModule(args=args)
flow_nets = nn.ModuleList()
geopath_nets = nn.ModuleList()
growth_nets = nn.ModuleList()
##### initialize branched flow and growth networks #####
for i in range(branches):
flow_net = VelocityNet(
dim=args.dim,
hidden_dims=args.hidden_dims_flow,
activation=args.activation_flow,
batch_norm=False,
)
geopath_net = GeoPathMLP(
input_dim=args.dim,
hidden_dims=args.hidden_dims_geopath,
time_geopath=args.time_geopath,
activation=args.activation_geopath,
batch_norm=False,
)
if i == 0:
growth_net = GrowthNet(
dim=args.dim,
hidden_dims=args.hidden_dims_growth,
activation=args.activation_growth,
batch_norm=False,
negative=True
)
else:
growth_net = GrowthNet(
dim=args.dim,
hidden_dims=args.hidden_dims_growth,
activation=args.activation_growth,
batch_norm=False,
negative=False
)
if args.ema_decay is not None:
flow_net = EMA(model=flow_net, decay=args.ema_decay)
geopath_net = EMA(model=geopath_net, decay=args.ema_decay)
growth_net = EMA(model=growth_net, decay=args.ema_decay)
flow_nets.append(flow_net)
geopath_nets.append(geopath_net)
growth_nets.append(growth_net)
ot_sampler = (
OTPlanSampler(method=args.optimal_transport_method)
if args.optimal_transport_method != "None"
else None
)
wandb.init(
project=f"branchsbm-{args.data_name}-{branches}-branches",
group=args.group_name,
config=vars(args),
dir=args.working_dir,
)
flow_matcher_base = BranchSBM(
geopath_nets=geopath_nets,
sigma=args.sigma,
alpha=int(args.branchsbm),
)
##### STAGE 1: Training of Geodesic Interpolants Beginning #####
geopath_callbacks = create_callbacks(
args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id
)
# define state cost
data_manifold_metric = DataManifoldMetric(
args=args,
skipped_time_points=skipped_time_points,
datamodule=datamodule,
)
geopath_model = BranchInterpolantTrain(
flow_matcher=flow_matcher_base,
skipped_time_points=skipped_time_points,
ot_sampler=ot_sampler,
args=args,
data_manifold_metric=data_manifold_metric
)
wandb_logger = WandbLogger()
trainer = Trainer(
max_epochs=args.epochs,
callbacks=geopath_callbacks,
accelerator=args.accelerator,
logger=wandb_logger,
num_sanity_val_steps=0,
default_root_dir=args.working_dir,
gradient_clip_val=(1.0 if args.data_type == "image" else None),
)
if args.load_geopath_model_ckpt:
best_model_path = args.load_geopath_model_ckpt
else:
trainer.fit(
geopath_model,
datamodule=datamodule,
)
best_model_path = geopath_callbacks[0].best_model_path
geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path)
flow_matcher_base.geopath_nets = geopath_model.geopath_nets
##### STAGE 1: Training of Geodesic Interpolants End #####
##### STAGE 2: Flow Matching Beginning #####
flow_callbacks = create_callbacks(
args,
phase="flow",
data_type=args.data_type,
run_id=wandb.run.id,
datamodule=datamodule,
)
if args.data_type == "lidar":
FlowNetTrain = FlowNetTrainLidar
else:
FlowNetTrain = FlowNetTrainCell
flow_train = FlowNetTrain(
flow_matcher=flow_matcher_base,
flow_nets=flow_nets,
ot_sampler=ot_sampler,
skipped_time_points=skipped_time_points,
args=args,
)
wandb_logger = WandbLogger()
trainer = Trainer(
max_epochs=args.epochs,
callbacks=flow_callbacks,
check_val_every_n_epoch=args.check_val_every_n_epoch,
accelerator=args.accelerator,
logger=wandb_logger,
default_root_dir=args.working_dir,
gradient_clip_val=(1.0 if args.data_type == "image" else None),
num_sanity_val_steps=(0 if args.data_type == "image" else None),
)
trainer.fit(
flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt
)
if args.data_type == "lidar":
trainer.test(flow_train, datamodule=datamodule)
##### STAGE 2: Flow Matching End #####
##### STAGE 3: Training Growth Networks Beginning ####
flow_nets = flow_train.flow_nets
growth_callbacks = create_callbacks(
args,
phase="growth",
data_type=args.data_type,
run_id=wandb.run.id,
datamodule=datamodule,
)
if args.data_type == "lidar":
GrowthNetTrain = GrowthNetTrainLidar
else:
GrowthNetTrain = GrowthNetTrainCell
growth_train = GrowthNetTrain(
flow_nets = flow_nets,
growth_nets = growth_nets,
ot_sampler=ot_sampler,
skipped_time_points=skipped_time_points,
args=args,
data_manifold_metric=data_manifold_metric,
joint = False
)
wandb_logger = WandbLogger()
trainer = Trainer(
max_epochs=args.epochs,
callbacks=growth_callbacks,
check_val_every_n_epoch=args.check_val_every_n_epoch,
accelerator=args.accelerator,
logger=wandb_logger,
default_root_dir=args.working_dir,
gradient_clip_val=(1.0 if args.data_type == "image" else None),
num_sanity_val_steps=(0 if args.data_type == "image" else None),
)
trainer.fit(
growth_train, datamodule=datamodule, ckpt_path=None
)
trainer.test(growth_train, datamodule=datamodule)
##### STAGE 3: Training Growth Networks End ####
##### STAGE 4: Joint Training Beginning ####
growth_nets = growth_train.growth_nets
joint_callbacks = create_callbacks(
args,
phase="joint",
data_type=args.data_type,
run_id=wandb.run.id,
datamodule=datamodule,
)
if args.data_type == "lidar":
GrowthNetTrain = GrowthNetTrainLidar
else:
GrowthNetTrain = GrowthNetTrainCell
joint_train = GrowthNetTrain(
flow_nets = flow_nets,
growth_nets = growth_nets,
ot_sampler=ot_sampler,
skipped_time_points=skipped_time_points,
args=args,
data_manifold_metric=data_manifold_metric,
joint = True
)
wandb_logger = WandbLogger()
trainer = Trainer(
max_epochs=args.epochs,
callbacks=joint_callbacks,
check_val_every_n_epoch=args.check_val_every_n_epoch,
accelerator=args.accelerator,
logger=wandb_logger,
default_root_dir=args.working_dir,
gradient_clip_val=(1.0 if args.data_type == "image" else None),
num_sanity_val_steps=(0 if args.data_type == "image" else None),
)
trainer.fit(
joint_train, datamodule=datamodule, ckpt_path=None
)
trainer.test(joint_train, datamodule=datamodule)
##### STAGE 4: Joint Training End ####
wandb.finish()
if __name__ == "__main__":
args = parse_args()
updated_args = copy.deepcopy(args)
if args.config_path:
config = load_config(args.config_path)
updated_args = merge_config(updated_args, config)
updated_args.group_name = generate_group_string()
updated_args.data_path = dataset_name2datapath(
updated_args.data_name, updated_args.working_dir
)
for seed in updated_args.seeds:
if updated_args.t_exclude:
for i, t_exclude in enumerate(updated_args.t_exclude):
updated_args.t_exclude_current = t_exclude
updated_args.seed_current = seed
updated_args.gamma_current = updated_args.gammas[i]
main(updated_args, seed=seed, t_exclude=t_exclude)
else:
updated_args.seed_current = seed
updated_args.gamma_current = updated_args.gammas[0]
main(updated_args, seed=seed, t_exclude=None)