Training
Any machine learning model is a probability distribution parameterized by parameters across all its layers denoted by . During training, it observes data sampled from a true distribution and minimizes the difference (KL-divergence) between the modeled distribution and sampled true distribution. In practice we use the Maximum Likelihood Estimate (MLE) objective, which simply means maximizing the likelihood of observing the training data if we were sampling it from our modeled distribution This is exactly what we will do with our decoder-only transformer.
Training Objective
Since we are predicting the next token conditional on tokens we have observed before it, we can represent the MLE objective as the following:
Instead of maximizing this term, we can minimize its negative log for numerical stability. This will also turn the products into summation. This gives us the training objective:
Recall from the Architecture section that the transformer model gives us z tensor of length vocab_size for each input token that contains logits for the model’s next token prediction. We can do a softmax of these values to obtain a probability distribution that can then be plugged into the above objective. For numerical stability, we subtract the maximum logit value from all values before applying softmax.
When implementing this in code we take inputs as the logits for each token with the shape batch x seq_len x vocab_size along with targets with the shape batch x seq_len which is just tokens shifted by one position. This implementation is equivalent to using torch.nn.CrossEntropyLoss with PyTorch.
def cross_entropy_loss(logits, targets):
max_logit = logits.max(dim=-1, keepdim=True).values
shifted = logits - max_logit
log_sum_exp = torch.log(torch.exp(shifted).sum(dim=-1, keepdim=True))
log_softmax = shifted - log_sum_exp
target_log_probs = log_softmax.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
return -target_log_probs.mean()AdamW Optimizer
Once we have found the loss, a simple loss.backward() backpropagates gradients to all parameters of the model. We then take an optimizer step where we update the parameters to minimize loss. The most simple form of doing this is Stochastic Gradient Descent where we use a fixed learning rate to update the parameters. Over time we have seen new techniques amongst which AdamW is the most widely used optimizer. It uses two crucial ideas that give us faster and more efficient navigation of the loss surface.
-
Momentum: Instead of an update with the recent gradient at each step, we keep an exponential moving average of all past gradients or the first moment and use that along with the current gradient at update time. Think of this like having momentum. If our previous gradients have been in the same direction, we will take faster step in the direction (accelerate) and if they have been noisy and in opposite direction we will dampen each step.
-
Adaptive Learning Rate: Instead of every parameter being updated with the same learning rate we use the RMSProp idea to keep track of the exponential moving average of the square of the gradients or the second moment. We then divide our update term with a square root of this term so that gradients with large steps historically get smaller updates and those with small get larger updates.
and are hyperparameters but we usually have their values be 0.9 and 0.999. Since the moments are initialized to zeroes, we also apply some bias correction so that the first updates are not too small. These updates give us a final expression for the moments:
The last important component is the W in AdamW which stands for weight decay. The standard practice is to add a penalty to the loss in form of L2/L1 regularization. But since we are dividing by the second moment that will keep track of the penalties, larger gradients will get less penalty when we divide by . We solve this by decoupling the decay from the gradient so every parameter has the same rate of decay. This gives us the final expression for the weight update. We use for numerical stability.
We implement AdamW using the torch.optim.Optimizer parent class. param_groups is a dictionary that contains parameter groups. In our case there is only a single group of all parameters. You also define all the hyperparameters as key-value pairs in this dictionary. Every parameter has a state dictionary where we then store the moments and a counter of number of updates that parameter has been through. We then update the p.grad.data based on the above expression without contributing to the computation graph.
class AdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}
super().__init__(params, defaults)
def step(self, closure=None):
loss = None if closure is None else closure()
for group in self.param_groups:
lr = group["lr"]
beta_1, beta_2 = group["betas"]
epsilon = group["eps"]
lamda = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state["t"] = 1
state["v"] = torch.zeros_like(p.data, requires_grad=False)
state["m"] = torch.zeros_like(p.data, requires_grad=False)
gradient = p.grad.data
state["m"] = beta_1 * state["m"] + (1 - beta_1) * gradient
state["v"] = beta_2 * state["v"] + (1 - beta_2) * (gradient**2)
adjusted_alpha = lr * ((1 - (beta_2**state["t"])) ** 0.5) / (1 - (beta_1**state["t"]))
theta_old = p.data.clone()
p.data = p.data - adjusted_alpha * state["m"] / (torch.sqrt(state["v"]) + epsilon)
p.data = p.data - (lr * lamda * theta_old)
state["t"] += 1
return lossLR Schedule and Gradient Clipping
I have also implemented Learning Rate Scheduling which changes the LR based on how many steps we are into training. We start with warmup that linearly increases our LR from to culminating at step This is helpful because in the initial phase of training our gradient estimates are noisy and the moments haven’t accumulated enough. In such a case having a smaller learning rate prevents us from taking huge steps in potentially wrong directions along the loss curve. Once we reach the peak we then slowly decay the learning rate with a cosine decay that decelerates at both ends. This ensures small steps as the model nears convergence.
def learning_rate_schedule(t, alpha_max, alpha_min, T_w, T_c):
if t < T_w:
return t * alpha_max / T_w
if t > T_c:
return alpha_min
cos_term = math.cos((t - T_w) / (T_c - T_w) * math.pi)
return alpha_min + 0.5 * (1 + cos_term) * (alpha_max - alpha_min)I also implemented gradient clipping which is used to prevent exploding gradients in deep neural networks. We set a max_norm upper bound and whenever the combined L2 norm of all our gradients exceeds this value, we scale down all gradients uniformly by the factor max_norm / (norm + eps) where eps is for numerical stability. Instead of creating one flat vector with all gradient values we find the norms of each gradient tensor and then stack them into a new tensor and use its L2 norm as the global norm.
def gradient_clipping(parameters, max_norm, eps=1e-6):
grads = [p.grad for p in parameters if p.grad is not None]
norm = torch.norm(torch.stack([torch.norm(g, 2) for g in grads]), 2)
if norm > max_norm:
scaling_factor = max_norm / (norm + eps)
for g in grads:
g.detach().mul_(scaling_factor)
return norm.item()Data Loading and Training
We start with our dataset. Most open source datasets will have dictionary samples with text within them. We will then tokenize each sample with our tokenizer after adding the <|endoftext|> token in the end. We then concatenate all these bytes into one continuous file and then save it as a binary file. During training we lazy load this binary file using np.memmap that doesn’t load all the data in RAM and instead selectively picks the bytes we ask for in the later loop using the appropriate offset in memory buffer. We then randomly pick starting positions in this file and then find seq_len tokens after each point and find the target values with a +1 offset. This becomes the X_batch, y_batch for our training loop that we load to the device with our model weights.
def data_loading(x, batch_size, context_length, device):
starts = torch.randint(0, len(x) - context_length, (batch_size,))
sequences = torch.stack([torch.from_numpy(x[s:s + context_length].astype(np.int64)) for s in starts])
targets = torch.stack([torch.from_numpy(x[s + 1:s + 1 + context_length].astype(np.int64)) for s in starts])
return sequences.to(device), targets.to(device)Training then is simply injecting the correct learning rate based on the current step, loading a batch of data and running a forward pass. We then find the cross entropy loss between the target and predictions and backpropagate with loss.backward() and clip the gradients. We then run the optimizer step to update weights and clear old gradients before the next loop.
Generation
During generation we take an input prompt and repeatedly predict the next tokens until we reach a max_tokens limit or the <|endoftext|> token. At every step we append the predicted token from the last step to our initial prompt. There are some knobs we have for how we pick from the distribution that gives us.
-
Temperature controls the sharpness of the probability distribution before we sample a token. Each logit is divided by , which amplifies differences between tokens when and compresses them when , so unlikely tokens receive more probability mass when the distribution is flatter.
-
We use Nucleus Sampling where we arrange all logits in descending order and select the top cumulative probability mass for a certain threshold and discount all other tokens.
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=200, temperature=1.0, top_p=0.9, device=None):
if device is None:
device = next(model.parameters()).device
model.eval()
token_ids = tokenizer.encode(prompt)
for _ in range(max_tokens):
input_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
logits = model(input_ids)
next_logits = logits[0, -1, :]
if temperature > 0:
scaled = next_logits / temperature
probs = torch.softmax(scaled, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = (cumsum - sorted_probs) > top_p
sorted_probs[mask] = 0.0
sorted_probs /= sorted_probs.sum()
idx = torch.multinomial(sorted_probs, num_samples=1)
next_id = sorted_indices[idx].item()
else:
next_id = next_logits.argmax().item()
token_ids.append(next_id)
decoded = tokenizer.decode([next_id])
if decoded == "<|endoftext|>":
break
model.train()
return tokenizer.decode(token_ids)