File size: 776 Bytes
dde4f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch, torch.nn as nn

batch_size = 64
max_len = 256
d_model = 384
n_head = 6
d_q = int(d_model / n_head) 
dropout = 0.2

class FeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(d_model, 4*d_model), # Expand dimension, so when applying ReLU, you mitigate loss of information issue
            nn.ReLU(),
            nn.Linear(4*d_model, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.seq(x)

if __name__ == '__main__':
    x = torch.randn(batch_size, max_len, d_model)
    ffwd = FeedForward(d_model)
    output = ffwd(x)

    print("Input shape:", x.shape)
    print("Output shape from FeedForward network:", output.shape)