| import argparse |
| import time |
|
|
| import torch |
|
|
| from specforge.core.loss import LogSoftmaxLoss, _compute_loss |
|
|
| TTT_LENGTH = 7 |
|
|
|
|
| def benchmark_loss_method( |
| loss_method: str, |
| test_configs: list, |
| ): |
| """Benchmark a loss computation method for speed and GPU memory usage.""" |
| print(f"\n=== Benchmarking {loss_method} Loss ===") |
|
|
| results = [] |
|
|
| for config in test_configs: |
| B, T, V = config |
| print(f"\nTesting config: B={B}, T={T}, V={V}") |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| |
| target = torch.softmax( |
| torch.randn(B, T, V, device="cuda", dtype=torch.float32), dim=-1 |
| ) |
| position_mask = torch.ones((B, T, 1), dtype=torch.bool, device="cuda") |
|
|
| |
| logits_list = [] |
| for i in range(TTT_LENGTH): |
| logits = torch.randn( |
| B, T, V, device="cuda", requires_grad=True, dtype=torch.float32 |
| ) |
| logits_list.append(logits) |
|
|
| torch.cuda.synchronize() |
| start_time = time.time() |
|
|
| plosses = [] |
| for i in range(TTT_LENGTH): |
| logits = logits_list[i] |
| if loss_method == "triton": |
| loss = LogSoftmaxLoss.apply(logits, target, position_mask) |
| else: |
| loss = _compute_loss(logits, target, position_mask) |
| plosses.append(loss) |
|
|
| ploss_weight = [0.8**i for i in range(len(plosses))] |
| ploss = ( |
| sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) |
| / TTT_LENGTH |
| ) |
| ploss.backward() |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
|
|
| end_time = time.time() |
| total_time = end_time - start_time |
| |
| peak_memory = 0 |
| if torch.cuda.is_available(): |
| peak_memory = torch.cuda.max_memory_allocated() |
|
|
| results.append( |
| { |
| "B": B, |
| "T": T, |
| "V": V, |
| "time_total": total_time, |
| "peak_memory": peak_memory, |
| } |
| ) |
|
|
| print(f" Total time (forward + backward): {total_time*1000:.3f}ms") |
| print(f" Peak memory: {peak_memory / 1024**3:.3f} GB") |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Benchmark loss computation methods") |
| parser.add_argument( |
| "--num-runs", type=int, default=5, help="Number of runs for averaging" |
| ) |
| args = parser.parse_args() |
|
|
| print("PyTorch version:", torch.__version__) |
| if torch.cuda.is_available(): |
| print("CUDA available:", torch.cuda.is_available()) |
| print("GPU:", torch.cuda.get_device_name()) |
| print( |
| "GPU memory:", |
| torch.cuda.get_device_properties(0).total_memory / 1024**3, |
| "GB", |
| ) |
| else: |
| print("CUDA not available - running on CPU") |
|
|
| |
| test_configs = [ |
| (1, 1024, 32000), |
| (1, 1024, 64000), |
| (1, 4096, 32000), |
| (1, 4096, 64000), |
| (1, 8192, 32000), |
| (1, 8192, 64000), |
| (1, 16384, 32000), |
| ] |
|
|
| print(f"Testing configurations: {test_configs}") |
|
|
| |
| print("\n" + "=" * 60) |
| pytorch_results = benchmark_loss_method("pytorch", test_configs) |
|
|
| print("\n" + "=" * 60) |
| triton_results = benchmark_loss_method("triton", test_configs) |
|
|
| |
| print(f"\n=== Performance Summary ===") |
| print(f"Configurations tested: {len(test_configs)}") |
|
|
| |
| print( |
| f"\n{'Config (B,T,V)':<15} {'PyTorch (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'PyTorch Mem (GB)':<18} {'Triton Mem (GB)':<15} {'Memory Save':<12}" |
| ) |
| print("-" * 115) |
|
|
| for i, config in enumerate(test_configs): |
| B, T, V = config |
| config_str = f"({B},{T},{V})" |
|
|
| pytorch_result = next( |
| (r for r in pytorch_results if r["B"] == B and r["T"] == T and r["V"] == V), |
| None, |
| ) |
| triton_result = next( |
| (r for r in triton_results if r["B"] == B and r["T"] == T and r["V"] == V), |
| None, |
| ) |
|
|
| if pytorch_result and triton_result: |
| pytorch_time_str = f"{pytorch_result['time_total']*1000:.2f}" |
| pytorch_mem_str = f"{pytorch_result['peak_memory']/1024**3:.2f}" |
|
|
| triton_time_str = f"{triton_result['time_total']*1000:.2f}" |
| triton_mem_str = f"{triton_result['peak_memory']/1024**3:.2f}" |
|
|
| if triton_result["time_total"] > 0: |
| speedup = pytorch_result["time_total"] / triton_result["time_total"] |
| speedup_str = f"{speedup:.2f}x" |
| else: |
| speedup_str = "N/A" |
|
|
| |
| if pytorch_result["peak_memory"] > 0: |
| memory_save_pct = ( |
| (pytorch_result["peak_memory"] - triton_result["peak_memory"]) |
| / pytorch_result["peak_memory"] |
| ) * 100 |
| memory_save_str = f"{memory_save_pct:.1f}%" |
| else: |
| memory_save_str = "N/A" |
|
|
| print( |
| f"{config_str:<15} {pytorch_time_str:<15} {triton_time_str:<15} {speedup_str:<10} {pytorch_mem_str:<18} {triton_mem_str:<15} {memory_save_str:<12}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|