|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class ProbabilityPathTracer: |
|
|
def __init__(self, oracle_model, tokenizer, device): |
|
|
self.oracle = oracle_model |
|
|
self.tokenizer = tokenizer |
|
|
self.device = device |
|
|
self.mask_id = tokenizer.mask_token_id |
|
|
self.history = {} |
|
|
|
|
|
@torch.inference_mode() |
|
|
def compute_loglikeli(self, xt): |
|
|
is_revealed = (xt != self.mask_id) |
|
|
|
|
|
if not is_revealed.any(): |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
logits = self.oracle( |
|
|
input_ids=xt, |
|
|
attention_mask=torch.ones_like(xt, device=xt.device) |
|
|
).logits |
|
|
|
|
|
|
|
|
nll = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
xt.view(-1), |
|
|
reduction='none' |
|
|
) |
|
|
|
|
|
nll = nll.view(xt.shape) |
|
|
|
|
|
|
|
|
avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1) |
|
|
|
|
|
return avg_ll.item() |
|
|
|
|
|
def log_step(self, xt, step_idx): |
|
|
score = self.compute_loglikeli(xt) |
|
|
self.history[f"trace_step_{step_idx}"] = score |
|
|
|
|
|
def get_trace(self): |
|
|
return self.history |