PepTune / src /config.py
Sophia Tang
initial commit
40e7e76
import argparse
import os
def get_parser():
parser = argparse.ArgumentParser(description='PepTune Training and Evaluation')
# Noise parameters
noise_group = parser.add_argument_group('noise')
noise_group.add_argument('--noise-type', type=str, default='loglinear',
help='Type of noise schedule')
noise_group.add_argument('--sigma-min', type=float, default=1e-4,
help='Minimum sigma value')
noise_group.add_argument('--sigma-max', type=float, default=20,
help='Maximum sigma value')
noise_group.add_argument('--state-dependent', action='store_true', default=True,
help='Use state-dependent noise')
# Base parameters
parser.add_argument('--base-path', type=str, default='/path/to/PepTune',
help='Base path to PepTune')
parser.add_argument('--mode', type=str, default='ppl_eval',
choices=['train', 'ppl_eval', 'sample_eval'],
help='Running mode')
parser.add_argument('--diffusion', type=str, default='absorbing_state',
help='Diffusion type')
parser.add_argument('--vocab', type=str, default='old_smiles',
choices=['old_smiles', 'new_smiles', 'selfies', 'helm'],
help='Vocabulary type')
parser.add_argument('--backbone', type=str, default='roformer',
choices=['peptideclm', 'helmgpt', 'dit', 'roformer', 'finetune_roformer'],
help='Model backbone')
parser.add_argument('--parameterization', type=str, default='subs',
help='Parameterization type')
parser.add_argument('--time-conditioning', action='store_true', default=False,
help='Use time conditioning')
parser.add_argument('--T', type=int, default=0,
help='Number of diffusion steps (0 for continuous time, 1000 for discrete)')
parser.add_argument('--subs-masking', action='store_true', default=False,
help='Use substitution masking')
parser.add_argument('--seed', type=int, default=42,
help='Random seed')
# MCTS parameters
mcts_group = parser.add_argument_group('mcts')
mcts_group.add_argument('--mcts-num-children', type=int, default=50,
help='Number of children in MCTS')
mcts_group.add_argument('--mcts-num-objectives', type=int, default=5,
help='Number of objectives in MCTS')
mcts_group.add_argument('--mcts-topk', type=int, default=100,
help='Top-k for MCTS')
mcts_group.add_argument('--mcts-mask-token', type=int, default=4,
help='Mask token ID')
mcts_group.add_argument('--mcts-num-iter', type=int, default=128,
help='Number of MCTS iterations')
mcts_group.add_argument('--mcts-sampling', type=int, default=0,
help='Sampling strategy (0 for gumbel, >0 for top-k)')
mcts_group.add_argument('--mcts-invalid-penalty', type=float, default=0.5,
help='Penalty for invalid sequences')
mcts_group.add_argument('--mcts-sample-prob', type=float, default=1.0,
help='Sampling probability')
mcts_group.add_argument('--mcts-perm', action='store_true', default=True,
help='Use permutation in MCTS')
mcts_group.add_argument('--mcts-dual', action='store_true', default=False,
help='Use dual mode')
mcts_group.add_argument('--mcts-single', action='store_true', default=False,
help='Use single mode')
mcts_group.add_argument('--mcts-time-dependent', action='store_true', default=True,
help='Use time-dependent MCTS')
# Data parameters
data_group = parser.add_argument_group('data')
data_group.add_argument('--train-data', type=str,
default='/path/to/your/home/PepTune/data/peptide_data',
help='Path to training data')
data_group.add_argument('--valid-data', type=str,
default='/path/to/your/home/PepTune/data/peptide_data',
help='Path to validation data')
data_group.add_argument('--data-batching', type=str, default='wrapping',
choices=['padding', 'wrapping'],
help='Batching strategy')
# Loader parameters
loader_group = parser.add_argument_group('loader')
loader_group.add_argument('--global-batch-size', type=int, default=64,
help='Global batch size')
loader_group.add_argument('--eval-global-batch-size', type=int, default=None,
help='Evaluation global batch size (defaults to global-batch-size)')
loader_group.add_argument('--num-workers', type=int, default=None,
help='Number of dataloader workers (defaults to available CPUs)')
loader_group.add_argument('--pin-memory', action='store_true', default=True,
help='Pin memory for dataloaders')
# Sampling parameters
sampling_group = parser.add_argument_group('sampling')
sampling_group.add_argument('--predictor', type=str, default='ddpm_cache',
choices=['analytic', 'ddpm', 'ddpm_cache'],
help='Predictor type for sampling')
sampling_group.add_argument('--num-sequences', type=int, default=100,
help='Number of sequences to generate')
sampling_group.add_argument('--sampling-eps', type=float, default=1e-3,
help='Sampling epsilon')
sampling_group.add_argument('--steps', type=int, default=128,
help='Number of sampling steps')
sampling_group.add_argument('--seq-length', type=int, default=100,
help='Sequence length')
sampling_group.add_argument('--noise-removal', action='store_true', default=True,
help='Use noise removal')
sampling_group.add_argument('--num-sample-batches', type=int, default=2,
help='Number of sample batches')
sampling_group.add_argument('--num-sample-log', type=int, default=2,
help='Number of samples to log')
sampling_group.add_argument('--stride-length', type=int, default=1,
help='Stride length for sampling')
sampling_group.add_argument('--num-strides', type=int, default=1,
help='Number of strides')
# Training parameters
training_group = parser.add_argument_group('training')
training_group.add_argument('--antithetic-sampling', action='store_true', default=True,
help='Use antithetic sampling')
training_group.add_argument('--training-sampling-eps', type=float, default=1e-3,
help='Training sampling epsilon')
training_group.add_argument('--focus-mask', action='store_true', default=False,
help='Use focus mask')
training_group.add_argument('--accumulator', action='store_true', default=False,
help='Use accumulator')
# Evaluation parameters
eval_group = parser.add_argument_group('eval')
eval_group.add_argument('--checkpoint-path', type=str, default=None,
help='Path to checkpoint for evaluation')
eval_group.add_argument('--disable-ema', action='store_true', default=False,
help='Disable EMA')
eval_group.add_argument('--compute-generative-perplexity', action='store_true', default=False,
help='Compute generative perplexity')
eval_group.add_argument('--perplexity-batch-size', type=int, default=8,
help='Batch size for perplexity computation')
eval_group.add_argument('--compute-perplexity-on-sanity', action='store_true', default=False,
help='Compute perplexity on sanity check')
eval_group.add_argument('--gen-ppl-eval-model', type=str, default='gpt2-large',
help='Model for generative perplexity evaluation')
eval_group.add_argument('--generate-samples', action='store_true', default=True,
help='Generate samples during evaluation')
eval_group.add_argument('--generation-model', type=str, default=None,
help='Model for generation')
# Optimizer parameters
optim_group = parser.add_argument_group('optim')
optim_group.add_argument('--weight-decay', type=float, default=0.075,
help='Weight decay')
optim_group.add_argument('--lr', type=float, default=3e-4,
help='Learning rate')
optim_group.add_argument('--beta1', type=float, default=0.9,
help='Adam beta1')
optim_group.add_argument('--beta2', type=float, default=0.999,
help='Adam beta2')
optim_group.add_argument('--eps', type=float, default=1e-8,
help='Adam epsilon')
# PepCLM model parameters
pepclm_group = parser.add_argument_group('pepclm')
pepclm_group.add_argument('--pepclm-hidden-size', type=int, default=768,
help='PepCLM hidden size')
pepclm_group.add_argument('--pepclm-cond-dim', type=int, default=256,
help='PepCLM conditioning dimension')
pepclm_group.add_argument('--pepclm-n-heads', type=int, default=20,
help='PepCLM number of attention heads')
pepclm_group.add_argument('--pepclm-n-blocks', type=int, default=4,
help='PepCLM number of blocks')
pepclm_group.add_argument('--pepclm-dropout', type=float, default=0.5,
help='PepCLM dropout rate')
pepclm_group.add_argument('--pepclm-length', type=int, default=512,
help='PepCLM sequence length')
# General model parameters
model_group = parser.add_argument_group('model')
model_group.add_argument('--model-type', type=str, default='ddit',
help='Model type')
model_group.add_argument('--hidden-size', type=int, default=768,
help='Model hidden size')
model_group.add_argument('--cond-dim', type=int, default=128,
help='Conditioning dimension')
model_group.add_argument('--length', type=int, default=512,
help='Sequence length')
model_group.add_argument('--n-blocks', type=int, default=12,
help='Number of blocks')
model_group.add_argument('--n-heads', type=int, default=12,
help='Number of attention heads')
model_group.add_argument('--scale-by-sigma', action='store_true', default=True,
help='Scale by sigma')
model_group.add_argument('--dropout', type=float, default=0.1,
help='Dropout rate')
# RoFormer parameters
roformer_group = parser.add_argument_group('roformer')
roformer_group.add_argument('--roformer-hidden-size', type=int, default=768,
help='RoFormer hidden size')
roformer_group.add_argument('--roformer-n-layers', type=int, default=8,
help='RoFormer number of layers')
roformer_group.add_argument('--roformer-n-heads', type=int, default=8,
help='RoFormer number of attention heads')
roformer_group.add_argument('--roformer-max-position-embeddings', type=int, default=1035,
help='RoFormer max position embeddings')
# HelmGPT parameters
helmgpt_group = parser.add_argument_group('helmgpt')
helmgpt_group.add_argument('--helmgpt-hidden-size', type=int, default=256,
help='HelmGPT hidden size')
helmgpt_group.add_argument('--helmgpt-embd-pdrop', type=float, default=0.1,
help='HelmGPT embedding dropout')
helmgpt_group.add_argument('--helmgpt-resid-pdrop', type=float, default=0.1,
help='HelmGPT residual dropout')
helmgpt_group.add_argument('--helmgpt-attn-pdrop', type=float, default=0.1,
help='HelmGPT attention dropout')
helmgpt_group.add_argument('--helmgpt-ff-dropout', type=float, default=0.0,
help='HelmGPT feedforward dropout')
helmgpt_group.add_argument('--helmgpt-block-size', type=int, default=140,
help='HelmGPT block size')
helmgpt_group.add_argument('--helmgpt-n-layer', type=int, default=8,
help='HelmGPT number of layers')
helmgpt_group.add_argument('--helmgpt-n-heads', type=int, default=8,
help='HelmGPT number of attention heads')
# Trainer parameters
trainer_group = parser.add_argument_group('trainer')
trainer_group.add_argument('--accelerator', type=str, default='cuda',
help='Accelerator type')
trainer_group.add_argument('--num-nodes', type=int, default=1,
help='Number of nodes')
trainer_group.add_argument('--devices', type=int, default=1,
help='Number of devices')
trainer_group.add_argument('--gradient-clip-val', type=float, default=1.0,
help='Gradient clipping value')
trainer_group.add_argument('--precision', type=str, default='64-true',
help='Training precision')
trainer_group.add_argument('--num-sanity-val-steps', type=int, default=2,
help='Number of sanity validation steps')
trainer_group.add_argument('--max-epochs', type=int, default=100,
help='Maximum number of epochs')
trainer_group.add_argument('--max-steps', type=int, default=1_000_000,
help='Maximum number of steps')
trainer_group.add_argument('--log-every-n-steps', type=int, default=10,
help='Log every n steps')
trainer_group.add_argument('--limit-train-batches', type=float, default=1.0,
help='Limit training batches')
trainer_group.add_argument('--limit-val-batches', type=float, default=1.0,
help='Limit validation batches')
trainer_group.add_argument('--check-val-every-n-epoch', type=int, default=1,
help='Check validation every n epochs')
# WandB parameters
wandb_group = parser.add_argument_group('wandb')
wandb_group.add_argument('--wandb-project', type=str, default='peptune',
help='WandB project name')
wandb_group.add_argument('--wandb-notes', type=str, default=None,
help='WandB notes')
wandb_group.add_argument('--wandb-group', type=str, default=None,
help='WandB group')
wandb_group.add_argument('--wandb-job-type', type=str, default=None,
help='WandB job type')
wandb_group.add_argument('--wandb-name', type=str, default='sophia-tang',
help='WandB run name')
wandb_group.add_argument('--wandb-id', type=str, default=None,
help='WandB run ID')
# Checkpointing parameters
checkpoint_group = parser.add_argument_group('checkpointing')
checkpoint_group.add_argument('--save-dir', type=str, default=None,
help='Directory to save checkpoints')
checkpoint_group.add_argument('--resume-from-ckpt', action='store_true', default=True,
help='Resume from checkpoint')
checkpoint_group.add_argument('--resume-ckpt-path', type=str, default=None,
help='Path to checkpoint to resume from')
checkpoint_group.add_argument('--checkpoint-every-n-epochs', type=int, default=1,
help='Save checkpoint every n epochs')
checkpoint_group.add_argument('--checkpoint-monitor', type=str, default='val/nll',
help='Metric to monitor for checkpointing')
checkpoint_group.add_argument('--checkpoint-save-top-k', type=int, default=10,
help='Save top k checkpoints')
checkpoint_group.add_argument('--checkpoint-mode', type=str, default='min',
choices=['min', 'max'],
help='Mode for checkpoint monitoring')
checkpoint_group.add_argument('--checkpoint-dirpath', type=str,
default='./checkpoints/11M-old-tokenizer',
help='Directory path for checkpoints')
# LR Scheduler parameters
scheduler_group = parser.add_argument_group('lr_scheduler')
scheduler_group.add_argument('--lr-warmup-steps', type=int, default=2500,
help='Number of warmup steps for learning rate')
return parser
def get_args():
"""Parse and return arguments."""
parser = get_parser()
args = parser.parse_args()
# Post-process arguments
if args.eval_global_batch_size is None:
args.eval_global_batch_size = args.global_batch_size
if args.num_workers is None:
args.num_workers = len(os.sched_getaffinity(0))
if args.wandb_id is None:
args.wandb_id = f"{args.wandb_name}_nov12_set2"
if args.save_dir is None:
args.save_dir = os.getcwd()
return args
if __name__ == '__main__':
args = get_args()
print(args)