|
|
|
|
|
|
|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import pandas as pd |
|
|
|
|
|
from tqdm import tqdm |
|
|
from omegaconf import OmegaConf |
|
|
from datetime import datetime |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
|
|
|
from src.madsbm.wt_peptide.sbm_module import MadSBM |
|
|
from src.sampling.madsbm_sampler import MadSBMSampler |
|
|
|
|
|
from src.utils.generate_utils import calc_entropy, mask_for_de_novo, calc_ppl |
|
|
from src.utils.model_utils import _print |
|
|
|
|
|
|
|
|
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') |
|
|
os.chdir('/scratch/pranamlab/sgoel/MadSBM') |
|
|
config = OmegaConf.load("./configs/wt_pep.yaml") |
|
|
|
|
|
date = datetime.now().strftime("%Y-%m-%d") |
|
|
|
|
|
|
|
|
def generate_sequence(masked_seq, target_toks, tokenizer, generator, device): |
|
|
input_ids = tokenizer(masked_seq, return_tensors="pt").to(device)['input_ids'] |
|
|
|
|
|
uncond_ids, uncond_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=False) |
|
|
guided_ids, guided_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=True) |
|
|
|
|
|
uncond_seq = tokenizer.decode(uncond_ids[0].squeeze())[5:-5].replace(" ", "") |
|
|
guided_seq = tokenizer.decode(guided_ids[0].squeeze())[5:-5].replace(" ", "") |
|
|
|
|
|
return uncond_seq, guided_seq, uncond_bind, guided_bind |
|
|
|
|
|
|
|
|
def main(): |
|
|
csv_save_path = f'./results/guided/' |
|
|
|
|
|
try: os.makedirs(csv_save_path, exist_ok=False) |
|
|
except FileExistsError: pass |
|
|
|
|
|
|
|
|
esm_pth = config.model.esm_model |
|
|
esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device) |
|
|
esm_model.eval() |
|
|
|
|
|
|
|
|
gen_model = MadSBM(config) |
|
|
state_dict = gen_model.get_state_dict(config.checkpointing.best_ckpt_path) |
|
|
gen_model.load_state_dict(state_dict) |
|
|
gen_model.to(device) |
|
|
gen_model.eval() |
|
|
tokenizer = gen_model.tokenizer |
|
|
generator = MadSBMSampler(gen_model, config, device, guidance=True) |
|
|
|
|
|
|
|
|
tgt_name = "3HVE" |
|
|
df = pd.read_csv("./data/wt_pep/targets.csv") |
|
|
tgt_seq = df.loc[df['Target'] == tgt_name, 'Sequence'].iloc[0] |
|
|
target_toks = tokenizer(tgt_seq, return_tensors='pt')['input_ids'].to(device) |
|
|
|
|
|
existing_binder = df.loc[df['Target'] == tgt_name, 'Existing Binder'].iloc[0] |
|
|
existing_binder_pred = generator.peptiverse.predict_binding_affinity( |
|
|
mode = 'wt', |
|
|
target_ids = target_toks, |
|
|
binder_ids = tokenizer(existing_binder, return_tensors='pt')['input_ids'].to(device).detach() |
|
|
)['affinity'] |
|
|
|
|
|
_print(f'EXISTING BINDER AFFINITY: {existing_binder_pred}') |
|
|
|
|
|
seq_lengths = [length for length in [10, 15, 20] for _ in range(20)] |
|
|
generation_results = [] |
|
|
|
|
|
for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "): |
|
|
|
|
|
masked_seq = mask_for_de_novo(seq_len) |
|
|
uncond_seq, guided_seq, uncond_bind, guided_bind = generate_sequence(masked_seq, target_toks, tokenizer, generator, device) |
|
|
|
|
|
uncond_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm') |
|
|
guided_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm') |
|
|
|
|
|
_print(f'uncond seq: {uncond_seq}') |
|
|
_print(f'uncond ppl: {uncond_ppl}') |
|
|
_print(f'uncond bind: {uncond_bind}') |
|
|
|
|
|
_print(f'guided seq: {guided_seq}') |
|
|
_print(f'guided ppl: {guided_ppl}') |
|
|
_print(f'guided bind: {guided_bind}') |
|
|
|
|
|
res_row = { |
|
|
"Uncond Generated Sequence": uncond_seq, |
|
|
"Guided Generated Sequence": guided_seq, |
|
|
"Uncond PPL": uncond_ppl, |
|
|
"Guided PPL": guided_ppl, |
|
|
"Uncond Affinity": uncond_bind, |
|
|
"Guided Affinity": guided_bind |
|
|
} |
|
|
|
|
|
generation_results.append(res_row) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(generation_results) |
|
|
|
|
|
_print(f"Uncond PPL Res: {df['Uncond PPL'].mean()}, {df['Uncond PPL'].std()}") |
|
|
_print(f"Guided PPL Res: {df['Guided PPL'].mean()}, {df['Guided PPL'].std()}") |
|
|
|
|
|
_print(f"Uncond Affinity Res: {df['Uncond Affinity'].mean()}, {df['Uncond Affinity'].std()}") |
|
|
_print(f"Guided Affinity Res: {df['Guided Affinity'].mean()}, {df['Guided Affinity'].std()}") |
|
|
|
|
|
df.to_csv( |
|
|
csv_save_path + f"/{tgt_name}/tau=0.5_topp=0.9_no-gumbel_rate=0.01_jump=0.05_ablate={config.model.ablate}_nsteps={config.sampling.n_steps}_seqs_with_ppl_{date}.csv", |
|
|
index=False |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |