|
|
|
|
|
import sys |
|
|
import torch |
|
|
import math |
|
|
|
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from collections import Counter |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml") |
|
|
|
|
|
|
|
|
|
|
|
def mask_for_de_novo(sequence_length): |
|
|
return "<mask>" * sequence_length |
|
|
|
|
|
def mask_for_scaffold(sequence, generate_type, mask_token): |
|
|
if generate_type == "uppercase": |
|
|
sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence]) |
|
|
elif generate_type == "lowercase": |
|
|
sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence]) |
|
|
return sequence |
|
|
|
|
|
|
|
|
|
|
|
def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1): |
|
|
""" |
|
|
Following the given evodiff example |
|
|
https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb |
|
|
""" |
|
|
|
|
|
motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) |
|
|
tkns = tokenizer.tokenize([motif_seq]) |
|
|
sample = torch.as_tensor(tkns).to(device) |
|
|
|
|
|
|
|
|
loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy() |
|
|
np.random.shuffle(loc) |
|
|
|
|
|
sample = sample.to(device).unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in loc: |
|
|
timestep = torch.tensor([0] * batch_size).to(device) |
|
|
timestep = timestep.to(device) |
|
|
prediction = model(sample, timestep) |
|
|
p = prediction[:, i, :len(tokenizer.all_aas) - 6] |
|
|
p = F.softmax(p, dim=1) |
|
|
p_sample = torch.multinomial(p, num_samples=1) |
|
|
sample[:, i] = p_sample.squeeze() |
|
|
output = [tokenizer.untokenize(s) for s in sample] |
|
|
return output[0] |
|
|
|
|
|
|
|
|
def dplm_infill(masked_seq, tokenizer, model, device): |
|
|
from src.lm.dplm.diffusion_module import DPLM |
|
|
from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler |
|
|
|
|
|
generator = DPLMUnconditionalSampler(tokenizer, model) |
|
|
xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device) |
|
|
denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze() |
|
|
generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5] |
|
|
return generated_sequence |
|
|
|
|
|
|
|
|
|
|
|
def calc_progen_ppl(model, tokenizer, target, device, fp16=True): |
|
|
"""Compute causal LM cross-entropy loss for a given sequence.""" |
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(enabled=fp16): |
|
|
logits = model( |
|
|
input_ids = target, |
|
|
attention_mask = torch.ones_like(target) |
|
|
).logits |
|
|
|
|
|
logits = logits[:-1, ...] |
|
|
target = target[1:] |
|
|
loss = torch.nn.functional.cross_entropy( |
|
|
input=logits, |
|
|
target=target, |
|
|
reduction='mean' |
|
|
) |
|
|
return torch.exp(loss).item() |
|
|
|
|
|
|
|
|
def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type): |
|
|
total_loss = 0.0 |
|
|
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device) |
|
|
|
|
|
for i in mask_token_indices: |
|
|
masked_input = tensor_input.clone() |
|
|
masked_input[0, i] = tokenizer.mask_token_id |
|
|
|
|
|
labels = torch.full(tensor_input.shape, -100).to(model.device) |
|
|
labels[0, i] = tensor_input[0, i] |
|
|
|
|
|
with torch.no_grad(): |
|
|
loss = model(masked_input, labels=labels).loss.item() |
|
|
total_loss += loss |
|
|
|
|
|
avg_loss = total_loss / len(generated_sequence) |
|
|
perplexity = math.exp(avg_loss) |
|
|
|
|
|
return perplexity |
|
|
|
|
|
|
|
|
def calc_entropy(seq): |
|
|
counts = Counter(seq) |
|
|
total_len = len(seq) |
|
|
entropy = 0.0 |
|
|
for count in counts.values(): |
|
|
prob = count / total_len |
|
|
entropy -= prob * math.log2(prob) |
|
|
return entropy |