Spaces:
Running
Running
File size: 1,032 Bytes
dde4f12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch, torch.nn as nn
batch_size = 64
max_len = 256
d_model = 384
n_head = 6
d_q = int(d_model / n_head)
dropout = 0.2
from multiHead import MultiHead
from feedForward import FeedForward
class Block(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.multiHead = MultiHead(n_head, d_q)
self.ffwd = FeedForward(d_model)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
normalized1 = self.ln1(x) # It is seen that pre LN works better as it stabilizes better
multiHead = self.multiHead(normalized1)
x = x + multiHead # Residual Addition
normalized2 = self.ln2(x)
ffwd = self.ffwd(normalized2)
x = x + ffwd
return x
if __name__ == "__main__":
x = torch.randn(batch_size, max_len, d_model)
block = Block(d_model, n_head)
output = block(x)
print("Input shape:", x.shape)
print("Output shape from one Transformer Block:", output.shape) |