The Transformer
Understanding tokenization, the transformer architecture, training, and attention variants from scratch.
Tokenization
LLMs need to understand words and characters to model language. When we give them an input sequence like What is the capital of India? they break it down into smaller parts called tokens. These tokens are then converted to vector embeddings that let models understand and tweak their meaning. But how do we know what these tokens are? This is defined by the tokenizer, which is the first component we train for any language model today. Think of it like defining a dictionary for the language model so that it knows what tokens make up the input sequence and can look up their meaning. This doesn't sound that hard, and I'm sure you can think of obvious approaches like:
-
Character level encoding: Why not just assign a unique integer ID to each character or letter? The model can just break words down into constituent characters and map them to the embeddings. This does give us a small and well defined vocabulary but comes with some problems. The computation cost in the attention layer of a transformer scales quadratically with the number of input tokens. If every character is an independent token, we will have to process lots of tokens! We also want tokens in our vocabulary to carry semantic meaning or information that the model can work with, which is not possible at a character level.
-
Word level encoding: Why not just assign a unique integer ID to every unique word we come across? This gives us tokens that are information dense with reasonable attention costs. But there are still problems here. The vocabulary size would be immense if we account for unique word variants across languages. Our model will also be incapable of tokenizing words it was not trained on.
We want a tokenizer that can keep a reasonable vocabulary size while having information dense tokens. Papers proposed rule based and morphological methods but the main breakthrough came with the byte-level BPE tokenizer which was popularized by GPT-2 from OpenAI.
UTF-8 and BPE
Before we can understand the tokenizer, we need to agree on our representation of characters. The global standard here is Unicode which assigns each character a unique integer ID called a codepoint. The letter a has codepoint 97, and the letter ê (e with circumflex, as in French) has codepoint 234. Unicode defines 150K+ such characters across languages and also including emojis and symbols. Within Unicode we have an encoding standard called UTF-8 that represents each such codepoint as a collection of bytes with a variable length between one and four bytes. A word or a sentence with UTF-8 encoding will be a list of byte values of its constituent characters.
text = "hello"
print(list(text.encode("utf-8")))The Byte Pair Encoding algorithm was introduced in 1994 to compress such byte representation. It takes an input and finds the most frequent contiguous byte pair and replaces it with a new byte value that was not in the initial set. In the example input aaabdaaabac we see that aa is the most frequent pair which would be replaced by a new character say Z = aa to give us a compressed form of the original input as ZabdZabac. This can be done repeatedly until a desired vocabulary size is reached.
Implementation
We start with a text corpus that will be our training dataset. I have used TinyStories and OpenWebText. You can download them as a .txt file. Before we get to the algorithm we do some pre-tokenization steps that are helpful during training.
You might have special tokens in your corpus like <|endoftext|> that represent end of documents or lines. You don't want them to be tokenized because they serve a special purpose and also define boundaries in your dataset. We therefore split our text into chunks between such special tokens and also remove them from the corpus.
We then do pre-tokenization which applies a special regex pattern introduced in the GPT-2 codebase that splits our chunks into smaller chunks based on pre-defined rules. Here is the pattern we use:
'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
This will split words, punctuations, numbers etc. into appropriate chunks for tokenization. We do this because we don't want to merge tokens across words. We also want punctuations or suffixes like 't or 's to remain independent tokens instead of merging with their root words because they carry a generalizable meaning. This video from Andrej Karpathy does a great job of explaining this part.
with open(input_path, "r", encoding="utf-8") as f:
text = f.read()
special_pattern = "|".join(re.escape(token) for token in special_tokens)
processed = re.split(special_pattern, text)
pre_tokenized = []
for chunk in processed:
tokenized = re.finditer(GPT_PATTERN, chunk)
for token in tokenized:
pre_tokenized.append(token.group())At this stage we have a collection of chunks of our string input. We now convert that to its byte pair representation by encoding it with the UTF-8 standard. We will then initialize our vocab as a dictionary with the index keys through 0-256 and their byte representation as values and merges to be a list of all byte pair merges our algorithm will make. We will then iterate over all possible pairs in the chunks in our collection and build a dictionary that tracks their frequency. We then pick the most frequent pair and add a new index to our vocab with the two bytes concatenated as its value and update merges. We then iterate through our pairs again and whenever we see the merged pair, we replace it with the new byte representation. We do this repeatedly until a desired vocabulary size is reached.
self.vocab = {i: bytes([i]) for i in range(256)}
self.merges = []
sequences = []
for chunk in pre_tokenized:
sequences.append(list(chunk.encode("utf-8")))
num_merges = vocab_size - len(special_tokens) - 256
pbar = tqdm(total=num_merges, desc="Training (simple)", unit="merge")
while len(self.vocab) < vocab_size - len(special_tokens):
if tracker:
tracker.on_merge_start()
pairs = {}
for token_bytes in sequences:
for a, b in zip(token_bytes, token_bytes[1:]):
if (a, b) in pairs:
pairs[(a, b)] += 1
else:
pairs[(a, b)] = 1
best_pair = max(pairs, key=lambda k: (pairs[k], self.vocab[k[0]], self.vocab[k[1]]))
new_id = len(self.vocab)
self.vocab[new_id] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
self.merges.append((best_pair[0], best_pair[1]))
for token_bytes in sequences:
i = 0
while i < len(token_bytes) - 1:
if token_bytes[i] == best_pair[0] and token_bytes[i + 1] == best_pair[1]:
token_bytes[i] = new_id
del token_bytes[i + 1]
else:
i += 1
merge_time = time() - tracker.merge_start_time if tracker else 0.0
merge_idx = len(self.merges) - 1We then add our special tokens to vocab as individual tokens. This is the entire training process which returns the vocab and merges that we can use for encoding and decoding with this tokenizer in the future. During encoding we take a string input and perform the pre-tokenization steps and convert the chunks to their UTF-8 byte representations. We then lookup all the byte pairs from the input in the merges dictionary and find the one with the smallest index (i.e. the earliest merge during training) and replace it with the corresponding pair from vocab and do this until we have no more merges found on lookup. For decoding our input is token IDs that we simply lookup in vocab and return.
Optimization
While this implementation works in theory, when we are dealing with large datasets it is important to optimize parts of the code. Here are a couple things you can try:
Process pooling for pre-tokenization: Since we are doing CPU bound regex based splitting on a large text file, we can load the file in binary and split it into num_workers chunks. We can then have multiple processes pre-tokenizing each chunk and concatenate the results.
special_pattern = "|".join(re.escape(t) for t in special_tokens) if special_tokens else ""
split_token = special_tokens[0].encode("utf-8") if special_tokens else b"\n"
with open(input_path, "rb") as f:
boundaries = find_chunk_boundaries(f, num_workers, split_token)
chunk_args = [
(input_path, special_pattern, boundaries[i], boundaries[i + 1])
for i in range(len(boundaries) - 1)
if boundaries[i] < boundaries[i + 1]
]
with mp.Pool(num_workers) as pool:
chunk_results = pool.starmap(_pretokenize_chunk, chunk_args)
pre_tokenized = []
for result in chunk_results:
pre_tokenized.extend(result)Efficient merging: The above approach was naive. Instead of rebuilding counts from scratch each iteration, we can maintain a dictionary of pairs and their counts. Whenever we merge a and b, decrement affected neighbors (x, a) and (b, x) and add new neighbor pairs with the merged token.
pairs: dict[tuple[int, int], int] = {}
for seq in sequences:
for a, b in zip(seq, seq[1:]):
pairs[(a, b)] = pairs.get((a, b), 0) + 1
num_merges = vocab_size - len(special_tokens) - 256
pbar = tqdm(total=num_merges, desc="Training (efficient)", unit="merge")
while len(self.vocab) < vocab_size - len(special_tokens):
if tracker:
tracker.on_merge_start()
best_pair = max(
(p for p in pairs if pairs[p] > 0),
key=lambda k: (pairs[k], self.vocab[k[0]], self.vocab[k[1]]),
)
A, B = best_pair
best_count = pairs[best_pair]
new_id = len(self.vocab)
self.vocab[new_id] = self.vocab[A] + self.vocab[B]
self.merges.append((A, B))
for seq in sequences:
i = 0
while i < len(seq) - 1:
if seq[i] == A and seq[i + 1] == B:
if i > 0:
pairs[(seq[i - 1], A)] -= 1
if i + 2 < len(seq):
pairs[(B, seq[i + 2])] -= 1
pairs[(A, B)] -= 1
seq[i] = new_id
del seq[i + 1]
if i > 0:
pairs[(seq[i - 1], new_id)] = pairs.get((seq[i - 1], new_id), 0) + 1
if i + 1 < len(seq):
pairs[(new_id, seq[i + 1])] = pairs.get((new_id, seq[i + 1]), 0) + 1
else:
i += 1When training on the TinyStories validation set (~22MB) the naive method took 373.58 seconds while the optimized code took 103.37 seconds and both reached the same outputs. This is a 3.6x speedup which will be higher with larger datasets as the overhead time cost for spinning up processes becomes negligible compared to the time for each process to run.
Architecture
The best way to understand any model structure is to step through the forward pass. I have implemented most of the model structure and training code from scratch following the CS336 curriculum, only using nn.Parameter() and the parent nn.Module for my classes.
Our input in a transformer is a sequence of strings like What is the capital of India? We then encode this input using the tokenizer we trained in the previous step to get a list of token IDs. Thanks to the transformer architecture and parallel computing we can process multiple inputs simultaneously in batches. We can represent this input as a matrix batch x seq_len where each batch is a different sequence. Now let's implement each part of the model step by step and understand how it works.
Embedding Layer
While token IDs tell our model what tokens make up the input, it does not let the model learn their meaning. This is why we convert each unique token to an embedding, which is a vector in dimensional space. This embedding is then a tunable representation of the token that the model can update as it learns. Here is how we implement the module:
- We create a weight matrix of the shape
vocab_size x dsuch that each row contains the embedding for a unique token. We then do a Xavier/Glorot initialization of the weights. - During the forward pass, we simply index the matrix for the specific token IDs and return their embeddings. Our input is
batch x seq_lenand output isbatch x seq_len x d. - During the backward pass for each sequence, only the embeddings that were used get updated and if the same token was used multiple times, the gradients accumulate.
class Embedding(nn.Module):
def __init__(self, vocab_size, d_model, device=None, dtype=torch.float32):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.weights = nn.Parameter(torch.empty(vocab_size, d_model, device=device, dtype=dtype))
std = math.sqrt(2 / (vocab_size + d_model))
torch.nn.init.trunc_normal_(self.weights, mean=0.0, std=std, a=-3 * std, b=3 * std)
def forward(self, token_ids):
return self.weights[token_ids]Normalization
Before we implement this layer, it is important to understand the problem well. The output of every neuron in the linear layers we implement has the form . We know that sums accumulate variance and if every weight has variance and input has variance then our neuron output will have variance . The Xavier initialization sets to solve this before training starts but once weights are updated we cannot ensure that stability. In practice each layer multiplies the variance by some factor that shifts away from one.
If this factor is greater than one then our outputs will explode with a wider range making downstream operations like softmax and activations less effective. If the factor is less than one our outputs shrink and gradients can vanish. This is also a problem for deeper layers because if the output distribution of the initial layers shift, they are essentially modeling a moving target. Fortunately we have a bunch of normalization techniques to solve this.
The original transformer uses Layer Normalization where we normalize across all features of a sample and then apply an affine transformation with learnable parameters. It is essentially re-centering the mean to zero and re-scaling the variance to one.
Authors of this paper asked what parts of layer normalization were contributing to the final effect and found that re-scaling was far more important than re-centering. They proposed the RMSNorm which only re-scales with Root Mean Square and only computes that instead of the mean and variance in layer norm. We can implement it as the following.
is the trainable parameter we initialize with all ones and we use a hyperparameter to prevent division by 0. This layer does not change shape of the input.
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5, device=None, dtype=None):
super().__init__()
self.d_model = d_model
self.gamma_weights = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
squared_mean = (torch.mean(x**2, dim=-1, keepdim=True))
RMS = torch.sqrt(squared_mean + self.eps)
return (x / RMS) * self.gamma_weightsPositional Encoding
Transformers are permutation invariant. This means that they do not understand positions in a sequence. Previous methods to solve this include absolute positional embeddings where we add a new positional vector to each embedding vector in our sequence that encodes the position for the specific token. Essentially you tell the model the exact position of each token in the sequence and expect it to make inference on the relative position of any pair of tokens. Another method introduced after this was relative positional embeddings where we learn a relative position weight for each pair of sequence positions in the sequence. Rotary Positional Encoding was introduced in 2019 and has since become the default for adding positional information to the embeddings in our model.
The Objective
The authors of RoPE started with a simple goal. If the attention calculation is a function of tokens and their positions, how can we find a function that captures the difference between positions of two tokens in their key and query representations?
This is a more formal representation of the problem where and are their positions in the sequence. Suppose that is a matrix transformation contingent on the position of the token and . Their inner product during attention calculation then becomes:
For this product to depend on the difference between the two positions we need to be some function of the difference . The authors of the paper found that such a condition is achieved with rotation matrices! The full derivation of this process is well documented here.
How it works
RoPE encodes the positional information for each token in its key and query vectors by applying a rotation based on the token's position in the input sequence. Imagine a two dimensional query vector for a token at position . The rotation matrix transformation for such a vector will look like the following.
How do we extend this from dimensions to d dimensions? We divide the vector into smaller d/2 two dimensional pairs and then apply a transformation to each of them. This can be represented as a matrix transformation of the form:
essentially gives each pair a different frequency of rotation which reduces as you move away from the first pair. This ensures that rotations don't wrap around to prevent different tokens having the exact same position encoded to them. If we apply this to each query and key vector and take their dot product we observe that difference in the rotation applied to both vectors is encoded in the inner product. This is great for the model to be able to learn and reason through the relative positions of the two!
How to implement
The full transformation matrix we saw above is sparse and would add computational costs. If we write out the first few terms of the transformed vector we find that the same thing can be written as an element-wise operation. We can first create a list of all pairs for a specific token and find the sin and cos for each. We then split the input vector into two even and odd indices and perform the following operation each:
We can then just interleave them back together, pair them up along the last dimension of the tensor and collapse them back into one to give us the transformed vector. We do this for all tokens in our sequence across all batches. This layer does not change shape of the input.
class RoPE(nn.Module):
def __init__(self, theta, d_k, max_seq_len, device):
super().__init__()
self.theta = theta
self.d_k = d_k
self.max_seq_len = max_seq_len
self.device = device
powers = torch.arange(0, d_k, 2) / d_k
freqs = 1.0 / (theta ** powers)
t = torch.arange(max_seq_len)
angles = torch.outer(t, freqs)
self.register_buffer("cos", torch.cos(angles), persistent=False)
self.register_buffer("sin", torch.sin(angles), persistent=False)
def forward(self, x, token_positions):
cos = self.cos[token_positions]
sin = self.sin[token_positions]
if token_positions.ndim == 1:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
x_rotated_even = x_even * cos - x_odd * sin
x_rotated_odd = x_even * sin + x_odd * cos
result = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
return result.flatten(-2)Multi-Head Attention
For a transformer to understand an input sequence it needs to understand how each token affects the meaning of other tokens. The classic example of how the word bank could mean a river bank or a financial institution. You can understand what it means in the current context based on the words that surround this. This is exactly what attention helps with. Let's understand what happens in a single attention head:
We start with initializing three matrices with learnable parameters. Each of them takes our batch x seq_len x d input and transforms them to respective batch x seq_len x d_k shaped collections of query and key and batch x seq_len x d_v shaped collections of value vectors. This gives us three representations for each token that answer the question:
- Query: What am I looking for as a token?
- Key: What do I contain as a token?
- Value: What do I pass along if selected?
We then perform the attention calculation. The term means we do a dot product of each query vector with all key vectors to give us a seq_len x seq_len shaped output. We then divide the values by to make the output variance match the input variance. We then apply a softmax across each row to get values that add up to one. These are the attention weights (how much each query attends to each key). The causal mask is applied before softmax so future positions get zero weight. We then do a matrix multiplication with the value vectors . This essentially means we rewrite each token's representation as a weighted sum of value vectors of all tokens in the sequence, where the weights are attention scores.
For next token prediction tasks, we use causal masking where we prevent tokens from attending to future tokens because at inference time they will not see any future tokens. This is applied by making all values in the upper triangular half of the attention score matrix negative infinity so that softmax converts them to zeroes.
def scaled_dot_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
numerator = Q @ K.transpose(-2, -1)
term = numerator / math.sqrt(d_k)
if mask is not None:
term = term.masked_fill(mask, float('-inf'))
softmax = torch.softmax(term, dim=-1)
return softmax @ VWhen we are implementing multi-headed attention, we will have multiple such attention heads and split the embedding of each token evenly between them. In this case we let d_k = d / num_heads (with d the model width) and then each head does the exact same computation as above on separate parts of the embedding. We then initialize a fourth weight that takes all the batch x seq_len x d_v inputs and concatenates them back with learnable weights to batch x seq_len x d output. We use multiple heads so that they can learn different kinds of relationships between tokens like semantic similarity, syntactic structure, etc.
In the following implementation we initialize the weights as linear classes. There are multiple ways to do the weights rearrangement but I believe einops is one of the cleanest ones. We go from batch x seq_len x (num_heads d_k) to batch x num_heads x seq_len x d_k and since PyTorch does all matrix operations on the last two dimensions this lets us process all heads concurrently. If we are using RoPE encoding we apply the forward pass to the key and query vectors so that the inner product captures the relative position.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, theta=None, max_seq_len=None, use_rope=False):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.theta = theta
self.max_seq_len = max_seq_len
self.use_rope = use_rope
self.d_k = d_model // num_heads
self.W_Q = Linear(in_features=d_model, out_features=(self.d_k * num_heads))
self.W_K = Linear(in_features=d_model, out_features=(self.d_k * num_heads))
self.W_V = Linear(in_features=d_model, out_features=(self.d_k * num_heads))
self.W_O = Linear(in_features=d_model, out_features=d_model)
if use_rope:
self.rope = RoPE(theta, self.d_k, max_seq_len, device="cpu")
def forward(self, x):
seq_len = x.shape[-2]
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
Q = rearrange(Q, "b n (num_heads d_k) -> b num_heads n d_k", num_heads = self.num_heads)
K = rearrange(K, "b n (num_heads d_k) -> b num_heads n d_k", num_heads = self.num_heads)
V = rearrange(V, "b n (num_heads d_k) -> b num_heads n d_k", num_heads = self.num_heads)
if self.use_rope:
token_positions = torch.arange(seq_len, device=x.device)
Q = self.rope(Q, token_positions)
K = self.rope(K, token_positions)
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
attention = scaled_dot_attention(Q, K, V, mask)
attention = rearrange(attention, "b num_heads n d_k -> b n (num_heads d_k)", num_heads=self.num_heads)
result = self.W_O(attention)
return resultFeed Forward Layer
While attention produces context rich representation of each token, it is still just a linear combination of value vectors. This is why we have a Feed Forward Layer in every transformer block after the attention layer. The original transformers paper used a two layer MLP.
This projects the inputs from to space and back with a ReLU activation function in between. A good intuition behind why FFNs are used is that they allow us to project the context rich data coming from the attention layer to higher dimension for the model to learn inherent patterns. Experiments with other activation functions and layer structures has improved the performance of FFNs. I have used the SwiGLU activation for my model which combines two ideas:
- which is a self gated activation function with non-zero gradients.
- which uses two branches instead of one branch with nonlinearity. The first branch here is a plain linear transformation while the second branch acts as a gate or mask with values between 0 and 1. The element-wise product then selectively lets information pass.
These combine to give us:
This projects our vector embeddings from d dimensional space to 8/3 * d dimensions and back. There's a paper from Noam Shazeer that compares all of these activation functions here.
def SiLU(x):
return x * torch.sigmoid(x)
class FeedForward(nn.Module):
def __init__(self, d_model, device=None, dtype=torch.float32):
super().__init__()
self.d_model = d_model
self.d_ff = int(d_model * 8 / 3)
self.weight1 = Linear(in_features=d_model, out_features=self.d_ff, device=device, dtype=dtype)
self.weight2 = Linear(in_features=self.d_ff, out_features=d_model, device=device, dtype=dtype)
self.weight3 = Linear(in_features=d_model, out_features=self.d_ff, device=device, dtype=dtype)
def forward(self, x):
first = self.weight1(x)
first = SiLU(first)
second = self.weight3(x)
intermediate = first * second
return self.weight2(intermediate)Transformer Block
Putting it all together we get a complete transformer block. Unlike the original transformer that does a post-layer-normalization after each attention and FFN layer, we implement a pre-RMSNorm before the two layers. One possible reason this architecture works well is that we have an uninterrupted residual stream that maintains a gradient highway back to the input during backpropagation.
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, theta=None, max_seq_len=None, use_rope=False):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.theta = theta
self.max_seq_len = max_seq_len
self.use_rope = use_rope
self.MHA = MultiHeadAttention(self.d_model, self.num_heads, self.theta, self.max_seq_len, self.use_rope)
self.FFN = FeedForward(self.d_model)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
def forward(self, x):
first_norm = self.norm1(x)
MHA = self.MHA(first_norm)
x = x + MHA
second_norm = self.norm2(x)
FFN = self.FFN(second_norm)
return x + FFNTransformer
We now take all of the components we have built to complete the transformer model. We start with the embedding module and then have num_layers transformer blocks. We have a normalization layer after these blocks and finally a linear layer that projects our processed inputs from batch x seq_len x d to batch x seq_len x vocab_size so that each token in the input sequence will now have logits for the model's prediction on the next tokens. During generation we will do a softmax of such logits to get a probability distribution over the model's prediction of next tokens. This structure is specific to decoder only models, when transformers are used for other tasks like sentiment analysis the final linear layer structure can be different.
class Transformer(nn.Module):
def __init__(self, num_layers, d_model, num_heads, vocab_size, theta=None, max_seq_len=None, use_rope=False):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.num_heads = num_heads
self.theta = theta
self.max_seq_len = max_seq_len
self.use_rope = use_rope
self.vocab_size = vocab_size
self.embedding = Embedding(vocab_size, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, num_heads, theta, max_seq_len, use_rope) for _ in range(num_layers)])
self.norm = RMSNorm(d_model)
self.linear = Linear(in_features=d_model, out_features=vocab_size)
def forward(self, token_list):
embeddings = self.embedding(token_list)
for block in self.blocks:
embeddings = block(embeddings)
normalized = self.norm(embeddings)
logits = self.linear(normalized)
return logitsPerfect! Now we have our architecture ready. The next section covers how we train the model: the MLE objective and cross-entropy loss.
Training
Any machine learning model is a probability distribution parameterized by parameters across all its layers denoted by . During training, it observes data sampled from a true distribution and minimizes the difference (KL-divergence) between the modeled distribution and sampled true distribution. In practice we use the Maximum Likelihood Estimate (MLE) objective, which simply means maximizing the likelihood of observing the training data if we were sampling it from our modeled distribution This is exactly what we will do with our decoder-only transformer.
Training Objective
Since we are predicting the next token conditional on tokens we have observed before it, we can represent the MLE objective as the following:
Instead of maximizing this term, we can minimize its negative log for numerical stability. This will also turn the products into summation. This gives us the training objective:
Recall from the Architecture section that the transformer model gives us z tensor of length vocab_size for each input token that contains logits for the model’s next token prediction. We can do a softmax of these values to obtain a probability distribution that can then be plugged into the above objective. For numerical stability, we subtract the maximum logit value from all values before applying softmax.
When implementing this in code we take inputs as the logits for each token with the shape batch x seq_len x vocab_size along with targets with the shape batch x seq_len which is just tokens shifted by one position. This implementation is equivalent to using torch.nn.CrossEntropyLoss with PyTorch.
def cross_entropy_loss(logits, targets):
max_logit = logits.max(dim=-1, keepdim=True).values
shifted = logits - max_logit
log_sum_exp = torch.log(torch.exp(shifted).sum(dim=-1, keepdim=True))
log_softmax = shifted - log_sum_exp
target_log_probs = log_softmax.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
return -target_log_probs.mean()AdamW Optimizer
Once we have found the loss, a simple loss.backward() backpropagates gradients to all parameters of the model. We then take an optimizer step where we update the parameters to minimize loss. The most simple form of doing this is Stochastic Gradient Descent where we use a fixed learning rate to update the parameters. Over time we have seen new techniques amongst which AdamW is the most widely used optimizer. It uses two crucial ideas that give us faster and more efficient navigation of the loss surface.
-
Momentum: Instead of an update with the recent gradient at each step, we keep an exponential moving average of all past gradients or the first moment and use that along with the current gradient at update time. Think of this like having momentum. If our previous gradients have been in the same direction, we will take faster step in the direction (accelerate) and if they have been noisy and in opposite direction we will dampen each step.
-
Adaptive Learning Rate: Instead of every parameter being updated with the same learning rate we use the RMSProp idea to keep track of the exponential moving average of the square of the gradients or the second moment. We then divide our update term with a square root of this term so that gradients with large steps historically get smaller updates and those with small get larger updates.
and are hyperparameters but we usually have their values be 0.9 and 0.999. Since the moments are initialized to zeroes, we also apply some bias correction so that the first updates are not too small. These updates give us a final expression for the moments:
The last important component is the W in AdamW which stands for weight decay. The standard practice is to add a penalty to the loss in form of L2/L1 regularization. But since we are dividing by the second moment that will keep track of the penalties, larger gradients will get less penalty when we divide by . We solve this by decoupling the decay from the gradient so every parameter has the same rate of decay. This gives us the final expression for the weight update. We use for numerical stability.
We implement AdamW using the torch.optim.Optimizer parent class. param_groups is a dictionary that contains parameter groups. In our case there is only a single group of all parameters. You also define all the hyperparameters as key-value pairs in this dictionary. Every parameter has a state dictionary where we then store the moments and a counter of number of updates that parameter has been through. We then update the p.grad.data based on the above expression without contributing to the computation graph.
class AdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}
super().__init__(params, defaults)
def step(self, closure=None):
loss = None if closure is None else closure()
for group in self.param_groups:
lr = group["lr"]
beta_1, beta_2 = group["betas"]
epsilon = group["eps"]
lamda = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state["t"] = 1
state["v"] = torch.zeros_like(p.data, requires_grad=False)
state["m"] = torch.zeros_like(p.data, requires_grad=False)
gradient = p.grad.data
state["m"] = beta_1 * state["m"] + (1 - beta_1) * gradient
state["v"] = beta_2 * state["v"] + (1 - beta_2) * (gradient**2)
adjusted_alpha = lr * ((1 - (beta_2**state["t"])) ** 0.5) / (1 - (beta_1**state["t"]))
theta_old = p.data.clone()
p.data = p.data - adjusted_alpha * state["m"] / (torch.sqrt(state["v"]) + epsilon)
p.data = p.data - (lr * lamda * theta_old)
state["t"] += 1
return lossLR Schedule and Gradient Clipping
I have also implemented Learning Rate Scheduling which changes the LR based on how many steps we are into training. We start with warmup that linearly increases our LR from to culminating at step This is helpful because in the initial phase of training our gradient estimates are noisy and the moments haven’t accumulated enough. In such a case having a smaller learning rate prevents us from taking huge steps in potentially wrong directions along the loss curve. Once we reach the peak we then slowly decay the learning rate with a cosine decay that decelerates at both ends. This ensures small steps as the model nears convergence.
def learning_rate_schedule(t, alpha_max, alpha_min, T_w, T_c):
if t < T_w:
return t * alpha_max / T_w
if t > T_c:
return alpha_min
cos_term = math.cos((t - T_w) / (T_c - T_w) * math.pi)
return alpha_min + 0.5 * (1 + cos_term) * (alpha_max - alpha_min)I also implemented gradient clipping which is used to prevent exploding gradients in deep neural networks. We set a max_norm upper bound and whenever the combined L2 norm of all our gradients exceeds this value, we scale down all gradients uniformly by the factor max_norm / (norm + eps) where eps is for numerical stability. Instead of creating one flat vector with all gradient values we find the norms of each gradient tensor and then stack them into a new tensor and use its L2 norm as the global norm.
def gradient_clipping(parameters, max_norm, eps=1e-6):
grads = [p.grad for p in parameters if p.grad is not None]
norm = torch.norm(torch.stack([torch.norm(g, 2) for g in grads]), 2)
if norm > max_norm:
scaling_factor = max_norm / (norm + eps)
for g in grads:
g.detach().mul_(scaling_factor)
return norm.item()Data Loading and Training
We start with our dataset. Most open source datasets will have dictionary samples with text within them. We will then tokenize each sample with our tokenizer after adding the <|endoftext|> token in the end. We then concatenate all these token IDs into one continuous file and then save it as a binary file. During training we lazy load this binary file using np.memmap that doesn’t load all the data in RAM and instead selectively reads the token IDs we ask for in the later loop using the appropriate offset in the memory buffer. We then randomly pick starting positions in this file and then find seq_len tokens after each point and find the target values with a +1 offset. This becomes the X_batch, y_batch for our training loop that we load to the device with our model weights.
def data_loading(x, batch_size, context_length, device):
starts = torch.randint(0, len(x) - context_length, (batch_size,))
sequences = torch.stack([torch.from_numpy(x[s:s + context_length].astype(np.int64)) for s in starts])
targets = torch.stack([torch.from_numpy(x[s + 1:s + 1 + context_length].astype(np.int64)) for s in starts])
return sequences.to(device), targets.to(device)Training then is simply injecting the correct learning rate based on the current step, loading a batch of data and running a forward pass. We then find the cross entropy loss between the target and predictions and backpropagate with loss.backward() and clip the gradients. We then run the optimizer step to update weights and clear old gradients before the next loop.
Generation
Once we have a trained model we run inference on it to generate tokens. This is autoregressive in nature which means we take an input prompt, process it to get the next token, append it to our prompt and repeat until we reach a max_tokens limit or the <|endoftext|> token. During this process we can tweak some knobs to change how the model samples the next token.
-
Temperature controls the sharpness of the probability distribution before we sample a token. Each logit is divided by . When , this amplifies differences between logits and sharpens the distribution (unlikely tokens get less probability mass). When , this compresses differences and flattens the distribution (unlikely tokens get more probability mass).
-
We use Nucleus Sampling where we arrange all logits in a descending order and select the top cumulative probability mass for a certain threshold and discount all other tokens.
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=200, temperature=1.0, top_p=0.9, device=None):
if device is None:
device = next(model.parameters()).device
model.eval()
token_ids = tokenizer.encode(prompt)
for _ in range(max_tokens):
input_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
logits = model(input_ids)
next_logits = logits[0, -1, :]
if temperature > 0:
scaled = next_logits / temperature
probs = torch.softmax(scaled, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = (cumsum - sorted_probs) > top_p
sorted_probs[mask] = 0.0
sorted_probs /= sorted_probs.sum()
idx = torch.multinomial(sorted_probs, num_samples=1)
next_id = sorted_indices[idx].item()
else:
next_id = next_logits.argmax().item()
token_ids.append(next_id)
decoded = tokenizer.decode([next_id])
if decoded == "<|endoftext|>":
break
model.train()
return tokenizer.decode(token_ids)Attention Variants
Since the first transformer paper, open source models have introduced a ton of variations to the core architecture to optimize for training and inference. A large chunk of these gains come from the core attention mechanism. Let’s understand some of them.
KV Cache
You’ll notice that at each step during generation, we are recomputing the key and value vectors for the entire sequence. This is unnecessary since our weights are frozen which means the K and V values for past tokens do not change at each step. KV caching is a method that stores these values as a cache on the GPU memory. This gives us two modes for our forward pass. In the prefill mode we do a forward pass on the initial user prompt with all tokens at once while adding the keys and values we calculated to the cache. In the decode mode our input is the new token so our QKV projections happen with a batch x 1 x d input and we then append the cached KV values, run attention and update the cache with KV of the new token for the next step.
if past_kv is not None:
k_cache, v_cache = past_kv
K = torch.cat([k_cache, K], dim=2)
V = torch.cat([v_cache, V], dim=2)
new_kv_cache = (K, V)During decode most projections in our transformer block that would transform all seq_len tokens now only transform one token in each batch. This means fewer FLOPs during QKV projection. It also means our attention operation goes from being to since we only need to multiply the current query token with the K matrix. We see similar gains in the FFN layer.
MQA
Having a large KV cache imposes memory costs during the decode phase. This is because we have to move the cache from HBM to the registers each time we do a forward pass. Since the cache grows linearly with sequence length, this adds latency during inference. Noam Shazeer introduced Multi Query Attention as a variation of the standard MHA to address this problem. It uses the same K and V matrix across all attention heads while still keeping unique Q matrices for each. This means you only have to transfer the equivalent of one head’s KV weights to run a forward pass. This method does make inference faster but it also leads to some degradation in performance since we are taking away additional representational capacity from the model.
GQA
Grouped Query Attention (GQA) was introduced to navigate this tradeoff between speed and performance. The authors reduced the number of KV matrices to some number between num_heads and zero giving you a spectrum from standard MHA to MQA. Moreover they also demonstrated that you don’t have to train a model from scratch and can take an existing MHA architecture, batch the KV matrices into desired number of groups and mean pool each group. You could then do some additional training, somewhere in the order of 5% of the main training to achieve comparable performance. Leading open-source models including Llama 2 and Llama 3 adopted this architecture.
MLA
Multi Head Latent Attention was introduced in the DeepSeek V2 paper. Instead of reducing KV matrices it compresses the KV representation to a lower dimensional latent space. This compressed cache has lower memory requirements and can be up projected during forward passes. Let’s understand how this works.
Our MLA block takes as input tokens with hidden state . We apply the transformation to get the latent KV cache and then concatenate the cache to this representation and up project it with and to give us the KV matrices. We can follow the same process for the Q matrix for parameter efficiency even though we are not caching it. Just like MHA we can then rearrange the values and have a final output transformation. One concern here might be the compute cost of compressing and decompressing the matrices but we can apply an absorption trick to the attention step.
This lets us use the latent KV representation and, more importantly, we can precompute the bracketed term since it is constant across sequence positions for a given layer. One concern here is that RoPE breaks the absorption trick since the transformation we apply depends on the position of the token. The DeepSeek authors solve this with decoupled RoPE where they split each query and key into a content piece without RoPE and a smaller positional piece with RoPE. Both go through the attention step independently to produce scores that are then added before softmax and the weighted sum over the value matrix.
class MultiHeadLatentAttention(nn.Module):
def __init__(self, d_model, num_heads, q_latent_dim, kv_latent_dim, head_dim=None, rope=None, rope_dim=None):
super().__init__()
self.d_model = d_model
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.rope = rope
self.rope_dim = rope_dim
self.W_DQ = nn.Linear(d_model, q_latent_dim)
self.W_UQ = nn.Linear(q_latent_dim, d_model)
self.W_DKV = nn.Linear(d_model, kv_latent_dim)
self.W_UK = nn.Linear(kv_latent_dim, d_model)
self.W_UV = nn.Linear(kv_latent_dim, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.W_QR = nn.Linear(q_latent_dim, num_heads * rope_dim)
self.W_KR = nn.Linear(d_model, rope_dim)
def forward(self, x, token_positions, cache=None):
Q_down = self.W_DQ(x)
Q_C = self.W_UQ(Q_down)
KV_latent = self.W_DKV(x)
if cache is not None:
KV_latent = torch.cat((cache, KV_latent), dim=1)
cache = KV_latent
K_C = self.W_UK(KV_latent)
V_C = self.W_UV(KV_latent)
Q_C = rearrange(Q_C, "b s (n d) -> b n s d", n=self.num_heads)
K_C = rearrange(K_C, "b s (n d) -> b n s d", n=self.num_heads)
V_C = rearrange(V_C, "b s (n d) -> b n s d", n=self.num_heads)
Q_R = self.W_QR(Q_down)
Q_R = rearrange(Q_R, "b s (n d) -> b n s d", n=self.num_heads)
K_R = self.W_KR(x)
K_R = rearrange(K_R, "b s d -> b 1 s d")
Q_R = self.rope(Q_R, token_positions).detach()
K_R = self.rope(K_R, token_positions).detach()
scores_C = torch.matmul(Q_C, K_C.transpose(-2, -1)) / math.sqrt(self.d_k)
scores_R = torch.matmul(Q_R, K_R.transpose(-2, -1)) / math.sqrt(self.rope_dim)
attention_scores = scores_C + scores_R
attention_weights = F.softmax(attention_scores, dim=-1)
attention_output = torch.matmul(attention_weights, V_C)
attention_output = rearrange(attention_output, "b n s d -> b s (n d)")
output = self.W_O(attention_output)
return output, cacheSWA
Sliding Window Attention reduces the memory and compute cost of attention by enforcing a “window” of previous tokens a token can attend to. Irrespective of the sequence length our attention compute cost is bounded by and the KV cache memory requirement is bounded by . It does have a tradeoff of tokens not getting the full context necessary but with deeper layers the receptive field for each token grows in a CNN like fashion.
DSA
DeepSeek Sparse Attention takes a similar approach of restricting the attention window but instead of a constant local bound it introduces an indexer. This indexer down projects the queries and keys and then computes a similarity score such that we get the top most relevant tokens that the current token should attend to. We then apply a mask over all other tokens.
There are lots of other attention variants that I’m not diving deeper into here. Sebastian Raschka has a ton of blogs on this topic!