Architecture
The best way to understand any model structure is to step through the forward pass. I have implemented most of the model structure and training code from scratch following the CS336 curriculum, only using nn.Parameter() and the parent nn.Module for my classes.
Our input in a transformer is a sequence of strings like What is the capital of India? We then encode this input using the tokenizer we trained in the previous step to get a list of token IDs. Thanks to the transformer architecture and parallel computing we can process multiple inputs simultaneously in batches. We can represent this input as a matrix batch x seq_len where each batch is a different sequence. Now let's implement each part of the model step by step and understand how it works.
Embedding Layer
While token IDs tell our model what tokens make up the input, it does not let the model learn their meaning. This is why we convert each unique token to an embedding, which is a vector in dimensional space. This embedding is then a tunable representation of the token that the model can update as it learns. Here is how we implement the module:
- We create a weight matrix of the shape
vocab_size x dsuch that each row contains the embedding for a unique token. We then do a Xavier/Glorot initialization of the weights. - During the forward pass, we simply index the matrix for the specific token IDs and return their embeddings. Our input is
batch x seq_lenand output isbatch x seq_len x d. - During the backward pass for each sequence, only the embeddings that were used get updated and if the same token was used multiple times, the gradients accumulate.
class Embedding(nn.Module):
def __init__(self, vocab_size, d_model, device=None, dtype=torch.float32):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.weights = nn.Parameter(torch.empty(vocab_size, d_model, device=device, dtype=dtype))
std = math.sqrt(2 / (vocab_size + d_model))
torch.nn.init.trunc_normal_(self.weights, mean=0.0, std=std, a=-3 * std, b=3 * std)
def forward(self, token_ids):
return self.weights[token_ids]Normalization
Before we implement this layer, it is important to understand the problem well. The output of every neuron in the linear layers we implement has the form . We know that sums accumulate variance and if every weight has variance and input has variance then our neuron output will have variance . The Xavier initialization sets to solve this before training starts but once weights are updated we cannot ensure that stability. In practice each layer multiplies the variance by some factor that shifts away from one.
If this factor is greater than one then our outputs will explode with a wider range making downstream operations like softmax and activations less effective. If the factor is less than one we see a similar effect. This is also a problem for deeper layers because if the output distribution of the initial layers shift, they are essentially modeling a moving target. Fortunately we have a bunch of normalization techniques to solve this.
The original transformer uses Layer Normalization where we normalize across all features of a sample and then apply an affine transformation with learnable parameters. It is essentially re-centering the mean to zero and re-scaling the variance to one.
Authors of this paper asked what parts of layer normalization were contributing to the final effect and found that re-scaling was far more important than re-centering. They proposed the RMSNorm which only re-scales with Root Mean Square and only computes that instead of the mean and variance in layer norm. We can implement it as the following.
is the trainable parameter we initialize with all ones and we use a hyperparameter to prevent division by 0. This layer does not change shape of the input.
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5, device=None, dtype=None):
super().__init__()
self.d_model = d_model
self.gamma_weights = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
squared_mean = (torch.mean(x**2, dim=-1, keepdim=True))
RMS = torch.sqrt(squared_mean + self.eps)
return (x / RMS) * self.gamma_weightsPositional Encoding
Transformers are permutation invariant. This means that they do not understand positions in a sequence. Previous methods to solve this include absolute positional embeddings where we add a new positional vector to each embedding vector in our sequence that encodes the position for the specific token. Essentially you tell the model the exact position of each token in the sequence and expect it to make inference on the relative position of any pair of tokens. Another method introduced after this was relative positional embeddings where we learn a relative position weight for each pair of sequence positions in the sequence. Rotary Positional Encoding was introduced in 2019 and has since become the default for adding positional information to the embeddings in our model.
The Objective
The authors of RoPE started with a simple goal. If the attention calculation is a function of tokens and their positions, how can we find a function that captures the difference between positions of two tokens in their key and query representations?
This is a more formal representation of the problem where and are their positions in the sequence. Suppose that is a matrix transformation contingent on the position of the token and . Their inner product during attention calculation then becomes:
For this product to depend on the difference between the two positions we need to be some function of the difference . The authors of the paper found that such a condition is achieved with rotation matrices! The full derivation of this process is well documented here.
How it works
RoPE encodes the positional information for each token in its key and query vectors by applying a rotation based on the token's position in the input sequence. Imagine a two dimensional query vector for a token at position . The rotation matrix transformation for such a vector will look like the following.
How do we extend this from dimensions to d dimensions? We divide the vector into smaller d/2 two dimensional pairs and then apply a transformation to each of them. This can be represented as a matrix transformation of the form:
essentially gives each pair a different frequency of rotation which reduces as you move away from the first pair. This ensures that rotations don't wrap around to prevent different tokens having the exact same position encoded to them. If we apply this to each query and key vector and take their dot product we observe that difference in the rotation applied to both vectors is encoded in the inner product. This is great for the model to be able to learn and reason through the relative positions of the two!
How to implement
The full transformation matrix we saw above is sparse and would add computational costs. If we write out the first few terms of the transformed vector we find that the same thing can be written as an element-wise operation. We can first create a list of all pairs for a specific token and find the sin and cos for each. We then split the input vector into two even and odd indices and perform the following operation each:
We can then just interleave them back together, pair them up along the last dimension of the tensor and collapse them back into one to give us the transformed vector. We do this for all tokens in our sequence across all batches. This layer does not change shape of the input.
class RoPE(nn.Module):
def __init__(self, theta, d_k, max_seq_len, device):
super().__init__()
self.theta = theta
self.d_k = d_k
self.max_seq_len = max_seq_len
self.device = device
powers = torch.arange(0, d_k, 2) / d_k
freqs = 1.0 / (theta ** powers)
t = torch.arange(max_seq_len)
angles = torch.outer(t, freqs)
self.register_buffer("cos", torch.cos(angles), persistent=False)
self.register_buffer("sin", torch.sin(angles), persistent=False)
def forward(self, x, token_positions):
cos = self.cos[token_positions]
sin = self.sin[token_positions]
if token_positions.ndim == 1:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
x_rotated_even = x_even * cos - x_odd * sin
x_rotated_odd = x_even * sin + x_odd * cos
result = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
return result.flatten(-2)Multi-Head Attention
For a transformer to understand an input sequence it needs to understand how each token affects the meaning of other tokens. The classic example of how the word bank could mean a river bank or a financial institution. You can understand what it means in the current context based on the words that surround this. This is exactly what attention helps with. Let's understand what happens in a single attention head:
We start with initializing three matrices with learnable parameters. Each of them takes our batch x seq_len x d input and transforms them to respective batch x seq_len x d_k shaped collections of query and key and batch x seq_len x d_v shaped collections of value vectors. This gives us three representations for each token that answer the question:
- Query: What am I looking for as a token?
- Key: What do I contain as a token?
- Value: What do I pass along if selected?
We then perform the attention calculation. The term means we do a dot product of each query vector with all key vectors to give us a seq_len x seq_len shaped output. We then divide the values by to make the output variance match the input variance. We then apply a softmax across each row to get values that add up to one. These are the attention weights (how much each query attends to each key). The causal mask is applied before softmax so future positions get zero weight. We then do a matrix multiplication with the value vectors . This essentially means we rewrite each token's representation as a weighted sum of value vectors of all tokens in the sequence, where the weights are attention scores.
For next token prediction tasks, we use causal masking where we prevent tokens from attending to future tokens because at inference time they will not see any future tokens. This is applied by making all values in the upper triangular half of the attention score matrix negative infinity so that softmax converts them to zeroes.
def scaled_dot_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
numerator = Q @ K.transpose(-2, -1)
term = numerator / math.sqrt(d_k)
if mask is not None:
term = term.masked_fill(mask, float('-inf'))
softmax = torch.softmax(term, dim=-1)
return softmax @ VWhen we are implementing multi-headed attention, we will have multiple such attention heads and split the embedding of each token evenly between them. In this case we let d_k = d / num_heads (with d the model width) and then each head does the exact same computation as above on separate parts of the embedding. We then initialize a fourth weight that takes all the batch x seq_len x d_v inputs and concatenates them back with learnable weights to batch x seq_len x d output. We use multiple heads so that they can learn different kinds of relationships between tokens like semantic similarity, syntactic structure, etc.
In the following implementation we initialize the weights as linear classes. There are multiple ways to do the weights rearrangement but I believe einops is one of the cleanest ones. We go from batch x seq_len x (num_heads d_k) to batch x num_heads x seq_len x d_k and since PyTorch does all matrix operations on the last two dimensions this lets us process all heads concurrently. If we are using RoPE encoding we apply the forward pass to the key and query vectors so that the inner product captures the relative position.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, theta=None, max_seq_len=None, use_rope=False):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.theta = theta
self.max_seq_len = max_seq_len
self.use_rope = use_rope
self.d_k = d_model // num_heads
self.W_Q = Linear(in_features=d_model, out_features=(self.d_k * num_heads))
self.W_K = Linear(in_features=d_model, out_features=(self.d_k * num_heads))
self.W_V = Linear(in_features=d_model, out_features=(self.d_k * num_heads))
self.W_O = Linear(in_features=d_model, out_features=d_model)
if use_rope:
self.rope = RoPE(theta, self.d_k, max_seq_len, device="cpu")
def forward(self, x):
seq_len = x.shape[-2]
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
Q = rearrange(Q, "b n (num_heads d_k) -> b num_heads n d_k", num_heads = self.num_heads)
K = rearrange(K, "b n (num_heads d_k) -> b num_heads n d_k", num_heads = self.num_heads)
V = rearrange(V, "b n (num_heads d_k) -> b num_heads n d_k", num_heads = self.num_heads)
if self.use_rope:
token_positions = torch.arange(seq_len, device=x.device)
Q = self.rope(Q, token_positions)
K = self.rope(K, token_positions)
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
attention = scaled_dot_attention(Q, K, V, mask)
attention = rearrange(attention, "b num_heads n d_k -> b n (num_heads d_k)", num_heads=self.num_heads)
result = self.W_O(attention)
return resultFeed Forward Layer
While attention produces context rich representation of each token, it is still just a linear combination of value vectors. This is why we have a Feed Forward Layer in every transformer block after the attention layer. The original transformers paper used a two layer MLP.
This projects the inputs from to space and back with a ReLU activation function in between. A good intuition behind why FFNs are used is that they allow us to project the context rich data coming from the attention layer to higher dimension for the model to learn inherent patterns. Experiments with other activation functions and layer structures has improved the performance of FFNs. I have used the SwiGLU activation for my model which combines two ideas:
- which is a self gated activation function with non-zero gradients.
- which uses two branches instead of one branch with nonlinearity. The first branch here is a plain linear transformation while the second branch acts as a gate or mask with values between 0 and 1. The element-wise product then selectively lets information pass.
These combine to give us:
This projects our vector embeddings from d dimensional space to 8/3 * d dimensions and back. There's a paper from Noam Shazeer that compares all of these activation functions here.
def SiLU(x):
return x * torch.sigmoid(x)
class FeedForward(nn.Module):
def __init__(self, d_model, device=None, dtype=torch.float32):
super().__init__()
self.d_model = d_model
self.d_ff = int(d_model * 8 / 3)
self.weight1 = Linear(in_features=d_model, out_features=self.d_ff, device=device, dtype=dtype)
self.weight2 = Linear(in_features=self.d_ff, out_features=d_model, device=device, dtype=dtype)
self.weight3 = Linear(in_features=d_model, out_features=self.d_ff, device=device, dtype=dtype)
def forward(self, x):
first = self.weight1(x)
first = SiLU(first)
second = self.weight3(x)
intermediate = first * second
return self.weight2(intermediate)Transformer Block
Putting it all together we get a complete transformer block. Unlike the original transformer that does a post-layer-normalization after each attention and FFN layer, we implement a pre-RMSNorm before the two layers. One possible reason this architecture works well is that we have an uninterrupted residual stream that maintains a gradient highway back to the input during backpropagation.
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, theta=None, max_seq_len=None, use_rope=False):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.theta = theta
self.max_seq_len = max_seq_len
self.use_rope = use_rope
self.MHA = MultiHeadAttention(self.d_model, self.num_heads, self.theta, self.max_seq_len, self.use_rope)
self.FFN = FeedForward(self.d_model)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
def forward(self, x):
first_norm = self.norm1(x)
MHA = self.MHA(first_norm)
x = x + MHA
second_norm = self.norm2(x)
FFN = self.FFN(second_norm)
return x + FFNTransformer
We now take all of the components we have built to complete the transformer model. We start with the embedding module and then have num_layers transformer blocks. We have a normalization layer after these blocks and finally a linear layer that projects our processed inputs from batch x seq_len x d to batch x seq_len x vocab_size so that each token in the input sequence will now have logits for the model's prediction on the next tokens. During generation we will do a softmax of such logits to get a probability distribution over the model's prediction of next tokens. This structure is specific to decoder only models, when transformers are used for other tasks like sentiment analysis the final linear layer structure can be different.
class Transformer(nn.Module):
def __init__(self, num_layers, d_model, num_heads, vocab_size, theta=None, max_seq_len=None, use_rope=False):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.num_heads = num_heads
self.theta = theta
self.max_seq_len = max_seq_len
self.use_rope = use_rope
self.vocab_size = vocab_size
self.embedding = Embedding(vocab_size, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, num_heads, theta, max_seq_len, use_rope) for _ in range(num_layers)])
self.norm = RMSNorm(d_model)
self.linear = Linear(in_features=d_model, out_features=vocab_size)
def forward(self, token_list):
embeddings = self.embedding(token_list)
for block in self.blocks:
embeddings = block(embeddings)
normalized = self.norm(embeddings)
logits = self.linear(normalized)
return logitsPerfect! Now we have our architecture ready. The next section covers how we train the model: the MLE objective and cross-entropy loss.