| import argparse |
| import os |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser(description='PepTune Training and Evaluation') |
| |
| |
| noise_group = parser.add_argument_group('noise') |
| noise_group.add_argument('--noise-type', type=str, default='loglinear', |
| help='Type of noise schedule') |
| noise_group.add_argument('--sigma-min', type=float, default=1e-4, |
| help='Minimum sigma value') |
| noise_group.add_argument('--sigma-max', type=float, default=20, |
| help='Maximum sigma value') |
| noise_group.add_argument('--state-dependent', action='store_true', default=True, |
| help='Use state-dependent noise') |
| |
| |
| parser.add_argument('--base-path', type=str, default='/path/to/PepTune', |
| help='Base path to PepTune') |
| parser.add_argument('--mode', type=str, default='ppl_eval', |
| choices=['train', 'ppl_eval', 'sample_eval'], |
| help='Running mode') |
| parser.add_argument('--diffusion', type=str, default='absorbing_state', |
| help='Diffusion type') |
| parser.add_argument('--vocab', type=str, default='old_smiles', |
| choices=['old_smiles', 'new_smiles', 'selfies', 'helm'], |
| help='Vocabulary type') |
| parser.add_argument('--backbone', type=str, default='roformer', |
| choices=['peptideclm', 'helmgpt', 'dit', 'roformer', 'finetune_roformer'], |
| help='Model backbone') |
| parser.add_argument('--parameterization', type=str, default='subs', |
| help='Parameterization type') |
| parser.add_argument('--time-conditioning', action='store_true', default=False, |
| help='Use time conditioning') |
| parser.add_argument('--T', type=int, default=0, |
| help='Number of diffusion steps (0 for continuous time, 1000 for discrete)') |
| parser.add_argument('--subs-masking', action='store_true', default=False, |
| help='Use substitution masking') |
| parser.add_argument('--seed', type=int, default=42, |
| help='Random seed') |
| |
| |
| mcts_group = parser.add_argument_group('mcts') |
| mcts_group.add_argument('--mcts-num-children', type=int, default=50, |
| help='Number of children in MCTS') |
| mcts_group.add_argument('--mcts-num-objectives', type=int, default=5, |
| help='Number of objectives in MCTS') |
| mcts_group.add_argument('--mcts-topk', type=int, default=100, |
| help='Top-k for MCTS') |
| mcts_group.add_argument('--mcts-mask-token', type=int, default=4, |
| help='Mask token ID') |
| mcts_group.add_argument('--mcts-num-iter', type=int, default=128, |
| help='Number of MCTS iterations') |
| mcts_group.add_argument('--mcts-sampling', type=int, default=0, |
| help='Sampling strategy (0 for gumbel, >0 for top-k)') |
| mcts_group.add_argument('--mcts-invalid-penalty', type=float, default=0.5, |
| help='Penalty for invalid sequences') |
| mcts_group.add_argument('--mcts-sample-prob', type=float, default=1.0, |
| help='Sampling probability') |
| mcts_group.add_argument('--mcts-perm', action='store_true', default=True, |
| help='Use permutation in MCTS') |
| mcts_group.add_argument('--mcts-dual', action='store_true', default=False, |
| help='Use dual mode') |
| mcts_group.add_argument('--mcts-single', action='store_true', default=False, |
| help='Use single mode') |
| mcts_group.add_argument('--mcts-time-dependent', action='store_true', default=True, |
| help='Use time-dependent MCTS') |
| |
| |
| data_group = parser.add_argument_group('data') |
| data_group.add_argument('--train-data', type=str, |
| default='/path/to/your/home/PepTune/data/peptide_data', |
| help='Path to training data') |
| data_group.add_argument('--valid-data', type=str, |
| default='/path/to/your/home/PepTune/data/peptide_data', |
| help='Path to validation data') |
| data_group.add_argument('--data-batching', type=str, default='wrapping', |
| choices=['padding', 'wrapping'], |
| help='Batching strategy') |
| |
| |
| loader_group = parser.add_argument_group('loader') |
| loader_group.add_argument('--global-batch-size', type=int, default=64, |
| help='Global batch size') |
| loader_group.add_argument('--eval-global-batch-size', type=int, default=None, |
| help='Evaluation global batch size (defaults to global-batch-size)') |
| loader_group.add_argument('--num-workers', type=int, default=None, |
| help='Number of dataloader workers (defaults to available CPUs)') |
| loader_group.add_argument('--pin-memory', action='store_true', default=True, |
| help='Pin memory for dataloaders') |
| |
| |
| sampling_group = parser.add_argument_group('sampling') |
| sampling_group.add_argument('--predictor', type=str, default='ddpm_cache', |
| choices=['analytic', 'ddpm', 'ddpm_cache'], |
| help='Predictor type for sampling') |
| sampling_group.add_argument('--num-sequences', type=int, default=100, |
| help='Number of sequences to generate') |
| sampling_group.add_argument('--sampling-eps', type=float, default=1e-3, |
| help='Sampling epsilon') |
| sampling_group.add_argument('--steps', type=int, default=128, |
| help='Number of sampling steps') |
| sampling_group.add_argument('--seq-length', type=int, default=100, |
| help='Sequence length') |
| sampling_group.add_argument('--noise-removal', action='store_true', default=True, |
| help='Use noise removal') |
| sampling_group.add_argument('--num-sample-batches', type=int, default=2, |
| help='Number of sample batches') |
| sampling_group.add_argument('--num-sample-log', type=int, default=2, |
| help='Number of samples to log') |
| sampling_group.add_argument('--stride-length', type=int, default=1, |
| help='Stride length for sampling') |
| sampling_group.add_argument('--num-strides', type=int, default=1, |
| help='Number of strides') |
| |
| |
| training_group = parser.add_argument_group('training') |
| training_group.add_argument('--antithetic-sampling', action='store_true', default=True, |
| help='Use antithetic sampling') |
| training_group.add_argument('--training-sampling-eps', type=float, default=1e-3, |
| help='Training sampling epsilon') |
| training_group.add_argument('--focus-mask', action='store_true', default=False, |
| help='Use focus mask') |
| training_group.add_argument('--accumulator', action='store_true', default=False, |
| help='Use accumulator') |
| |
| |
| eval_group = parser.add_argument_group('eval') |
| eval_group.add_argument('--checkpoint-path', type=str, default=None, |
| help='Path to checkpoint for evaluation') |
| eval_group.add_argument('--disable-ema', action='store_true', default=False, |
| help='Disable EMA') |
| eval_group.add_argument('--compute-generative-perplexity', action='store_true', default=False, |
| help='Compute generative perplexity') |
| eval_group.add_argument('--perplexity-batch-size', type=int, default=8, |
| help='Batch size for perplexity computation') |
| eval_group.add_argument('--compute-perplexity-on-sanity', action='store_true', default=False, |
| help='Compute perplexity on sanity check') |
| eval_group.add_argument('--gen-ppl-eval-model', type=str, default='gpt2-large', |
| help='Model for generative perplexity evaluation') |
| eval_group.add_argument('--generate-samples', action='store_true', default=True, |
| help='Generate samples during evaluation') |
| eval_group.add_argument('--generation-model', type=str, default=None, |
| help='Model for generation') |
| |
| |
| optim_group = parser.add_argument_group('optim') |
| optim_group.add_argument('--weight-decay', type=float, default=0.075, |
| help='Weight decay') |
| optim_group.add_argument('--lr', type=float, default=3e-4, |
| help='Learning rate') |
| optim_group.add_argument('--beta1', type=float, default=0.9, |
| help='Adam beta1') |
| optim_group.add_argument('--beta2', type=float, default=0.999, |
| help='Adam beta2') |
| optim_group.add_argument('--eps', type=float, default=1e-8, |
| help='Adam epsilon') |
| |
| |
| pepclm_group = parser.add_argument_group('pepclm') |
| pepclm_group.add_argument('--pepclm-hidden-size', type=int, default=768, |
| help='PepCLM hidden size') |
| pepclm_group.add_argument('--pepclm-cond-dim', type=int, default=256, |
| help='PepCLM conditioning dimension') |
| pepclm_group.add_argument('--pepclm-n-heads', type=int, default=20, |
| help='PepCLM number of attention heads') |
| pepclm_group.add_argument('--pepclm-n-blocks', type=int, default=4, |
| help='PepCLM number of blocks') |
| pepclm_group.add_argument('--pepclm-dropout', type=float, default=0.5, |
| help='PepCLM dropout rate') |
| pepclm_group.add_argument('--pepclm-length', type=int, default=512, |
| help='PepCLM sequence length') |
| |
| |
| model_group = parser.add_argument_group('model') |
| model_group.add_argument('--model-type', type=str, default='ddit', |
| help='Model type') |
| model_group.add_argument('--hidden-size', type=int, default=768, |
| help='Model hidden size') |
| model_group.add_argument('--cond-dim', type=int, default=128, |
| help='Conditioning dimension') |
| model_group.add_argument('--length', type=int, default=512, |
| help='Sequence length') |
| model_group.add_argument('--n-blocks', type=int, default=12, |
| help='Number of blocks') |
| model_group.add_argument('--n-heads', type=int, default=12, |
| help='Number of attention heads') |
| model_group.add_argument('--scale-by-sigma', action='store_true', default=True, |
| help='Scale by sigma') |
| model_group.add_argument('--dropout', type=float, default=0.1, |
| help='Dropout rate') |
| |
| |
| roformer_group = parser.add_argument_group('roformer') |
| roformer_group.add_argument('--roformer-hidden-size', type=int, default=768, |
| help='RoFormer hidden size') |
| roformer_group.add_argument('--roformer-n-layers', type=int, default=8, |
| help='RoFormer number of layers') |
| roformer_group.add_argument('--roformer-n-heads', type=int, default=8, |
| help='RoFormer number of attention heads') |
| roformer_group.add_argument('--roformer-max-position-embeddings', type=int, default=1035, |
| help='RoFormer max position embeddings') |
| |
| |
| helmgpt_group = parser.add_argument_group('helmgpt') |
| helmgpt_group.add_argument('--helmgpt-hidden-size', type=int, default=256, |
| help='HelmGPT hidden size') |
| helmgpt_group.add_argument('--helmgpt-embd-pdrop', type=float, default=0.1, |
| help='HelmGPT embedding dropout') |
| helmgpt_group.add_argument('--helmgpt-resid-pdrop', type=float, default=0.1, |
| help='HelmGPT residual dropout') |
| helmgpt_group.add_argument('--helmgpt-attn-pdrop', type=float, default=0.1, |
| help='HelmGPT attention dropout') |
| helmgpt_group.add_argument('--helmgpt-ff-dropout', type=float, default=0.0, |
| help='HelmGPT feedforward dropout') |
| helmgpt_group.add_argument('--helmgpt-block-size', type=int, default=140, |
| help='HelmGPT block size') |
| helmgpt_group.add_argument('--helmgpt-n-layer', type=int, default=8, |
| help='HelmGPT number of layers') |
| helmgpt_group.add_argument('--helmgpt-n-heads', type=int, default=8, |
| help='HelmGPT number of attention heads') |
| |
| |
| trainer_group = parser.add_argument_group('trainer') |
| trainer_group.add_argument('--accelerator', type=str, default='cuda', |
| help='Accelerator type') |
| trainer_group.add_argument('--num-nodes', type=int, default=1, |
| help='Number of nodes') |
| trainer_group.add_argument('--devices', type=int, default=1, |
| help='Number of devices') |
| trainer_group.add_argument('--gradient-clip-val', type=float, default=1.0, |
| help='Gradient clipping value') |
| trainer_group.add_argument('--precision', type=str, default='64-true', |
| help='Training precision') |
| trainer_group.add_argument('--num-sanity-val-steps', type=int, default=2, |
| help='Number of sanity validation steps') |
| trainer_group.add_argument('--max-epochs', type=int, default=100, |
| help='Maximum number of epochs') |
| trainer_group.add_argument('--max-steps', type=int, default=1_000_000, |
| help='Maximum number of steps') |
| trainer_group.add_argument('--log-every-n-steps', type=int, default=10, |
| help='Log every n steps') |
| trainer_group.add_argument('--limit-train-batches', type=float, default=1.0, |
| help='Limit training batches') |
| trainer_group.add_argument('--limit-val-batches', type=float, default=1.0, |
| help='Limit validation batches') |
| trainer_group.add_argument('--check-val-every-n-epoch', type=int, default=1, |
| help='Check validation every n epochs') |
| |
| |
| wandb_group = parser.add_argument_group('wandb') |
| wandb_group.add_argument('--wandb-project', type=str, default='peptune', |
| help='WandB project name') |
| wandb_group.add_argument('--wandb-notes', type=str, default=None, |
| help='WandB notes') |
| wandb_group.add_argument('--wandb-group', type=str, default=None, |
| help='WandB group') |
| wandb_group.add_argument('--wandb-job-type', type=str, default=None, |
| help='WandB job type') |
| wandb_group.add_argument('--wandb-name', type=str, default='sophia-tang', |
| help='WandB run name') |
| wandb_group.add_argument('--wandb-id', type=str, default=None, |
| help='WandB run ID') |
| |
| |
| checkpoint_group = parser.add_argument_group('checkpointing') |
| checkpoint_group.add_argument('--save-dir', type=str, default=None, |
| help='Directory to save checkpoints') |
| checkpoint_group.add_argument('--resume-from-ckpt', action='store_true', default=True, |
| help='Resume from checkpoint') |
| checkpoint_group.add_argument('--resume-ckpt-path', type=str, default=None, |
| help='Path to checkpoint to resume from') |
| checkpoint_group.add_argument('--checkpoint-every-n-epochs', type=int, default=1, |
| help='Save checkpoint every n epochs') |
| checkpoint_group.add_argument('--checkpoint-monitor', type=str, default='val/nll', |
| help='Metric to monitor for checkpointing') |
| checkpoint_group.add_argument('--checkpoint-save-top-k', type=int, default=10, |
| help='Save top k checkpoints') |
| checkpoint_group.add_argument('--checkpoint-mode', type=str, default='min', |
| choices=['min', 'max'], |
| help='Mode for checkpoint monitoring') |
| checkpoint_group.add_argument('--checkpoint-dirpath', type=str, |
| default='./checkpoints/11M-old-tokenizer', |
| help='Directory path for checkpoints') |
| |
| |
| scheduler_group = parser.add_argument_group('lr_scheduler') |
| scheduler_group.add_argument('--lr-warmup-steps', type=int, default=2500, |
| help='Number of warmup steps for learning rate') |
| |
| return parser |
|
|
|
|
| def get_args(): |
| """Parse and return arguments.""" |
| parser = get_parser() |
| args = parser.parse_args() |
| |
| |
| if args.eval_global_batch_size is None: |
| args.eval_global_batch_size = args.global_batch_size |
| |
| if args.num_workers is None: |
| args.num_workers = len(os.sched_getaffinity(0)) |
| |
| if args.wandb_id is None: |
| args.wandb_id = f"{args.wandb_name}_nov12_set2" |
| |
| if args.save_dir is None: |
| args.save_dir = os.getcwd() |
| |
| return args |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args() |
| print(args) |
|
|