#!/usr/bin/env python3 import os import random import torch import pandas as pd from tqdm import tqdm from omegaconf import OmegaConf from datetime import datetime from transformers import AutoTokenizer, AutoModelForMaskedLM from src.madsbm.wt_peptide.sbm_module import MadSBM from src.sampling.madsbm_sampler import MadSBMSampler from src.utils.generate_utils import calc_entropy, mask_for_de_novo, calc_ppl from src.utils.model_utils import _print device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') os.chdir('/scratch/pranamlab/sgoel/MadSBM') config = OmegaConf.load("./configs/wt_pep.yaml") date = datetime.now().strftime("%Y-%m-%d") def generate_sequence(masked_seq, target_toks, tokenizer, generator, device): input_ids = tokenizer(masked_seq, return_tensors="pt").to(device)['input_ids'] uncond_ids, uncond_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=False) guided_ids, guided_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=True) uncond_seq = tokenizer.decode(uncond_ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues guided_seq = tokenizer.decode(guided_ids[0].squeeze())[5:-5].replace(" ", "") return uncond_seq, guided_seq, uncond_bind, guided_bind def main(): csv_save_path = f'./results/guided/' try: os.makedirs(csv_save_path, exist_ok=False) except FileExistsError: pass # Load ESM model for eval esm_pth = config.model.esm_model esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device) esm_model.eval() # Load SBM model gen_model = MadSBM(config) state_dict = gen_model.get_state_dict(config.checkpointing.best_ckpt_path) gen_model.load_state_dict(state_dict) gen_model.to(device) gen_model.eval() tokenizer = gen_model.tokenizer generator = MadSBMSampler(gen_model, config, device, guidance=True) tgt_name = "3HVE" df = pd.read_csv("./data/wt_pep/targets.csv") tgt_seq = df.loc[df['Target'] == tgt_name, 'Sequence'].iloc[0] target_toks = tokenizer(tgt_seq, return_tensors='pt')['input_ids'].to(device) existing_binder = df.loc[df['Target'] == tgt_name, 'Existing Binder'].iloc[0] existing_binder_pred = generator.peptiverse.predict_binding_affinity( mode = 'wt', target_ids = target_toks, binder_ids = tokenizer(existing_binder, return_tensors='pt')['input_ids'].to(device).detach() )['affinity'] _print(f'EXISTING BINDER AFFINITY: {existing_binder_pred}') seq_lengths = [length for length in [10, 15, 20] for _ in range(20)] generation_results = [] for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "): masked_seq = mask_for_de_novo(seq_len) # Sequence of all tokens uncond_seq, guided_seq, uncond_bind, guided_bind = generate_sequence(masked_seq, target_toks, tokenizer, generator, device) uncond_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm') guided_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm') _print(f'uncond seq: {uncond_seq}') _print(f'uncond ppl: {uncond_ppl}') _print(f'uncond bind: {uncond_bind}') _print(f'guided seq: {guided_seq}') _print(f'guided ppl: {guided_ppl}') _print(f'guided bind: {guided_bind}') res_row = { "Uncond Generated Sequence": uncond_seq, "Guided Generated Sequence": guided_seq, "Uncond PPL": uncond_ppl, "Guided PPL": guided_ppl, "Uncond Affinity": uncond_bind, "Guided Affinity": guided_bind } generation_results.append(res_row) df = pd.DataFrame(generation_results) _print(f"Uncond PPL Res: {df['Uncond PPL'].mean()}, {df['Uncond PPL'].std()}") _print(f"Guided PPL Res: {df['Guided PPL'].mean()}, {df['Guided PPL'].std()}") _print(f"Uncond Affinity Res: {df['Uncond Affinity'].mean()}, {df['Uncond Affinity'].std()}") _print(f"Guided Affinity Res: {df['Guided Affinity'].mean()}, {df['Guided Affinity'].std()}") df.to_csv( csv_save_path + f"/{tgt_name}/tau=0.5_topp=0.9_no-gumbel_rate=0.01_jump=0.05_ablate={config.model.ablate}_nsteps={config.sampling.n_steps}_seqs_with_ppl_{date}.csv", index=False ) if __name__ == "__main__": main()