File size: 2,124 Bytes
dde4f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch, torch.nn as nn, torch.nn.functional as F

batch_size = 64
max_len = 256
d_model = 384
n_head = 6
d_q = int (d_model / n_head) # 384/6 = 64
dropout = 0.2

class Head(nn.Module):
    def __init__(self, d_q):
        super().__init__()
        self.query = nn.Linear(d_model, d_q, bias=False) # Query weight matrix (Wq) = Linear, pass in x with shape (seq, 384) * (384, 64) to get q = (seq, 64) size
        self.key = nn.Linear(d_model, d_q, bias=False)   # k = x * Wk
        self.value = nn.Linear(d_model, d_q, bias=False) # v = x * Wv

        self.register_buffer('tril', torch.tril(torch.ones(max_len, max_len))) # Save it to register_buffer, as a non-trainable parameter / buffer
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, S, D = x.shape # B --> Batch; S --> Seq_length; D --> Dimension

        q = self.query(x) # Shape of q: (Batch, Seq_Len, d_q) = (B, S, 64)
        k = self.key(x) 
        v = self.value(x)

        attention_matrix = torch.matmul(q, k.transpose(-2, -1)) # --> (B, S, 64) * (B, 64, S) --> (B, S, S) shape
        attention_matrix = attention_matrix / (k.size(-1) ** 0.5)

        attention_matrix = attention_matrix.masked_fill(self.tril[:S, :S] == 0, float('-inf')) # Makes upper right triangle True because they are all 0s and all 1s (lower half of triangle) false and wherever it is True, fill it in with -inf or in other words fill the spots with 0s as -inf so as we are creating a causal decoder that isn't bidirectional

        attention_matrix = F.softmax(attention_matrix, dim=-1) # dim = -1, to apply softmax row-wise
        attention_matrix = self.dropout(attention_matrix) # Apply 20% dropout to prevent overfitting
        output = torch.matmul(attention_matrix, v) # --> (B, S, S) * (B, S, 64) --> (B, S, 64) (Original x dimension after concat, so you can now simply add)

        return output

if __name__ == "__main__":
    x = torch.randn(batch_size, max_len, d_model)
    single_head = Head(d_q)
    output = single_head(x)

    print("Input shape:", x.shape)
    print("Output shape from a single head:", output.shape)