MadSBM / src /sampling /guided_sample.py
Shrey Goel
initial commit
94c2704
#!/usr/bin/env python3
import os
import random
import torch
import pandas as pd
from tqdm import tqdm
from omegaconf import OmegaConf
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForMaskedLM
from src.madsbm.wt_peptide.sbm_module import MadSBM
from src.sampling.madsbm_sampler import MadSBMSampler
from src.utils.generate_utils import calc_entropy, mask_for_de_novo, calc_ppl
from src.utils.model_utils import _print
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
os.chdir('/scratch/pranamlab/sgoel/MadSBM')
config = OmegaConf.load("./configs/wt_pep.yaml")
date = datetime.now().strftime("%Y-%m-%d")
def generate_sequence(masked_seq, target_toks, tokenizer, generator, device):
input_ids = tokenizer(masked_seq, return_tensors="pt").to(device)['input_ids']
uncond_ids, uncond_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=False)
guided_ids, guided_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=True)
uncond_seq = tokenizer.decode(uncond_ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
guided_seq = tokenizer.decode(guided_ids[0].squeeze())[5:-5].replace(" ", "")
return uncond_seq, guided_seq, uncond_bind, guided_bind
def main():
csv_save_path = f'./results/guided/'
try: os.makedirs(csv_save_path, exist_ok=False)
except FileExistsError: pass
# Load ESM model for eval
esm_pth = config.model.esm_model
esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
esm_model.eval()
# Load SBM model
gen_model = MadSBM(config)
state_dict = gen_model.get_state_dict(config.checkpointing.best_ckpt_path)
gen_model.load_state_dict(state_dict)
gen_model.to(device)
gen_model.eval()
tokenizer = gen_model.tokenizer
generator = MadSBMSampler(gen_model, config, device, guidance=True)
tgt_name = "3HVE"
df = pd.read_csv("./data/wt_pep/targets.csv")
tgt_seq = df.loc[df['Target'] == tgt_name, 'Sequence'].iloc[0]
target_toks = tokenizer(tgt_seq, return_tensors='pt')['input_ids'].to(device)
existing_binder = df.loc[df['Target'] == tgt_name, 'Existing Binder'].iloc[0]
existing_binder_pred = generator.peptiverse.predict_binding_affinity(
mode = 'wt',
target_ids = target_toks,
binder_ids = tokenizer(existing_binder, return_tensors='pt')['input_ids'].to(device).detach()
)['affinity']
_print(f'EXISTING BINDER AFFINITY: {existing_binder_pred}')
seq_lengths = [length for length in [10, 15, 20] for _ in range(20)]
generation_results = []
for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
masked_seq = mask_for_de_novo(seq_len) # Sequence of all <MASK> tokens
uncond_seq, guided_seq, uncond_bind, guided_bind = generate_sequence(masked_seq, target_toks, tokenizer, generator, device)
uncond_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
guided_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
_print(f'uncond seq: {uncond_seq}')
_print(f'uncond ppl: {uncond_ppl}')
_print(f'uncond bind: {uncond_bind}')
_print(f'guided seq: {guided_seq}')
_print(f'guided ppl: {guided_ppl}')
_print(f'guided bind: {guided_bind}')
res_row = {
"Uncond Generated Sequence": uncond_seq,
"Guided Generated Sequence": guided_seq,
"Uncond PPL": uncond_ppl,
"Guided PPL": guided_ppl,
"Uncond Affinity": uncond_bind,
"Guided Affinity": guided_bind
}
generation_results.append(res_row)
df = pd.DataFrame(generation_results)
_print(f"Uncond PPL Res: {df['Uncond PPL'].mean()}, {df['Uncond PPL'].std()}")
_print(f"Guided PPL Res: {df['Guided PPL'].mean()}, {df['Guided PPL'].std()}")
_print(f"Uncond Affinity Res: {df['Uncond Affinity'].mean()}, {df['Uncond Affinity'].std()}")
_print(f"Guided Affinity Res: {df['Guided Affinity'].mean()}, {df['Guided Affinity'].std()}")
df.to_csv(
csv_save_path + f"/{tgt_name}/tau=0.5_topp=0.9_no-gumbel_rate=0.01_jump=0.05_ablate={config.model.ablate}_nsteps={config.sampling.n_steps}_seqs_with_ppl_{date}.csv",
index=False
)
if __name__ == "__main__":
main()