File size: 18,181 Bytes
40e7e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import argparse
import os


def get_parser():
    parser = argparse.ArgumentParser(description='PepTune Training and Evaluation')
    
    # Noise parameters
    noise_group = parser.add_argument_group('noise')
    noise_group.add_argument('--noise-type', type=str, default='loglinear', 
                            help='Type of noise schedule')
    noise_group.add_argument('--sigma-min', type=float, default=1e-4,
                            help='Minimum sigma value')
    noise_group.add_argument('--sigma-max', type=float, default=20,
                            help='Maximum sigma value')
    noise_group.add_argument('--state-dependent', action='store_true', default=True,
                            help='Use state-dependent noise')
    
    # Base parameters
    parser.add_argument('--base-path', type=str, default='/path/to/PepTune',
                       help='Base path to PepTune')
    parser.add_argument('--mode', type=str, default='ppl_eval',
                       choices=['train', 'ppl_eval', 'sample_eval'],
                       help='Running mode')
    parser.add_argument('--diffusion', type=str, default='absorbing_state',
                       help='Diffusion type')
    parser.add_argument('--vocab', type=str, default='old_smiles',
                       choices=['old_smiles', 'new_smiles', 'selfies', 'helm'],
                       help='Vocabulary type')
    parser.add_argument('--backbone', type=str, default='roformer',
                       choices=['peptideclm', 'helmgpt', 'dit', 'roformer', 'finetune_roformer'],
                       help='Model backbone')
    parser.add_argument('--parameterization', type=str, default='subs',
                       help='Parameterization type')
    parser.add_argument('--time-conditioning', action='store_true', default=False,
                       help='Use time conditioning')
    parser.add_argument('--T', type=int, default=0,
                       help='Number of diffusion steps (0 for continuous time, 1000 for discrete)')
    parser.add_argument('--subs-masking', action='store_true', default=False,
                       help='Use substitution masking')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    
    # MCTS parameters
    mcts_group = parser.add_argument_group('mcts')
    mcts_group.add_argument('--mcts-num-children', type=int, default=50,
                           help='Number of children in MCTS')
    mcts_group.add_argument('--mcts-num-objectives', type=int, default=5,
                           help='Number of objectives in MCTS')
    mcts_group.add_argument('--mcts-topk', type=int, default=100,
                           help='Top-k for MCTS')
    mcts_group.add_argument('--mcts-mask-token', type=int, default=4,
                           help='Mask token ID')
    mcts_group.add_argument('--mcts-num-iter', type=int, default=128,
                           help='Number of MCTS iterations')
    mcts_group.add_argument('--mcts-sampling', type=int, default=0,
                           help='Sampling strategy (0 for gumbel, >0 for top-k)')
    mcts_group.add_argument('--mcts-invalid-penalty', type=float, default=0.5,
                           help='Penalty for invalid sequences')
    mcts_group.add_argument('--mcts-sample-prob', type=float, default=1.0,
                           help='Sampling probability')
    mcts_group.add_argument('--mcts-perm', action='store_true', default=True,
                           help='Use permutation in MCTS')
    mcts_group.add_argument('--mcts-dual', action='store_true', default=False,
                           help='Use dual mode')
    mcts_group.add_argument('--mcts-single', action='store_true', default=False,
                           help='Use single mode')
    mcts_group.add_argument('--mcts-time-dependent', action='store_true', default=True,
                           help='Use time-dependent MCTS')
    
    # Data parameters
    data_group = parser.add_argument_group('data')
    data_group.add_argument('--train-data', type=str,
                           default='/path/to/your/home/PepTune/data/peptide_data',
                           help='Path to training data')
    data_group.add_argument('--valid-data', type=str,
                           default='/path/to/your/home/PepTune/data/peptide_data',
                           help='Path to validation data')
    data_group.add_argument('--data-batching', type=str, default='wrapping',
                           choices=['padding', 'wrapping'],
                           help='Batching strategy')
    
    # Loader parameters
    loader_group = parser.add_argument_group('loader')
    loader_group.add_argument('--global-batch-size', type=int, default=64,
                             help='Global batch size')
    loader_group.add_argument('--eval-global-batch-size', type=int, default=None,
                             help='Evaluation global batch size (defaults to global-batch-size)')
    loader_group.add_argument('--num-workers', type=int, default=None,
                             help='Number of dataloader workers (defaults to available CPUs)')
    loader_group.add_argument('--pin-memory', action='store_true', default=True,
                             help='Pin memory for dataloaders')
    
    # Sampling parameters
    sampling_group = parser.add_argument_group('sampling')
    sampling_group.add_argument('--predictor', type=str, default='ddpm_cache',
                               choices=['analytic', 'ddpm', 'ddpm_cache'],
                               help='Predictor type for sampling')
    sampling_group.add_argument('--num-sequences', type=int, default=100,
                               help='Number of sequences to generate')
    sampling_group.add_argument('--sampling-eps', type=float, default=1e-3,
                               help='Sampling epsilon')
    sampling_group.add_argument('--steps', type=int, default=128,
                               help='Number of sampling steps')
    sampling_group.add_argument('--seq-length', type=int, default=100,
                               help='Sequence length')
    sampling_group.add_argument('--noise-removal', action='store_true', default=True,
                               help='Use noise removal')
    sampling_group.add_argument('--num-sample-batches', type=int, default=2,
                               help='Number of sample batches')
    sampling_group.add_argument('--num-sample-log', type=int, default=2,
                               help='Number of samples to log')
    sampling_group.add_argument('--stride-length', type=int, default=1,
                               help='Stride length for sampling')
    sampling_group.add_argument('--num-strides', type=int, default=1,
                               help='Number of strides')
    
    # Training parameters
    training_group = parser.add_argument_group('training')
    training_group.add_argument('--antithetic-sampling', action='store_true', default=True,
                               help='Use antithetic sampling')
    training_group.add_argument('--training-sampling-eps', type=float, default=1e-3,
                               help='Training sampling epsilon')
    training_group.add_argument('--focus-mask', action='store_true', default=False,
                               help='Use focus mask')
    training_group.add_argument('--accumulator', action='store_true', default=False,
                               help='Use accumulator')
    
    # Evaluation parameters
    eval_group = parser.add_argument_group('eval')
    eval_group.add_argument('--checkpoint-path', type=str, default=None,
                           help='Path to checkpoint for evaluation')
    eval_group.add_argument('--disable-ema', action='store_true', default=False,
                           help='Disable EMA')
    eval_group.add_argument('--compute-generative-perplexity', action='store_true', default=False,
                           help='Compute generative perplexity')
    eval_group.add_argument('--perplexity-batch-size', type=int, default=8,
                           help='Batch size for perplexity computation')
    eval_group.add_argument('--compute-perplexity-on-sanity', action='store_true', default=False,
                           help='Compute perplexity on sanity check')
    eval_group.add_argument('--gen-ppl-eval-model', type=str, default='gpt2-large',
                           help='Model for generative perplexity evaluation')
    eval_group.add_argument('--generate-samples', action='store_true', default=True,
                           help='Generate samples during evaluation')
    eval_group.add_argument('--generation-model', type=str, default=None,
                           help='Model for generation')
    
    # Optimizer parameters
    optim_group = parser.add_argument_group('optim')
    optim_group.add_argument('--weight-decay', type=float, default=0.075,
                            help='Weight decay')
    optim_group.add_argument('--lr', type=float, default=3e-4,
                            help='Learning rate')
    optim_group.add_argument('--beta1', type=float, default=0.9,
                            help='Adam beta1')
    optim_group.add_argument('--beta2', type=float, default=0.999,
                            help='Adam beta2')
    optim_group.add_argument('--eps', type=float, default=1e-8,
                            help='Adam epsilon')
    
    # PepCLM model parameters
    pepclm_group = parser.add_argument_group('pepclm')
    pepclm_group.add_argument('--pepclm-hidden-size', type=int, default=768,
                             help='PepCLM hidden size')
    pepclm_group.add_argument('--pepclm-cond-dim', type=int, default=256,
                             help='PepCLM conditioning dimension')
    pepclm_group.add_argument('--pepclm-n-heads', type=int, default=20,
                             help='PepCLM number of attention heads')
    pepclm_group.add_argument('--pepclm-n-blocks', type=int, default=4,
                             help='PepCLM number of blocks')
    pepclm_group.add_argument('--pepclm-dropout', type=float, default=0.5,
                             help='PepCLM dropout rate')
    pepclm_group.add_argument('--pepclm-length', type=int, default=512,
                             help='PepCLM sequence length')
    
    # General model parameters
    model_group = parser.add_argument_group('model')
    model_group.add_argument('--model-type', type=str, default='ddit',
                            help='Model type')
    model_group.add_argument('--hidden-size', type=int, default=768,
                            help='Model hidden size')
    model_group.add_argument('--cond-dim', type=int, default=128,
                            help='Conditioning dimension')
    model_group.add_argument('--length', type=int, default=512,
                            help='Sequence length')
    model_group.add_argument('--n-blocks', type=int, default=12,
                            help='Number of blocks')
    model_group.add_argument('--n-heads', type=int, default=12,
                            help='Number of attention heads')
    model_group.add_argument('--scale-by-sigma', action='store_true', default=True,
                            help='Scale by sigma')
    model_group.add_argument('--dropout', type=float, default=0.1,
                            help='Dropout rate')
    
    # RoFormer parameters
    roformer_group = parser.add_argument_group('roformer')
    roformer_group.add_argument('--roformer-hidden-size', type=int, default=768,
                               help='RoFormer hidden size')
    roformer_group.add_argument('--roformer-n-layers', type=int, default=8,
                               help='RoFormer number of layers')
    roformer_group.add_argument('--roformer-n-heads', type=int, default=8,
                               help='RoFormer number of attention heads')
    roformer_group.add_argument('--roformer-max-position-embeddings', type=int, default=1035,
                               help='RoFormer max position embeddings')
    
    # HelmGPT parameters
    helmgpt_group = parser.add_argument_group('helmgpt')
    helmgpt_group.add_argument('--helmgpt-hidden-size', type=int, default=256,
                              help='HelmGPT hidden size')
    helmgpt_group.add_argument('--helmgpt-embd-pdrop', type=float, default=0.1,
                              help='HelmGPT embedding dropout')
    helmgpt_group.add_argument('--helmgpt-resid-pdrop', type=float, default=0.1,
                              help='HelmGPT residual dropout')
    helmgpt_group.add_argument('--helmgpt-attn-pdrop', type=float, default=0.1,
                              help='HelmGPT attention dropout')
    helmgpt_group.add_argument('--helmgpt-ff-dropout', type=float, default=0.0,
                              help='HelmGPT feedforward dropout')
    helmgpt_group.add_argument('--helmgpt-block-size', type=int, default=140,
                              help='HelmGPT block size')
    helmgpt_group.add_argument('--helmgpt-n-layer', type=int, default=8,
                              help='HelmGPT number of layers')
    helmgpt_group.add_argument('--helmgpt-n-heads', type=int, default=8,
                              help='HelmGPT number of attention heads')
    
    # Trainer parameters
    trainer_group = parser.add_argument_group('trainer')
    trainer_group.add_argument('--accelerator', type=str, default='cuda',
                              help='Accelerator type')
    trainer_group.add_argument('--num-nodes', type=int, default=1,
                              help='Number of nodes')
    trainer_group.add_argument('--devices', type=int, default=1,
                              help='Number of devices')
    trainer_group.add_argument('--gradient-clip-val', type=float, default=1.0,
                              help='Gradient clipping value')
    trainer_group.add_argument('--precision', type=str, default='64-true',
                              help='Training precision')
    trainer_group.add_argument('--num-sanity-val-steps', type=int, default=2,
                              help='Number of sanity validation steps')
    trainer_group.add_argument('--max-epochs', type=int, default=100,
                              help='Maximum number of epochs')
    trainer_group.add_argument('--max-steps', type=int, default=1_000_000,
                              help='Maximum number of steps')
    trainer_group.add_argument('--log-every-n-steps', type=int, default=10,
                              help='Log every n steps')
    trainer_group.add_argument('--limit-train-batches', type=float, default=1.0,
                              help='Limit training batches')
    trainer_group.add_argument('--limit-val-batches', type=float, default=1.0,
                              help='Limit validation batches')
    trainer_group.add_argument('--check-val-every-n-epoch', type=int, default=1,
                              help='Check validation every n epochs')
    
    # WandB parameters
    wandb_group = parser.add_argument_group('wandb')
    wandb_group.add_argument('--wandb-project', type=str, default='peptune',
                            help='WandB project name')
    wandb_group.add_argument('--wandb-notes', type=str, default=None,
                            help='WandB notes')
    wandb_group.add_argument('--wandb-group', type=str, default=None,
                            help='WandB group')
    wandb_group.add_argument('--wandb-job-type', type=str, default=None,
                            help='WandB job type')
    wandb_group.add_argument('--wandb-name', type=str, default='sophia-tang',
                            help='WandB run name')
    wandb_group.add_argument('--wandb-id', type=str, default=None,
                            help='WandB run ID')
    
    # Checkpointing parameters
    checkpoint_group = parser.add_argument_group('checkpointing')
    checkpoint_group.add_argument('--save-dir', type=str, default=None,
                                 help='Directory to save checkpoints')
    checkpoint_group.add_argument('--resume-from-ckpt', action='store_true', default=True,
                                 help='Resume from checkpoint')
    checkpoint_group.add_argument('--resume-ckpt-path', type=str, default=None,
                                 help='Path to checkpoint to resume from')
    checkpoint_group.add_argument('--checkpoint-every-n-epochs', type=int, default=1,
                                 help='Save checkpoint every n epochs')
    checkpoint_group.add_argument('--checkpoint-monitor', type=str, default='val/nll',
                                 help='Metric to monitor for checkpointing')
    checkpoint_group.add_argument('--checkpoint-save-top-k', type=int, default=10,
                                 help='Save top k checkpoints')
    checkpoint_group.add_argument('--checkpoint-mode', type=str, default='min',
                                 choices=['min', 'max'],
                                 help='Mode for checkpoint monitoring')
    checkpoint_group.add_argument('--checkpoint-dirpath', type=str,
                                 default='./checkpoints/11M-old-tokenizer',
                                 help='Directory path for checkpoints')
    
    # LR Scheduler parameters
    scheduler_group = parser.add_argument_group('lr_scheduler')
    scheduler_group.add_argument('--lr-warmup-steps', type=int, default=2500,
                                help='Number of warmup steps for learning rate')
    
    return parser


def get_args():
    """Parse and return arguments."""
    parser = get_parser()
    args = parser.parse_args()
    
    # Post-process arguments
    if args.eval_global_batch_size is None:
        args.eval_global_batch_size = args.global_batch_size
    
    if args.num_workers is None:
        args.num_workers = len(os.sched_getaffinity(0))
    
    if args.wandb_id is None:
        args.wandb_id = f"{args.wandb_name}_nov12_set2"
    
    if args.save_dir is None:
        args.save_dir = os.getcwd()
    
    return args


if __name__ == '__main__':
    args = get_args()
    print(args)