import torch import torch.nn.functional as F class ProbabilityPathTracer: def __init__(self, oracle_model, tokenizer, device): self.oracle = oracle_model self.tokenizer = tokenizer self.device = device self.mask_id = tokenizer.mask_token_id self.history = {} # {nth_step: prob_score} @torch.inference_mode() def compute_loglikeli(self, xt): is_revealed = (xt != self.mask_id) if not is_revealed.any(): return 0.0 # esm forward pass logits = self.oracle( input_ids=xt, attention_mask=torch.ones_like(xt, device=xt.device) ).logits # Calculate CE loss only on unmasked tokens nll = F.cross_entropy( logits.view(-1, logits.size(-1)), xt.view(-1), reduction='none' ) nll = nll.view(xt.shape) # Lower NLL = better --> higher LL = better avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1) return avg_ll.item() def log_step(self, xt, step_idx): score = self.compute_loglikeli(xt) self.history[f"trace_step_{step_idx}"] = score def get_trace(self): return self.history