Lekr0's picture
Add files using upload-large-folder tool
212a146 verified
import argparse
import time
import torch
from specforge.core.loss import LogSoftmaxLoss, _compute_loss
TTT_LENGTH = 7
def benchmark_loss_method(
loss_method: str,
test_configs: list,
):
"""Benchmark a loss computation method for speed and GPU memory usage."""
print(f"\n=== Benchmarking {loss_method} Loss ===")
results = []
for config in test_configs:
B, T, V = config
print(f"\nTesting config: B={B}, T={T}, V={V}")
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create tensors outside timing measurement
target = torch.softmax(
torch.randn(B, T, V, device="cuda", dtype=torch.float32), dim=-1
)
position_mask = torch.ones((B, T, 1), dtype=torch.bool, device="cuda")
# Pre-allocate logits tensors for each TTT step
logits_list = []
for i in range(TTT_LENGTH):
logits = torch.randn(
B, T, V, device="cuda", requires_grad=True, dtype=torch.float32
)
logits_list.append(logits)
torch.cuda.synchronize() # Ensure all operations are complete
start_time = time.time()
plosses = []
for i in range(TTT_LENGTH):
logits = logits_list[i]
if loss_method == "triton":
loss = LogSoftmaxLoss.apply(logits, target, position_mask)
else:
loss = _compute_loss(logits, target, position_mask)
plosses.append(loss)
ploss_weight = [0.8**i for i in range(len(plosses))]
ploss = (
sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
/ TTT_LENGTH
)
ploss.backward()
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()
total_time = end_time - start_time
# Record memory usage
peak_memory = 0
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated()
results.append(
{
"B": B,
"T": T,
"V": V,
"time_total": total_time,
"peak_memory": peak_memory,
}
)
print(f" Total time (forward + backward): {total_time*1000:.3f}ms")
print(f" Peak memory: {peak_memory / 1024**3:.3f} GB")
return results
def main():
parser = argparse.ArgumentParser(description="Benchmark loss computation methods")
parser.add_argument(
"--num-runs", type=int, default=5, help="Number of runs for averaging"
)
args = parser.parse_args()
print("PyTorch version:", torch.__version__)
if torch.cuda.is_available():
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name())
print(
"GPU memory:",
torch.cuda.get_device_properties(0).total_memory / 1024**3,
"GB",
)
else:
print("CUDA not available - running on CPU")
# Define test configurations (B, T, V)
test_configs = [
(1, 1024, 32000),
(1, 1024, 64000),
(1, 4096, 32000),
(1, 4096, 64000),
(1, 8192, 32000),
(1, 8192, 64000),
(1, 16384, 32000),
]
print(f"Testing configurations: {test_configs}")
# Run benchmarks
print("\n" + "=" * 60)
pytorch_results = benchmark_loss_method("pytorch", test_configs)
print("\n" + "=" * 60)
triton_results = benchmark_loss_method("triton", test_configs)
# Print results summary
print(f"\n=== Performance Summary ===")
print(f"Configurations tested: {len(test_configs)}")
# Print detailed results table
print(
f"\n{'Config (B,T,V)':<15} {'PyTorch (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'PyTorch Mem (GB)':<18} {'Triton Mem (GB)':<15} {'Memory Save':<12}"
)
print("-" * 115)
for i, config in enumerate(test_configs):
B, T, V = config
config_str = f"({B},{T},{V})"
pytorch_result = next(
(r for r in pytorch_results if r["B"] == B and r["T"] == T and r["V"] == V),
None,
)
triton_result = next(
(r for r in triton_results if r["B"] == B and r["T"] == T and r["V"] == V),
None,
)
if pytorch_result and triton_result:
pytorch_time_str = f"{pytorch_result['time_total']*1000:.2f}"
pytorch_mem_str = f"{pytorch_result['peak_memory']/1024**3:.2f}"
triton_time_str = f"{triton_result['time_total']*1000:.2f}"
triton_mem_str = f"{triton_result['peak_memory']/1024**3:.2f}"
if triton_result["time_total"] > 0:
speedup = pytorch_result["time_total"] / triton_result["time_total"]
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
# Calculate memory savings percentage
if pytorch_result["peak_memory"] > 0:
memory_save_pct = (
(pytorch_result["peak_memory"] - triton_result["peak_memory"])
/ pytorch_result["peak_memory"]
) * 100
memory_save_str = f"{memory_save_pct:.1f}%"
else:
memory_save_str = "N/A"
print(
f"{config_str:<15} {pytorch_time_str:<15} {triton_time_str:<15} {speedup_str:<10} {pytorch_mem_str:<18} {triton_mem_str:<15} {memory_save_str:<12}"
)
if __name__ == "__main__":
main()