Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import sys | |
| from contextlib import contextmanager | |
| import torch | |
| from fudge.constants import * | |
| def suppress_stdout(): | |
| with open(os.devnull, "w") as devnull: | |
| old_stdout = sys.stdout | |
| sys.stdout = devnull | |
| try: | |
| yield | |
| finally: | |
| sys.stdout = old_stdout | |
| def save_checkpoint(state, save_path): | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| torch.save(state, save_path) | |
| def freeze(module): | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| def num_params(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def clamp(x, limit): | |
| return max(-limit, min(x, limit)) | |
| def pad_to_length(tensor, length, dim, value=0): | |
| """ | |
| Pad tensor to given length in given dim using given value (value should be numeric) | |
| """ | |
| assert tensor.size(dim) <= length | |
| if tensor.size(dim) < length: | |
| zeros_shape = list(tensor.shape) | |
| zeros_shape[dim] = length - tensor.size(dim) | |
| zeros_shape = tuple(zeros_shape) | |
| return torch.cat([tensor, torch.zeros(zeros_shape).type(tensor.type()).to(tensor.device).fill_(value)], dim=dim) | |
| else: | |
| return tensor | |
| def pad_mask(lengths: torch.LongTensor) -> torch.ByteTensor: | |
| """ | |
| Create a mask of seq x batch where seq = max(lengths), with 0 in padding locations and 1 otherwise. | |
| """ | |
| # lengths: bs. Ex: [2, 3, 1] | |
| max_seqlen = torch.max(lengths) | |
| expanded_lengths = lengths.unsqueeze(0).repeat((max_seqlen, 1)) # [[2, 3, 1], [2, 3, 1], [2, 3, 1]] | |
| indices = torch.arange(max_seqlen).unsqueeze(1).repeat((1, lengths.size(0))).to(lengths.device) # [[0, 0, 0], [1, 1, 1], [2, 2, 2]] | |
| return expanded_lengths > indices # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs | |
| class ProgressMeter(object): | |
| """ | |
| Display meter | |
| """ | |
| def __init__(self, num_batches, meters, prefix=""): | |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
| self.meters = meters | |
| self.prefix = prefix | |
| def display(self, batch): | |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
| entries.append(time.ctime(time.time())) | |
| entries += [str(meter) for meter in self.meters] | |
| print('\t'.join(entries)) | |
| def _get_batch_fmtstr(self, num_batches): | |
| num_digits = len(str(num_batches // 1)) | |
| fmt = '{:' + str(num_digits) + 'd}' | |
| return '[' + fmt + '/' + fmt.format(num_batches) + ']' | |
| class AverageMeter(object): | |
| """ | |
| Display meter | |
| Computes and stores the average and current value | |
| """ | |
| def __init__(self, name, fmt=':f'): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
| return fmtstr.format(**self.__dict__) |