File size: 4,561 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3

import os
import random
import torch
import pandas as pd

from tqdm import tqdm
from omegaconf import OmegaConf
from datetime import datetime

from transformers import AutoTokenizer, AutoModelForMaskedLM

from src.madsbm.wt_peptide.sbm_module import MadSBM
from src.sampling.madsbm_sampler import MadSBMSampler

from src.utils.generate_utils import calc_entropy, mask_for_de_novo, calc_ppl
from src.utils.model_utils import _print


device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
os.chdir('/scratch/pranamlab/sgoel/MadSBM')
config = OmegaConf.load("./configs/wt_pep.yaml")

date = datetime.now().strftime("%Y-%m-%d")


def generate_sequence(masked_seq, target_toks, tokenizer, generator, device):
    input_ids = tokenizer(masked_seq, return_tensors="pt").to(device)['input_ids']

    uncond_ids, uncond_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=False)
    guided_ids, guided_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=True)
    
    uncond_seq = tokenizer.decode(uncond_ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
    guided_seq = tokenizer.decode(guided_ids[0].squeeze())[5:-5].replace(" ", "")
    
    return uncond_seq, guided_seq, uncond_bind, guided_bind


def main():
    csv_save_path = f'./results/guided/'
    
    try: os.makedirs(csv_save_path, exist_ok=False)
    except FileExistsError: pass

    # Load ESM model for eval
    esm_pth = config.model.esm_model
    esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
    esm_model.eval()

    # Load SBM model
    gen_model = MadSBM(config)
    state_dict = gen_model.get_state_dict(config.checkpointing.best_ckpt_path)
    gen_model.load_state_dict(state_dict)
    gen_model.to(device)
    gen_model.eval()
    tokenizer = gen_model.tokenizer
    generator = MadSBMSampler(gen_model, config, device, guidance=True)


    tgt_name = "3HVE"
    df = pd.read_csv("./data/wt_pep/targets.csv")
    tgt_seq = df.loc[df['Target'] == tgt_name, 'Sequence'].iloc[0]
    target_toks = tokenizer(tgt_seq, return_tensors='pt')['input_ids'].to(device)

    existing_binder = df.loc[df['Target'] == tgt_name, 'Existing Binder'].iloc[0]
    existing_binder_pred = generator.peptiverse.predict_binding_affinity(
        mode = 'wt',
        target_ids = target_toks,
        binder_ids = tokenizer(existing_binder, return_tensors='pt')['input_ids'].to(device).detach()
    )['affinity']

    _print(f'EXISTING BINDER AFFINITY: {existing_binder_pred}')

    seq_lengths = [length for length in [10, 15, 20] for _ in range(20)]
    generation_results = []

    for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
        
        masked_seq = mask_for_de_novo(seq_len) # Sequence of all <MASK> tokens
        uncond_seq, guided_seq, uncond_bind, guided_bind = generate_sequence(masked_seq, target_toks, tokenizer, generator, device)

        uncond_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
        guided_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')

        _print(f'uncond seq: {uncond_seq}')
        _print(f'uncond ppl: {uncond_ppl}')
        _print(f'uncond bind: {uncond_bind}')
        
        _print(f'guided seq: {guided_seq}')
        _print(f'guided ppl: {guided_ppl}')
        _print(f'guided bind: {guided_bind}')

        res_row = {
            "Uncond Generated Sequence": uncond_seq,
            "Guided Generated Sequence": guided_seq,
            "Uncond PPL": uncond_ppl,
            "Guided PPL": guided_ppl,
            "Uncond Affinity": uncond_bind,
            "Guided Affinity": guided_bind
        }

        generation_results.append(res_row) 


    df = pd.DataFrame(generation_results)
    
    _print(f"Uncond PPL Res: {df['Uncond PPL'].mean()}, {df['Uncond PPL'].std()}")
    _print(f"Guided PPL Res: {df['Guided PPL'].mean()}, {df['Guided PPL'].std()}")

    _print(f"Uncond Affinity Res: {df['Uncond Affinity'].mean()}, {df['Uncond Affinity'].std()}")
    _print(f"Guided Affinity Res: {df['Guided Affinity'].mean()}, {df['Guided Affinity'].std()}")
    
    df.to_csv(
        csv_save_path + f"/{tgt_name}/tau=0.5_topp=0.9_no-gumbel_rate=0.01_jump=0.05_ablate={config.model.ablate}_nsteps={config.sampling.n_steps}_seqs_with_ppl_{date}.csv",
        index=False
    )


if __name__ == "__main__":
    main()