Tokenization
LLMs need to understand words and characters to model language. When we give them an input sequence like What is the capital of India? they break it down into smaller parts called tokens. These tokens are then converted to vector embeddings that let models understand and tweak their meaning. But how do we know what these tokens are? This is defined by the tokenizer, which is the first component we train for any language model today. Think of it like defining a dictionary for the language model so that it knows what tokens make up the input sequence and can lookup their meaning. This doesn't sound that hard, and I'm sure you can think of obvious approaches like:
-
Character level encoding: Why not just assign a unique integer ID to each character or letter? The model can just break words down into constituent characters and map them to the embeddings. This does give us a small and well defined vocabulary but comes with some problems. The computation cost in the attention layer of a transformer scales quadratically with the number of input tokens. If every character is an independent token, we will have to process lots of tokens! We also want tokens in our vocabulary to carry semantic meaning or information that the model can work with, which is not possible at a character level.
-
Word level encoding: Why not just assign a unique integer ID to every unique word we come across? This gives us tokens that are information dense with reasonable attention costs. But there are still problems here. The vocabulary size would be immense if we account for unique word variants across languages. Our model will also be incapable of tokenizing words it was not trained on.
We want a tokenizer that can keep a reasonable vocabulary size while having information dense tokens. Papers proposed rule based and morphological methods but the main breakthrough came with the byte-level BPE tokenizer which was popularized by GPT-2 from OpenAI. Let's understand it step by step...
UTF-8 and BPE
Before we can understand the tokenizer, we need to agree on our representation of characters. The global standard here is Unicode which assigns each character a unique integer ID called a codepoint. The letter a has codepoint 97, and the letter ê (e with circumflex, as in French) has codepoint 234. Unicode defines 150K+ such characters across languages and also including emojis and symbols. Within Unicode we have an encoding standard called UTF-8 that represents each such codepoint as a collection of bytes with a variable length between one and four bytes. A word or a sentence with UTF-8 encoding will be a list of byte values of its constituent characters.
text = "hello"
print(list(text.encode("utf-8")))The Byte Pair Encoding algorithm was introduced in 1994 to compress such byte representation. It takes an input and finds the most frequent contiguous byte pair and replaces it with a new byte value that was not in the initial set. In the example input aaabdaaabac we see that aa is the most frequent pair which would be replaced by a new character say Z = aa to give us a compressed form of the original input as ZabdZabac. This can be done repeatedly until a desired vocabulary size is reached.
Implementation
We start with a text corpus that will be our training dataset. I have used TinyStories and OpenWebText. You can download them as a .txt file. Before we get to the algorithm we do some pre-tokenization steps that are helpful during training.
You might have special tokens in your corpus like <|endoftext|> that represent end of documents or lines. You don't want them to be tokenized because they serve a special purpose and also define boundaries in your dataset. We therefore split our text into chunks between such special tokens and also remove them from the corpus.
We then do pre-tokenization which applies a special regex pattern introduced in the GPT-2 codebase that splits our chunks into smaller chunks based on pre-defined rules. Here is the pattern we use:
'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
This will split words, punctuations, numbers etc. into appropriate chunks for tokenization. We do this because we don't want to merge tokens across words. We also want punctuations or suffixes like 't or 's to remain independent tokens instead of merging with their root words because they carry a generalizable meaning. This video from Andrej Karpathy does a great job of explaining this part.
with open(input_path, "r", encoding="utf-8") as f:
text = f.read()
special_pattern = "|".join(re.escape(token) for token in special_tokens)
processed = re.split(special_pattern, text)
pre_tokenized = []
for chunk in processed:
tokenized = re.finditer(GPT_PATTERN, chunk)
for token in tokenized:
pre_tokenized.append(token.group())At this stage we have a collection of chunks of our string input. We now convert that to its byte pair representation by encoding it with the UTF-8 standard. We will then initialize our vocab as a dictionary with the index keys through 0-256 and their byte representation as values and merges to be a list of all byte pair merges our algorithm will make. We will then iterate over all possible pairs in the chunks in our collection and build a dictionary that tracks their frequency. We then pick the most frequent pair and add a new index to our vocab with the two bytes concatenated as its value and update merges. We then iterate through our pairs again and whenever we see the merged pair, we replace it with the new byte representation. We do this repeatedly until a desired vocabulary size is reached.
self.vocab = {i: bytes([i]) for i in range(256)}
self.merges = []
sequences = []
for chunk in pre_tokenized:
sequences.append(list(chunk.encode("utf-8")))
num_merges = vocab_size - len(special_tokens) - 256
pbar = tqdm(total=num_merges, desc="Training (simple)", unit="merge")
while len(self.vocab) < vocab_size - len(special_tokens):
if tracker:
tracker.on_merge_start()
pairs = {}
for token_bytes in sequences:
for a, b in zip(token_bytes, token_bytes[1:]):
if (a, b) in pairs:
pairs[(a, b)] += 1
else:
pairs[(a, b)] = 1
best_pair = max(pairs, key=lambda k: (pairs[k], self.vocab[k[0]], self.vocab[k[1]]))
new_id = len(self.vocab)
self.vocab[new_id] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
self.merges.append((best_pair[0], best_pair[1]))
for token_bytes in sequences:
i = 0
while i < len(token_bytes) - 1:
if token_bytes[i] == best_pair[0] and token_bytes[i + 1] == best_pair[1]:
token_bytes[i] = new_id
del token_bytes[i + 1]
else:
i += 1
merge_time = time() - tracker.merge_start_time if tracker else 0.0
merge_idx = len(self.merges) - 1We then add our special tokens to vocab as individual tokens. This is the entire training process which returns the vocab and merges that we can use for encoding and decoding with this tokenizer in the future. During encoding we take a string input and perform the pre-tokenization steps and convert the chunks to their UTF-8 byte representations. We then lookup all the byte pairs from the input in the merges dictionary and find the one with the smallest index (i.e. the earliest merge during training) and replace it with the corresponding pair from vocab and do this until we have no more merges found on lookup. For decoding our input is token IDs that we simply lookup in vocab and return.
Optimization
While this implementation works in theory, when we are dealing with large datasets it is important to optimize parts of the code. Here are a couple things you can try:
Process pooling for pre-tokenization: Since we are doing CPU bound regex based splitting on a large text file, we can load the file in binary and split it into num_workers chunks. We can then have multiple processes pre-tokenizing each chunk and concatenate the results.
special_pattern = "|".join(re.escape(t) for t in special_tokens) if special_tokens else ""
split_token = special_tokens[0].encode("utf-8") if special_tokens else b"\n"
with open(input_path, "rb") as f:
boundaries = find_chunk_boundaries(f, num_workers, split_token)
chunk_args = [
(input_path, special_pattern, boundaries[i], boundaries[i + 1])
for i in range(len(boundaries) - 1)
if boundaries[i] < boundaries[i + 1]
]
with mp.Pool(num_workers) as pool:
chunk_results = pool.starmap(_pretokenize_chunk, chunk_args)
pre_tokenized = []
for result in chunk_results:
pre_tokenized.extend(result)Efficient merging: The above approach was naive. Instead of rebuilding counts from scratch each iteration, we can maintain a dictionary of pairs and their counts. Whenever we merge a and b, decrement affected neighbors (x, a) and (b, x) and add new neighbor pairs with the merged token.
pairs: dict[tuple[int, int], int] = {}
for seq in sequences:
for a, b in zip(seq, seq[1:]):
pairs[(a, b)] = pairs.get((a, b), 0) + 1
num_merges = vocab_size - len(special_tokens) - 256
pbar = tqdm(total=num_merges, desc="Training (efficient)", unit="merge")
while len(self.vocab) < vocab_size - len(special_tokens):
if tracker:
tracker.on_merge_start()
best_pair = max(
(p for p in pairs if pairs[p] > 0),
key=lambda k: (pairs[k], self.vocab[k[0]], self.vocab[k[1]]),
)
A, B = best_pair
best_count = pairs[best_pair]
new_id = len(self.vocab)
self.vocab[new_id] = self.vocab[A] + self.vocab[B]
self.merges.append((A, B))
for seq in sequences:
i = 0
while i < len(seq) - 1:
if seq[i] == A and seq[i + 1] == B:
if i > 0:
pairs[(seq[i - 1], A)] -= 1
if i + 2 < len(seq):
pairs[(B, seq[i + 2])] -= 1
pairs[(A, B)] -= 1
seq[i] = new_id
del seq[i + 1]
if i > 0:
pairs[(seq[i - 1], new_id)] = pairs.get((seq[i - 1], new_id), 0) + 1
if i + 1 < len(seq):
pairs[(new_id, seq[i + 1])] = pairs.get((new_id, seq[i + 1]), 0) + 1
else:
i += 1When training on the TinyStories validation set (~22MB) the naive method took 373.58 seconds while the optimized code took 103.37 seconds and both reached the same outputs. This is a 3.6x speedup which will be higher with larger datasets as the overhead time cost for spinning up processes becomes negligible compared to the time for each process to run.