Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import torch
import torch.nn.init
def norm_tensor(shape, device, dtype, std=0.02):
t = torch.empty(shape, device=device, dtype=dtype, requires_grad=True)
torch.nn.init.trunc_normal_(t, mean=0.0, std=std)
return t