import torch, torch.nn as nn, torch.nn.functional as F batch_size = 64 max_len = 256 d_model = 384 n_head = 6 d_q = int (d_model / n_head) # 384/6 = 64 dropout = 0.2 class Head(nn.Module): def __init__(self, d_q): super().__init__() self.query = nn.Linear(d_model, d_q, bias=False) # Query weight matrix (Wq) = Linear, pass in x with shape (seq, 384) * (384, 64) to get q = (seq, 64) size self.key = nn.Linear(d_model, d_q, bias=False) # k = x * Wk self.value = nn.Linear(d_model, d_q, bias=False) # v = x * Wv self.register_buffer('tril', torch.tril(torch.ones(max_len, max_len))) # Save it to register_buffer, as a non-trainable parameter / buffer self.dropout = nn.Dropout(dropout) def forward(self, x): B, S, D = x.shape # B --> Batch; S --> Seq_length; D --> Dimension q = self.query(x) # Shape of q: (Batch, Seq_Len, d_q) = (B, S, 64) k = self.key(x) v = self.value(x) attention_matrix = torch.matmul(q, k.transpose(-2, -1)) # --> (B, S, 64) * (B, 64, S) --> (B, S, S) shape attention_matrix = attention_matrix / (k.size(-1) ** 0.5) attention_matrix = attention_matrix.masked_fill(self.tril[:S, :S] == 0, float('-inf')) # Makes upper right triangle True because they are all 0s and all 1s (lower half of triangle) false and wherever it is True, fill it in with -inf or in other words fill the spots with 0s as -inf so as we are creating a causal decoder that isn't bidirectional attention_matrix = F.softmax(attention_matrix, dim=-1) # dim = -1, to apply softmax row-wise attention_matrix = self.dropout(attention_matrix) # Apply 20% dropout to prevent overfitting output = torch.matmul(attention_matrix, v) # --> (B, S, S) * (B, S, 64) --> (B, S, 64) (Original x dimension after concat, so you can now simply add) return output if __name__ == "__main__": x = torch.randn(batch_size, max_len, d_model) single_head = Head(d_q) output = single_head(x) print("Input shape:", x.shape) print("Output shape from a single head:", output.shape)