MadSBM / src /sampling /path_tracer.py
Shrey Goel
initial commit
94c2704
import torch
import torch.nn.functional as F
class ProbabilityPathTracer:
def __init__(self, oracle_model, tokenizer, device):
self.oracle = oracle_model
self.tokenizer = tokenizer
self.device = device
self.mask_id = tokenizer.mask_token_id
self.history = {} # {nth_step: prob_score}
@torch.inference_mode()
def compute_loglikeli(self, xt):
is_revealed = (xt != self.mask_id)
if not is_revealed.any():
return 0.0
# esm forward pass
logits = self.oracle(
input_ids=xt,
attention_mask=torch.ones_like(xt, device=xt.device)
).logits
# Calculate CE loss only on unmasked tokens
nll = F.cross_entropy(
logits.view(-1, logits.size(-1)),
xt.view(-1),
reduction='none'
)
nll = nll.view(xt.shape)
# Lower NLL = better --> higher LL = better
avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1)
return avg_ll.item()
def log_step(self, xt, step_idx):
score = self.compute_loglikeli(xt)
self.history[f"trace_step_{step_idx}"] = score
def get_trace(self):
return self.history