| import unittest | |
| import torch | |
| from specforge.core.loss import LogSoftmaxLoss, _compute_loss | |
| from .utils import norm_tensor | |
| class TestLogSoftmaxLoss(unittest.TestCase): | |
| TTT_LENGTH = 7 | |
| def _test_loss_and_gradient_calculation(self, B, T, V): | |
| if not torch.cuda.is_available(): | |
| device = "cpu" | |
| else: | |
| device = "cuda" | |
| logits = norm_tensor((B, T, V), device, torch.float32) | |
| logits2 = logits.clone().detach().requires_grad_(True) | |
| target = norm_tensor((B, T, V), device, torch.float32) | |
| position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device) | |
| output1 = LogSoftmaxLoss.apply(logits, target, position_mask) | |
| output2 = _compute_loss(logits2, target, position_mask) | |
| torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) | |
| output1.backward() | |
| output2.backward() | |
| torch.testing.assert_close(logits.grad, logits2.grad, rtol=1e-4, atol=1e-4) | |
| def test_loss(self): | |
| B = [1, 2, 4] | |
| T = [1024, 2048, 4096, 6000] | |
| V = [4096, 8192, 10000] | |
| for b in B: | |
| for t in T: | |
| for v in V: | |
| self._test_loss_and_gradient_calculation(b, t, v) | |
| def test_ttt_loss_accumulation(self): | |
| if not torch.cuda.is_available(): | |
| device = "cpu" | |
| else: | |
| device = "cuda" | |
| B, T, V = 1, 1024, 3200 | |
| plosses = [] | |
| plosses_compare = [] | |
| logits_list = [ | |
| norm_tensor((B, T, V), device, torch.float32) | |
| for _ in range(self.TTT_LENGTH) | |
| ] | |
| logits_list_copy = [ | |
| logits.clone().detach().requires_grad_(True) for logits in logits_list | |
| ] | |
| for i in range(self.TTT_LENGTH): | |
| logits = logits_list[i] | |
| logits2 = logits_list_copy[i] | |
| target = norm_tensor((B, T, V), device, torch.float32) | |
| position_mask = torch.randint( | |
| 0, 2, (B, T, 1), dtype=torch.bool, device=device | |
| ) | |
| output1 = LogSoftmaxLoss.apply(logits, target, position_mask) | |
| output2 = _compute_loss(logits2, target, position_mask) | |
| torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) | |
| plosses.append(output1) | |
| plosses_compare.append(output2) | |
| ploss_weight = [0.8**i for i in range(len(plosses))] | |
| ploss = ( | |
| sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) | |
| / self.TTT_LENGTH | |
| ) | |
| ploss_compare = ( | |
| sum([ploss_weight[i] * plosses_compare[i] for i in range(len(plosses))]) | |
| / self.TTT_LENGTH | |
| ) | |
| torch.testing.assert_close(ploss, ploss_compare, rtol=1e-4, atol=1e-4) | |
| ploss.backward() | |
| ploss_compare.backward() | |
| for i in range(self.TTT_LENGTH): | |
| torch.testing.assert_close( | |
| logits_list[i].grad, logits_list_copy[i].grad, rtol=1e-4, atol=1e-4 | |
| ) | |