Ashish Reddy
committing
a090db7
raw
history blame contribute delete
671 Bytes
import torch.nn as nn
class LoRALayer(nn.Module):
def __init__(self, in_features, out_features, rank, alpha):
super().__init__()
self.rank = rank
self.alpha = alpha
self.scaling = alpha/rank
self.loraA = nn.Linear(in_features, rank, bias=False)
self.loraB = nn.Linear(rank, out_features, bias=False)
nn.init.kaiming_uniform_(self.loraA.weight, a=5**0.5)
nn.init.zeros_(self.loraB.weight)
def forward(self, x):
delta = self.loraB(self.loraA(x)) # (x*A)*B --> ((B, S, D) * (B, D, R)) * (B, R, D) --> (B, S, R) * (B, R, D) --> (B, S, D)
x = self.scaling * delta
return x