Getting Started:¶
What you need to do:¶
- Download and Extract the assigment file (mp06.zip)
- Implement the required methods in the
submitted.pybuild_attention_mask(batch)batch_samples(batch, pad_token_id)create_causal_mask(seq_len)LinearFunctionAttentionFunction
- Text locally with
python grade.py - Submit
submitted.pyto Gradescope!
CS440/ECE448¶
MP06: Autoregressive Transformers and Reasoning Traces¶
After unpacking the mp06.zip files you will find the following content:
submitted.py: Your homework. Edit and then submit to Gradescopemp06_notebook.ipynb: This is a Jupyter Notebook to help you debug. You can completely ignore it if you want, although you might find that it gives you useful instructions.grade.py: Once your homework seems to be working, you can test it by typingpython grade.py, which will run the visible tests.tests/test_visible.py: This file contains visible unit tests for the main assignmentutils.py: This is an auxiliary program that you can use to split data, parse generations, etc...requirements.txt: Usepip install -r requirements.txtto install all the necessary dependencies
GPU Access¶
You DO NOT need any GPU access for the MP. Everything included in the submitted.py can be easily evaluated on CPU alone. Training will also progress normally on a CPU, although slower. For Free GPU Access you can use Google Colaboratory to which you would upload your submitted.py, utils.py, and the mp06_notebook.py. At the top you should see on the toolbar Runtime. Clicking on this produces a dropdown where you should see Change runtime type. Under the list of Hardware Accelerators select T4 GPU and then click save! This will provide GPU access.
To verify GPU access you can run:
import torch
torch.cuda.is_available() # this should produce True
What Are We Building?¶
Today we will be building a tiny language model to add numbers together! But we will test it in two different ways:
Standard Addition:¶
82+78<assistant_start>=160<assistant_end>
Reasoning Addition:¶
82+78<assistant_start>
<start_think>
=2+8=10<start_carry>1<end_carry>
=8+7+1=16<start_carry>1<end_carry><end_think>
=160
<assistant_end>
What you see above is two ways we can ask an LLM to perform the addition task. In the first method (Standard addition) we only ask the model for the final answer. The model sees the prompt
82+78 and is trained to emit 160 directly, without thinking how it arrives there. In the second method (Reasoning addition) we ask the model to first produce some intermediate thinking text that reasons out how we typically perform a summation and then produce the final answer. This is often known as a reasoning trace.
Reasoning Models¶
Reasoning models have become more popular lately, and they are trained to produce intermediate structured steps instead of jumping straight to an answer. This can (potentially) help with model performance as it forces the model to go through a verified process, but at the penalty of longer inference time compute as we have to now generate more tokens!
Special Tokens¶
Special tokens are common in LLMs as well and act as delimiters. The special tokens we will encounter in this MP are:
<assistant_start>: Tells the LLM, this is where you start generating stuff<assistant_end>: The LLM can tell us that its done generating and we can stop<start_carry>/<end_carry>: This is just a helper token to identify the digit we will carry over to the next place<start_think>/<end_think>: Between these tokens the model has to produce its thought tokens, once its done thinking it has to emit the final solution<pad>: Not all sequences are of the same length, so we must pad shorter sequences to the longer ones
What you will be doing¶
The reasoning model will act as just an example for you to train on. The main focus for the MP will be to write the core portions of the Transformer architecture and the methods you will implement are:
- Causal Masking
- Padding Masking
- Data Collation (batching samples together)
- Custom Linear Layers (Forward/Backward)
- Custom Attention Module (Forward/Backward)
All other code regarding data production, the model architecture, training/evaluation will be provided and annotated throughout this notebook!
Table of Contents¶
- Building the Dataset
- Creating a Dataset
- Tokenization
- Tokens to Embeddings
- Problem 1: Padded Batching
- Problem 2: Attention Masking
- Data Collation
- Attention Mechanism and Masking
- Problem 3: Causal Masking
- Linear Layer Overview
- Multidimensional Matrix Multiplication
- Problem 4: Custom Linear Layer
- Problem 5: Custom Attention Implementation
- Attention Block
- Multilayer Perceptron
- Transformer Block
- Tiny GPT2 Model
- Positional Embeddings
- Training Loop
- Compare Reasoning vs Non-Reasoning Models
- Data Balancing
Lets Take a Look At the Data¶
We will be generating a synthetic dataset today with optional reasoning traces. This code takes in two integer values and produces a string that will later act as a sample for our LLM to train on!
import re
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import random_split_dataset, evaluate_model, extract_answer
import submitted
import importlib
def generate_addition_example(
a: int,
b: int,
carry_start_token: str = "<start_carry>",
carry_end_token: str = "<end_carry>",
think_start_token: str = "<start_think>",
think_end_token: str = "<end_think>",
assistant_start_token: str = "<assistant_start>",
assistant_end_token: str = "<assistant_end>",
add_reasoning_trace=True
):
if add_reasoning_trace:
sa = str(a)
sb = str(b)
i = len(sa) - 1
j = len(sb) - 1
carry = 0
steps = []
# Process digits from right to left
while i >= 0 or j >= 0:
da = int(sa[i]) if i >= 0 else 0
db = int(sb[j]) if j >= 0 else 0
total = da + db + carry
new_carry = total // 10
if carry == 0:
step = (
f"={da}+{db}={total}"
f"{carry_start_token}{new_carry}{carry_end_token}\n"
)
else:
step = (
f"={da}+{db}+{carry}={total}"
f"{carry_start_token}{new_carry}{carry_end_token}\n"
)
steps.append(step)
carry = new_carry
i -= 1
j -= 1
result = str(a + b)
# Build final example
example = [
f"{a}+{b}{assistant_start_token}\n",
think_start_token, "\n",
*steps,
think_end_token, "\n",
f"={result}\n{assistant_end_token}"
]
return "".join(example)
else:
return f"{a}+{b}{assistant_start_token}={a+b}{assistant_end_token}"
a = 68
b = 72
print()
print('\033[95m' + "Input Without Reasoning" + '\033[0m')
print(generate_addition_example(a,b, add_reasoning_trace=False))
print("\n" + '\033[95m' + "Input With Reasoning" + '\033[0m')
print(generate_addition_example(a,b, add_reasoning_trace=True))
Input Without Reasoning 68+72<assistant_start>=140<assistant_end> Input With Reasoning 68+72<assistant_start> <start_think> =8+2=10<start_carry>1<end_carry> =6+7+1=14<start_carry>1<end_carry> <end_think> =140 <assistant_end>
Creating a Dataset¶
Of course we dont want one sample, but rather a large collection of samples to train our model on! So lets create a list of samples that we can process later on.
def build_dataset(
num_places: int = 3,
reasoning: bool = True,
num_samples: int = 100_000,
):
"""
Build a randomly sampled dataset of addition problems.
Args:
num_places: Maximum number of digits (e.g., 2 means 0–99, 3 means 0-999 etc...)
reasoning: Whether to include reasoning traces
num_samples: Number of (i, j) pairs to sample
seed: Optional RNG seed for reproducibility
Returns:
List of formatted training examples
"""
### If we want upto 2 places that means a and b can be any value between 0-99
### so we have 100 * 100 (10000) possible combinations!
max_value = 10 ** num_places
dataset = []
assert num_samples <= max_value ** 2, "Increase num places or you will repeat samples!"
print("Generating Random Dataset...")
### We randomly sample our dataset by repeatedly grabbing different a,b values ###
### and producing our text (with or without reasoning trace) ###
for _ in tqdm(range(num_samples)):
i = random.randrange(max_value)
j = random.randrange(max_value)
sample = generate_addition_example(
i, j, add_reasoning_trace=reasoning
)
# Verify correctness
assert int(
sample.split("=")[-1]
.replace("<assistant_end>", "")
.strip()
) == i + j
dataset.append(sample)
return dataset
dataset = build_dataset(num_places=2, num_samples=100, reasoning=True)
for sample in dataset[-4:]:
print(sample + "\n")
Generating Random Dataset...
100%|█████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 414456.92it/s]
89+53<assistant_start> <start_think> =9+3=12<start_carry>1<end_carry> =8+5+1=14<start_carry>1<end_carry> <end_think> =142 <assistant_end> 34+90<assistant_start> <start_think> =4+0=4<start_carry>0<end_carry> =3+9=12<start_carry>1<end_carry> <end_think> =124 <assistant_end> 97+73<assistant_start> <start_think> =7+3=10<start_carry>1<end_carry> =9+7+1=17<start_carry>1<end_carry> <end_think> =170 <assistant_end> 62+86<assistant_start> <start_think> =2+6=8<start_carry>0<end_carry> =6+8=14<start_carry>1<end_carry> <end_think> =148 <assistant_end>
Tokenization¶
Of course neural networks don't process text but rather tokenized text. This means the string must first be converted into a discrete symbol that the model can then embed and process numerically. Modern LLMs have complex tokenization schemes such as BPE (Byte Pair Encoding), Wordpiece or others and they do exactly what is described, conversion of text to integers, where those integers map to specific words or subwords. For example, you can play with the openAI Tokenizer and you can see the following conversion:
$$\text{UIUC was referenced in the 2001 Space Odyssey} = \text{[3892, 25949, 673, 48585, 306, 290, 220, 1179, 16, 19468, 119230]}$$
Each token index represents the following:
- $\text{3892}=\text{UI}$
- $\text{25949}=\text{UC}$
- $\text{673}=\text{ was}$
- $\text{48585}=\text{ referenced}$
- $\text{306}=\text{ in}$
- $\text{290}=\text{ the}$
- $\text{220}=\text{<space>}$
- $\text{1179}=\text{200}$
- $\text{16}=\text{ 1}$
- $\text{19468}=\text{ Space}$
- $\text{119230}=\text{ Odyssey}$
So tokenizers basically provide two things:
- Mapping from token indexes to cooresponding text
- Each tokenizer has a max number of tokens it contains (the
vocabulary sizeof the tokenizer)
Simplified Tokenizer¶
Taking a look at our problem, we don't really need a fancy tokenizer. The only tokens we have in our data is:
- Digits 0-9
- Special tokens:
<start_carry>,<end_carry>,<start_think>,<end_think>,<assistant_start>,<assistant_end>,<pad> - Operators:
=,+ - Formatting:
\n
So lets build a tokenizer that can take our dataset text and encode to the tokens, and decode back to text again!
class Tokenizer:
"""Simple character-level tokenizer with special tokens for reasoning."""
def __init__(self):
### Define all the possible unique vocabulary in our dataset ###
self.vocab = (
[str(i) for i in range(10)] + # Digits 0-9
["<start_carry>", "<end_carry>", "<start_think>", "<end_think>", "<assistant_start>", "<assistant_end>", "<pad>"] + # Special tokens
["=", "+", "\n"] # Operators and formatting
)
### Get the number of unique vocabular in tokenizer ###
self.vocab_size = len(self.vocab)
### Create dictionary mappings from our vocabulary to a unique index representing it ###
self.char2idx = {c: i for i, c in enumerate(self.vocab)}
self.idx2char = {i: c for c, i in self.char2idx.items()}
### Store important token IDs for easy access later
self.start_carry_id = self.char2idx["<start_carry>"]
self.end_carry_id = self.char2idx["<end_carry>"]
self.start_think_id = self.char2idx["<start_think>"]
self.end_think_id = self.char2idx["<end_think>"]
self.assistant_start_id = self.char2idx["<assistant_start>"]
self.assistant_end_id = self.char2idx["<assistant_end>"]
self.pad_id = self.char2idx["<pad>"]
### Regex pattern to split our text into tokens
### this will convert:
### 68+72<assistant_start><start_think>=8+2=10<start_carry>1<end_carry>=6+7+1=14<start_carry>1<end_carry><end_think>=140<assistant_end>
### to
### ['6','8','+','7','2','<assistant_start>','<start_think>','=','8', ...]
self.pattern = re.compile(r"<[^>]+>|[\s\S]")
def split(self, text):
"""
Split text into tokens, keeping <...> intact.
"""
return self.pattern.findall(text)
def encode(self, text, is_prompt=False):
"""
Encode text into token IDs.
text: string to encode
is_prompt: whether to append <assistant_start>. During training time, all our special tokens are already
added to the raw text as defined in our `generate_addition_example` method. But during inference
we just provide something like 25+32, which means we need to manually add in the <assistant_start>
token so our input to the LLM becomes 25+32<assistant_start> as that is the special token we defined
that indicates to the LLM to start generating!
"""
### Use our REGEX to separate our text into characters and special tokens ###
string_tokens = self.split(text)
### Loop through and encode each of them ###
token_ids = []
for tok in string_tokens:
if tok not in self.char2idx:
raise ValueError(f"Unknown token: {tok}")
token_ids.append(self.char2idx[tok])
### Manually add on the assistant_start_id to the end if we are using a prompt ###
if is_prompt:
token_ids.append(self.assistant_start_id)
return token_ids
def decode(self, tokens, strip_pad_tokens = True):
"""
Decode token IDs back into text.
tokens: List of token IDs or tensor
strip_pad_tokens: Whether to remove padding tokens
"""
if isinstance(tokens, torch.Tensor):
assert tokens.dim() == 1
tokens = tokens.tolist()
if strip_pad_tokens:
tokens = [i for i in tokens if i != self.pad_id]
decoded = "".join([self.idx2char[i] for i in tokens])
return decoded
sample = dataset[0]
print('\033[95m' + "Original Sample" + '\033[0m')
print(sample)
print("\n" + '\033[95m' + "Tokenized Sample" + '\033[0m')
tokenizer = Tokenizer()
encoded = tokenizer.encode(sample)
print(encoded)
print("\n" + '\033[95m' + "Decoded Tokens" + '\033[0m')
decoded = tokenizer.decode(encoded)
print(decoded)
assert sample == decoded, "Original and Decoded should be the same, or something has gone wrong!"
Original Sample 47+54<assistant_start> <start_think> =7+4=11<start_carry>1<end_carry> =4+5+1=10<start_carry>1<end_carry> <end_think> =101 <assistant_end> Tokenized Sample [4, 7, 18, 5, 4, 14, 19, 12, 19, 17, 7, 18, 4, 17, 1, 1, 10, 1, 11, 19, 17, 4, 18, 5, 18, 1, 17, 1, 0, 10, 1, 11, 19, 13, 19, 17, 1, 0, 1, 19, 15] Decoded Tokens 47+54<assistant_start> <start_think> =7+4=11<start_carry>1<end_carry> =4+5+1=10<start_carry>1<end_carry> <end_think> =101 <assistant_end>
What Do Tokens Mean? We Use the Embedding Matrix!¶
Each token was assigned a unique number, but its not like these numbers mean anything? How can a Neural network learn from them? At the start of our LLM we will have something called the Embedding Matrix which will be instantiated via the nn.Embedding from PyTorch.

This matrix will have the shape of (Vocab Size x Embedding Dimension). What this means is, each vocabulary in our tokenizer will have a vector of size Embedding Dimension representing it! So the unique number we assigned to each word in our vocabulary is nothing more than the index (which row) of this embedding matrix does this word go with. The parameters of this matrix are trained with gradient descent just like all the other parameters of the model.
Lets quickly see this in action. nn.Embedding expects indexes and then grabs the corresponding vectors in the embedding matrix in the forward pass!
### Define an Embeding layer with a vocab size of 4, and each vocab represented by 5 numbers (embed dim)
embeds = nn.Embedding(4,5)
print('\033[95m' + "Embedding Matrix:" + '\033[0m')
print(embeds.weight)
### Pass Tokens Through
tokens = torch.tensor([0,1,2], dtype=torch.long)
token_embeds = embeds(tokens)
print("\n" + '\033[95m' + "Embedded Tokens No Batch:" + '\033[0m')
print(token_embeds)
print(token_embeds.shape) # -> should be (3,5) as we had 3 tokens in our example and then
# each were indexed from an embed matrix that gave every token an embed dim of 5
### Pass w/ batch dimension ###
tokens = torch.tensor([[1,3,0], [2,1,1]])
token_embeds = embeds(tokens)
print("\n" + '\033[95m' + "Embedded Tokens w/ Batch:" + '\033[0m')
print(token_embeds)
print(token_embeds.shape) # -> should be (2,3,5) as we had 2 samples, each with 3 tokens in our example
# and then each were indexed from an embed matrix that gave every token an embed dim of 5
Embedding Matrix: Parameter containing: tensor([[-0.7128, -0.5692, -0.1855, 0.5567, 0.4218], [ 1.0083, -0.6225, 1.7885, -1.0122, 0.0139], [-0.3570, 1.9469, 1.2911, -0.5108, -1.6280], [ 1.5158, 0.2416, -0.5650, 0.9368, -0.3103]], requires_grad=True) Embedded Tokens No Batch: tensor([[-0.7128, -0.5692, -0.1855, 0.5567, 0.4218], [ 1.0083, -0.6225, 1.7885, -1.0122, 0.0139], [-0.3570, 1.9469, 1.2911, -0.5108, -1.6280]], grad_fn=<EmbeddingBackward0>) torch.Size([3, 5]) Embedded Tokens w/ Batch: tensor([[[ 1.0083, -0.6225, 1.7885, -1.0122, 0.0139], [ 1.5158, 0.2416, -0.5650, 0.9368, -0.3103], [-0.7128, -0.5692, -0.1855, 0.5567, 0.4218]], [[-0.3570, 1.9469, 1.2911, -0.5108, -1.6280], [ 1.0083, -0.6225, 1.7885, -1.0122, 0.0139], [ 1.0083, -0.6225, 1.7885, -1.0122, 0.0139]]], grad_fn=<EmbeddingBackward0>) torch.Size([2, 3, 5])
Tokenize Dataset¶
Now that we have our Tokenizer, we can go ahead and actually tokenize our data!
def tokenize_dataset(dataset):
"""Tokenize an entire dataset."""
tokenizer = Tokenizer()
print("Tokenizing Dataset...")
return [tokenizer.encode(sample, is_prompt=False) for sample in dataset]
tokenized_dataset = tokenize_dataset(dataset)
for sample in tokenized_dataset[-10:]:
print("Num Tokens:", len(sample), sample)
Tokenizing Dataset... Num Tokens: 36 [2, 0, 18, 5, 8, 14, 19, 12, 19, 17, 0, 18, 8, 17, 8, 10, 0, 11, 19, 17, 2, 18, 5, 17, 7, 10, 0, 11, 19, 13, 19, 17, 7, 8, 19, 15] Num Tokens: 38 [8, 1, 18, 5, 5, 14, 19, 12, 19, 17, 1, 18, 5, 17, 6, 10, 0, 11, 19, 17, 8, 18, 5, 17, 1, 3, 10, 1, 11, 19, 13, 19, 17, 1, 3, 6, 19, 15] Num Tokens: 38 [8, 8, 18, 7, 1, 14, 19, 12, 19, 17, 8, 18, 1, 17, 9, 10, 0, 11, 19, 17, 8, 18, 7, 17, 1, 5, 10, 1, 11, 19, 13, 19, 17, 1, 5, 9, 19, 15] Num Tokens: 41 [7, 6, 18, 9, 5, 14, 19, 12, 19, 17, 6, 18, 5, 17, 1, 1, 10, 1, 11, 19, 17, 7, 18, 9, 18, 1, 17, 1, 7, 10, 1, 11, 19, 13, 19, 17, 1, 7, 1, 19, 15] Num Tokens: 35 [8, 18, 4, 1, 14, 19, 12, 19, 17, 8, 18, 1, 17, 9, 10, 0, 11, 19, 17, 0, 18, 4, 17, 4, 10, 0, 11, 19, 13, 19, 17, 4, 9, 19, 15] Num Tokens: 38 [7, 0, 18, 5, 1, 14, 19, 12, 19, 17, 0, 18, 1, 17, 1, 10, 0, 11, 19, 17, 7, 18, 5, 17, 1, 2, 10, 1, 11, 19, 13, 19, 17, 1, 2, 1, 19, 15] Num Tokens: 41 [8, 9, 18, 5, 3, 14, 19, 12, 19, 17, 9, 18, 3, 17, 1, 2, 10, 1, 11, 19, 17, 8, 18, 5, 18, 1, 17, 1, 4, 10, 1, 11, 19, 13, 19, 17, 1, 4, 2, 19, 15] Num Tokens: 38 [3, 4, 18, 9, 0, 14, 19, 12, 19, 17, 4, 18, 0, 17, 4, 10, 0, 11, 19, 17, 3, 18, 9, 17, 1, 2, 10, 1, 11, 19, 13, 19, 17, 1, 2, 4, 19, 15] Num Tokens: 41 [9, 7, 18, 7, 3, 14, 19, 12, 19, 17, 7, 18, 3, 17, 1, 0, 10, 1, 11, 19, 17, 9, 18, 7, 18, 1, 17, 1, 7, 10, 1, 11, 19, 13, 19, 17, 1, 7, 0, 19, 15] Num Tokens: 38 [6, 2, 18, 8, 6, 14, 19, 12, 19, 17, 2, 18, 6, 17, 8, 10, 0, 11, 19, 17, 6, 18, 8, 17, 1, 4, 10, 1, 11, 19, 13, 19, 17, 1, 4, 8, 19, 15]
First Hurdle: Samples are of Different Lengths¶

You might notice an issue now. We typically train neural networks in batches, but our different samples here are of different lengths. So we can't really batch together tensors of different sizes! The way we typically deal with this is padding.
Lets say we have the following 3 samples:
[10, 22, 53]
[16, 65, 43, 73]
[64, 23]
What we want is to concatenate it all together so we have a batch size of 3, but the sequence lengths of each sample is different and thats why concatenation wont work. To help out, what we can do is just add a special pad token to the data so we pad the shorter samples to the longer one!
[10, 22, 53, <P>]
[16, 65, 43, 73]
[64, 23, <P>, <P>]
This is the first step of the data collation process! Lets implement this
Problem 1: Implement Padded Batching¶
Remember, we had a special token in the tokenizer called <pad> that we kept exactly for this reason! The batch_samples method you are about to implement has in the function argument pad_token_id where we can pass that in later, but assume its just an integer value.
In submitted.py solve the problem batch_samples
importlib.reload(submitted)
from submitted import batch_samples
print("Batched Data")
padded_batch = batch_samples(tokenized_dataset[:4], tokenizer.pad_id)
print(padded_batch)
print("Shape:", padded_batch.shape)
Batched Data
tensor([[ 4, 7, 18, 5, 4, 14, 19, 12, 19, 17, 7, 18, 4, 17, 1, 1, 10, 1,
11, 19, 17, 4, 18, 5, 18, 1, 17, 1, 0, 10, 1, 11, 19, 13, 19, 17,
1, 0, 1, 19, 15],
[ 3, 6, 18, 7, 5, 14, 19, 12, 19, 17, 6, 18, 5, 17, 1, 1, 10, 1,
11, 19, 17, 3, 18, 7, 18, 1, 17, 1, 1, 10, 1, 11, 19, 13, 19, 17,
1, 1, 1, 19, 15],
[ 5, 18, 8, 3, 14, 19, 12, 19, 17, 5, 18, 3, 17, 8, 10, 0, 11, 19,
17, 0, 18, 8, 17, 8, 10, 0, 11, 19, 13, 19, 17, 8, 8, 19, 15, 16,
16, 16, 16, 16, 16],
[ 2, 6, 18, 2, 4, 14, 19, 12, 19, 17, 6, 18, 4, 17, 1, 0, 10, 1,
11, 19, 17, 2, 18, 2, 18, 1, 17, 5, 10, 0, 11, 19, 13, 19, 17, 5,
0, 19, 15, 16, 16]])
Shape: torch.Size([4, 41])
Problem 2: Attention Mask¶
Padding is necessary, but these pad tokens are not actually a part of the dataset. We don't want the Attention Mechanism we will implement later to really pay attention to this. So for now, lets create a binary mask that will be True for non-pad tokens and False otherwise!
In submitted.py solve the problem build_attention_mask
importlib.reload(submitted)
from submitted import build_attention_mask
print("Attention Mask")
attention_mask = build_attention_mask(tokenized_dataset[:4])
print(attention_mask)
print("Shape:", attention_mask.shape)
Attention Mask
tensor([[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True],
[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True],
[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, False, False, False, False, False,
False],
[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, False,
False]])
Shape: torch.Size([4, 41])
Collate Function¶
Now that we have our mechanisms to batch data and create an attention mask, we can go ahead and write our collate function. The collation function will be used later in the Dataloader and provides the logic of how to take separate individual samples and batch them together. We have handled the issue of different sequence lengths so we are good to go!
Setting up the Autoregressive Task¶
How do we setup the autoregressive task? Models like GPT (Transformer Decoders) are next token predictors. So the ideas is given a sequence of tokens, we want to predict the one that comes after it! For example lets say we have the following token sequence:
$$\text{[ 2, 7, 18, 7, 6, 14, 12, 17, 7, 18, 6, 17]}$$
This means we set out model inputs and targets up as:
$$\text{Input: [ 2, 7, 18, 7, 6, 14, 12, 17, 7, 18, 6]}$$ $$\text{Target: [ 7, 18, 7, 6, 14, 12, 17, 7, 18, 6, 17]}$$
Notice from the start. The Input is first 2 and it has to predict the target 7 (which is the next input!). Then given 2 and 7 has to predict 18 (which is the input after that!). The targets are then basically just shifted over by 1 from the inputs.
def collate_fn(batch):
"""
Collate function for DataLoader - prepares batches for training.
Creates input/target pairs by shifting sequences by one position.
Args:
batch: List of tokenized sequences
Returns:
Dictionary with input_ids, targets, and attention_mask
"""
tokenizer = Tokenizer()
### Pad Sequences and Batch Together ###
data = batch_samples(batch, tokenizer.pad_id)
### Create input (all but last token) and target (all but first token)
### This way they are offset and each input token has to predict the cooresponding (next) target token
inputs = data[:, :-1].clone()
targets = data[:, 1:].clone()
### Create attention mask on the inputs (ignoring the last position as that is the final target) ###
attention_mask = build_attention_mask([seq[:-1] for seq in batch])
return {
"input_ids": inputs,
"targets": targets,
"attention_mask": attention_mask
}
batch = collate_fn(tokenized_dataset[:3])
print(batch)
{'input_ids': tensor([[ 4, 7, 18, 5, 4, 14, 19, 12, 19, 17, 7, 18, 4, 17, 1, 1, 10, 1,
11, 19, 17, 4, 18, 5, 18, 1, 17, 1, 0, 10, 1, 11, 19, 13, 19, 17,
1, 0, 1, 19],
[ 3, 6, 18, 7, 5, 14, 19, 12, 19, 17, 6, 18, 5, 17, 1, 1, 10, 1,
11, 19, 17, 3, 18, 7, 18, 1, 17, 1, 1, 10, 1, 11, 19, 13, 19, 17,
1, 1, 1, 19],
[ 5, 18, 8, 3, 14, 19, 12, 19, 17, 5, 18, 3, 17, 8, 10, 0, 11, 19,
17, 0, 18, 8, 17, 8, 10, 0, 11, 19, 13, 19, 17, 8, 8, 19, 15, 16,
16, 16, 16, 16]]), 'targets': tensor([[ 7, 18, 5, 4, 14, 19, 12, 19, 17, 7, 18, 4, 17, 1, 1, 10, 1, 11,
19, 17, 4, 18, 5, 18, 1, 17, 1, 0, 10, 1, 11, 19, 13, 19, 17, 1,
0, 1, 19, 15],
[ 6, 18, 7, 5, 14, 19, 12, 19, 17, 6, 18, 5, 17, 1, 1, 10, 1, 11,
19, 17, 3, 18, 7, 18, 1, 17, 1, 1, 10, 1, 11, 19, 13, 19, 17, 1,
1, 1, 19, 15],
[18, 8, 3, 14, 19, 12, 19, 17, 5, 18, 3, 17, 8, 10, 0, 11, 19, 17,
0, 18, 8, 17, 8, 10, 0, 11, 19, 13, 19, 17, 8, 8, 19, 15, 16, 16,
16, 16, 16, 16]]), 'attention_mask': tensor([[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, False, False, False, False, False, False]])}
Attention and Masking¶
We have two types of masks:
- Padding Mask – prevents attention from computing on
<pad>tokens - Causal Mask – enforces autoregressive behavior by blocking access to future tokens
In this problem, we focus on the causal mask. But to understand causal masking we first need to talk about the Attention Mechanism! We will be implementing this later, but we will walk through it here!

- Deep can only look at itself, and not any future words
- Learning can look at itself and the past word Deep
- is can look at itself and the past words Deep and Learning
- fun can look at itself and all the past words.
Without the masking, this attention matrix would be a regular encoder. If Deep could look at the future words learning, is and fun, then we break causality, we want to ensure that for every word in the sentence, we are only looking at the past to predict the next. (i.e at word n we want to predict what the next word at n+1 is based on all words from 1 to n-1). Unfortunately this is cheating, the whole purpose of the model is to predict future words, so if we are allowing the model to see the future words that it needs to predict, then the model wont learn anything but to just copy the future words for predictions.
How to do Attention Masking¶
So how do we actually perform attention masking? Lets first remind ourselves what attention is doing. Here is the equation for Attention as a reminder:
$$\text{Attention}(Q,K,V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_e}})V$$
So the first step is the computing the $QK^T$, where Q and K both have the shape (Sequence Length x Embedding Dimension). The output of this computation will be sequence length x sequence length. This is what it looks like!

In the image above, I also applied the softmax (not shown for simplicity), so each row of the attention matrix adds up to 1 (like probabilities).
Now we can multiply our output of $QK^T$ with our $V$. This is what a regular encoder (bidirectional) attention will look like. Remember, $V_1, ... V_4$ are the projection vectors (Values) of the data, and the attention matrix is a weighted average of all of these vectors.
Note
In transformers, our input $X$ goes through 3 different linear projections to create $Q, K, \text{and} V$ Initially, the each vector for each word in $X$ is the embedding vector representing that word. After the attention computation, each vector for each word isn't just the embedding of that word but rather a weighted average of all the vectors in the sequence and how they are related to the word of interest.
Whats the problem above? The first row vector in the output $0.2V_1 + 0.1V_2 + 0.4V_3 + 0.3V_4$ is a weighted average of the entire sequence (therefore getting information from future vectors). This is again cheating, so we need to mask our attention matrix, and set the cases of future words, in comparison to the word of interest, are set to a weight of 0. The word of interest and previous words and then reweighted to add up to 1. Therefore, we are only learning how every word is related to itself and the past!
Computing the Reweighted Causal Attention Mask¶
Lets pretend the raw outputs of $QK^T$, before the softmax, is below:
\begin{equation} \begin{bmatrix} 7 & -8 & 6 \\ -3 & 2 & 4 \\ 1 & 6 & -2 \\ \end{bmatrix} \end{equation}
Remember, the equation for softmax is:
$$\text{Softmax}(\vec{x}) = \frac{e^{x_i}}{\sum_{j=1}^N{e^{x_j}}}$$
Then, we can compute softmax for row of the matrix above:
\begin{equation} \text{Softmax} \begin{bmatrix} 7 & -8 & 6 \\ -3 & 2 & 4 \\ 1 & 6 & -2 \\ \end{bmatrix} = \begin{bmatrix} \frac{e^{7}}{e^{7}+e^{-8}+e^{6}} & \frac{e^{-8}}{e^{7}+e^{-8}+e^{6}} & \frac{e^{6}}{e^{7}+e^{-8}+e^{6}} \\ \frac{e^{-3}}{e^{-3}+e^{2}+e^{4}} & \frac{e^{2}}{e^{-3}+e^{2}+e^{4}} & \frac{e^{4}}{e^{-3}+e^{2}+e^{4}} \\ \frac{e^{1}}{e^{1}+e^{6}+e^{-2}} & \frac{e^{6}}{e^{1}+e^{6}+e^{-2}} & \frac{e^{-2}}{e^{1}+e^{6}+e^{-2}} \\ \end{bmatrix} = \begin{bmatrix} 0.73 & 0.0000002 & 0.27 \\ 0.0008 & 0.12 & 0.88 \\ 0.007 & 0.99 & 0.003 \\ \end{bmatrix} \end{equation}
But, what we want, is the top triangle to have weights of 0, and the rest adding up to 1. So lets take the second vector in the matrix above to see how we can do that.
$$x_2 = [-3, 2, 4]$$
Because this is the second vector, we need to zero out the softmax output for everything after the second index (so in our case just the last value). Lets replace the value 4 by $-\infty$. Then we can write it as:
$$x_2 = [-3, 2, -\infty]$$
Lets now take softmax of this vector!
$$\text{Softmax}(x_2) = [\frac{e^{-3}}{e^{-3}+e^{2}+e^{-\infty}}, \frac{e^{2}}{e^{-3}+e^{2}+e^{-\infty}}, \frac{e^{-\infty}}{e^{-3}+e^{2}+e^{-\infty}}]$$
Remember, $e^{-\infty}$ is equal to 0, so we can solve solve this!
$$\text{Softmax}(x_2) = [\frac{e^{-3}}{e^{-3}+e^{2}+0}, \frac{e^{2}}{e^{-3}+e^{2}+0}, \frac{0}{e^{-3}+e^{2}+0}] = [\frac{e^{-3}}{e^{-3}+e^{2}+0}, \frac{e^{2}}{e^{-3}+e^{2}+0}, \frac{0}{e^{-3}+e^{2}+0}] = [0.0067, 0.9933, 0.0000]$$
So we have exactly what we want! The attention weight of the last value is set to 0, so when we are on the second vector $x_2$, we cannot look forward to the future value vectors $v_3$, and the remaining parts add up to 1 so its still a probability vector! To do this correctly for the entire matrix, we can just substitute in the top triangle of $QK^T$ with $-\infty$. This would look like:
\begin{equation} \begin{bmatrix} 7 & -\infty & -\infty \\ -3 & 2 & -\infty \\ 1 & 6 & -2 \\ \end{bmatrix} \end{equation}
Taking the softmax of the rows of this matrix then gives:
\begin{equation} \text{Softmax} \begin{bmatrix} 7 & -\infty & -\infty \\ -3 & 2 & -\infty \\ 1 & 6 & -2 \\ \end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 \\ 0.0067 & 0.9933 & 0 \\ 0.007 & 0.99 & 0.003 \\ \end{bmatrix} \end{equation}
What about Attention Padding Masking?¶

The same principle applies to padding masking. We essentially want to insert -inf for padded positions that are not valid and we dont want to compute attention for, so when we matmul with the values, we place 0 emphasis on them! The key trick here will be our padding mask is passed in as (batch x seq_len and for every sample in the batch we basically have a True/False vector
that is True for positios we want Attention and False for positions we don't. What we really want is in our (seq_len x seq_len) attention matrix, the first dimension of seq_len is our query vectors, the second dimension of seq_len is our key vectors. We want to make sure every position where our query vectors attend to a padding position in the keys/values we will set that to 0 (or -inf before softmax). This works out to just setting the columns of our (seq_len x seq_len) matrix (for a single sample in the batch) that are padding positions to 0 as you see above!
Thats It!¶
For the most part, this is what powers GPT (along with buckets of data and giant models). Encoder type models (ViT, RoBERTa, etc...) do not have the Causal mask, but still may have attention masking.
Problem 3: Attention Causal Masking¶
What we want to do is write a method that can produce this attention mask given a sequence length. This will be a boolean tensor of the shape (1, seq_len, seq_len)
The (seq_len, seq_len) should be clear. The attention matrix we produce is of shape (seq_len, seq_len), and this mask will cover it, and be True for positions that are valid (we want to compute attention) and False for positions that are invalid (non-causal).
What about the extra (1,)? When we do our Attention Operation (we will look at this more closely later) our Queries, Keys and Values will be in the shape of:
(batch_size, seq_len, embed_dim). When we do our $QK^T$ we will have a tensor of the shape (batch_size, seq_len, seq_len) where for every sample in the batch, we have a (seq_len, seq_len) attention mask. The extra (1,) is just an extra dimensions we added in to act as placeholders to broadcast over in the batch dimensions.
HINT:¶
torch.tril should make this really easy
importlib.reload(submitted)
from submitted import create_causal_mask
create_causal_mask(5)
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])
What is the Linear Layer?¶
The linear layer is the bedrock of the Transformer model, so we will be implementing our own! This means we will write the forward and backward method for this operation! This is also a good time to learn the torch.autograd.Function method where we can write functions and their custom gradients!
As implemented in PyTorch, the Linear layer has the form:
$$y = xW^T + b$$
Multidimensional Matrix Multiplication:¶
The forward method should perform this operation and return the final output. But there is a caveat. In our Matrix multiplication $xW^T$, we are typically used to doing the operation between 2D arrays.
For example if we have the matrix multiplication of $A$ and $B$, then $A$ should have some shape (pxq) and $B$ should have some shape (qxr), and the final output of the matrix multiplication would be (pxr).
But in $xW^T$, $x$ is a tensor of shape (batch, seq_len, embed_dim) and $W^T$ is a tensor of shape (out_dim, embed_dim). This is a 3D tensor doing a matrix multiplication with a 2D tensor! How does this work?
In PyTorch Matmul is assumped to always be along the inner dimension. So what actually happens internally is $x$ will first be flattened from (batch, seq_len, embed_dim) to (batch*seq_len, embed_dim). Now its been flattend to a 2D array and we have the matmul between (batch*seq_len, embed_dim) and the transpose of (out_dim, embed_dim). So we will have:
(batch*seq_len, embed_dim) @ (out_dim, embed_dim).T = (batch*seq_len, embed_dim) @ (embed_dim, out_dim) = (batch*seq_len, out_dim)
Then PyTorch will unflatten the dimensions that it had flattened earlier like so:
(batch*seq_len, out_dim) => (batch, seq_len, out_dim)
So our final result essentially changed the embed_dim dimension and projected it to some other out_dim
Lets do an example!
x = torch.randn(2,3,4)
w = torch.randn(8,4)
print("PyTorch Multidimensional Matmul:")
out_multidim = x@w.T
print(out_multidim)
print(out_multidim.shape)
print("\nManual Reshape for Matmul:")
out_reshape_before = x.reshape(2*3,4)@w.T
print(out_reshape_before)
print(out_reshape_before.shape)
### Perform the reshape back to the original data shape ourselves
out_reshape_after = out_reshape_before.reshape(2,3,8)
print(out_reshape_after)
print(out_reshape_after.shape)
assert torch.allclose(out_reshape_after, out_multidim)
PyTorch Multidimensional Matmul:
tensor([[[-1.8542, 0.8069, 1.2487, -1.0750, -1.6353, 0.4645, -2.8518,
-0.4310],
[ 2.8592, 1.3956, -1.9772, 0.6518, 1.2616, -0.1656, 0.9610,
0.8828],
[ 0.2141, -2.4457, 0.3667, -3.0267, 0.1998, -1.8024, -2.4556,
-0.2906]],
[[-2.5685, -2.5846, 2.3751, -1.1941, -1.3559, -0.2207, -0.7792,
-0.7120],
[-1.0390, 0.6818, 1.2398, 0.7814, -1.5063, 1.1745, 0.1704,
0.2016],
[-0.1497, -3.5993, 0.4502, -0.8169, 0.9915, -1.3686, 1.4652,
-0.3970]]])
torch.Size([2, 3, 8])
Manual Reshape for Matmul:
tensor([[-1.8542, 0.8069, 1.2487, -1.0750, -1.6353, 0.4645, -2.8518, -0.4310],
[ 2.8592, 1.3956, -1.9772, 0.6518, 1.2616, -0.1656, 0.9610, 0.8828],
[ 0.2141, -2.4457, 0.3667, -3.0267, 0.1998, -1.8024, -2.4556, -0.2906],
[-2.5685, -2.5846, 2.3751, -1.1941, -1.3559, -0.2207, -0.7792, -0.7120],
[-1.0390, 0.6818, 1.2398, 0.7814, -1.5063, 1.1745, 0.1704, 0.2016],
[-0.1497, -3.5993, 0.4502, -0.8169, 0.9915, -1.3686, 1.4652, -0.3970]])
torch.Size([6, 8])
tensor([[[-1.8542, 0.8069, 1.2487, -1.0750, -1.6353, 0.4645, -2.8518,
-0.4310],
[ 2.8592, 1.3956, -1.9772, 0.6518, 1.2616, -0.1656, 0.9610,
0.8828],
[ 0.2141, -2.4457, 0.3667, -3.0267, 0.1998, -1.8024, -2.4556,
-0.2906]],
[[-2.5685, -2.5846, 2.3751, -1.1941, -1.3559, -0.2207, -0.7792,
-0.7120],
[-1.0390, 0.6818, 1.2398, 0.7814, -1.5063, 1.1745, 0.1704,
0.2016],
[-0.1497, -3.5993, 0.4502, -0.8169, 0.9915, -1.3686, 1.4652,
-0.3970]]])
torch.Size([2, 3, 8])
Problem 4: Custom Linear Layer¶
We will now be implementing the forward and backward pass of the linear layer!
Forward:¶
The forward method is exactly what we just outlined!
$$y = xW^T + b$$
Backward:¶
The backward method should return $\frac{dL}{dx}$, $\frac{dL}{dW}$, $\frac{dL}{db}$ given the upstream gradient $\frac{dL}{dy}$
An important part to remember is the gradients must be in the same shape as the data it goes with! For example, $W$ is going to be (out_features, in_features), so the gradients must also be in that same shape. Similary, $x$ can have any number of dimensions, but the last dimension must be in_features, so the output gradient for $x$ must also have exactly the same shape!
NOTE: What is ctx¶
You will notice a ctx being passed in. The ctx (or context) is just a container inside which we can store information in the forward pass that we need again in the backward pass. For example, we need $x$ and $W$ to compute our gradients, so we will store those in the forward pass and access them again in the backward pass! This part of the code is provided!
Verification¶
Code is provided here that compares your implementation to that of PyTorch Linear layers. A very similar setup will be tried in the tests as well!
importlib.reload(submitted)
from submitted import LinearFunction
### Create a Layer ###
layer = nn.Linear(6,8)
### Create copies of the weight/bias ###
weight_ = layer.weight.detach().clone().requires_grad_()
bias_ = layer.bias.detach().clone().requires_grad_()
### Create an Input ###
x = torch.randn(2,3,4,6, requires_grad=True)
### Create a Clone ###
x_ = x.detach().clone().requires_grad_()
### PyTorch Linear Output ###
torch_linear_out = layer(x)
### Our Linear Layer ###
custom_linear_out = LinearFunction.apply(x_, weight_, bias_)
### Check Outputs are Equivalent ###
print("Max Diff Output:", torch.max(torch.abs(torch_linear_out - custom_linear_out)).item())
assert torch.allclose(torch_linear_out, custom_linear_out, rtol=1e-3, atol=1e3)
### Backpropagate ###
upstream_grad = torch.randn_like(torch_linear_out)
torch_linear_out.backward(upstream_grad)
custom_linear_out.backward(upstream_grad)
### Check Gradients ###
print("Max Diff dW:", torch.max(torch.abs(layer.weight.grad - weight_.grad)).item())
assert torch.allclose(layer.weight.grad, weight_.grad, rtol=1e-3, atol=1e3)
print("Max Diff dB:", torch.max(torch.abs(layer.bias.grad - bias_.grad)).item())
assert torch.allclose(layer.bias.grad, bias_.grad, rtol=1e-3, atol=1e3)
print("Max Diff dX:", torch.max(torch.abs(x.grad - x_.grad)).item())
assert torch.allclose(x.grad, x_.grad, rtol=1e-3, atol=1e3)
Max Diff Output: 0.0 Max Diff dW: 0.0 Max Diff dB: 0.0 Max Diff dX: 0.0
Converting to a Module¶
Lets wrap your Linear Functional method with a module for easy use! This way we can use the Functional as a standard module later on.
import math
class MyLinear(nn.Module):
"""
weight: (D_out, D_in)
bias: (D_out,)
"""
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
### Define the Weights of the model
self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features))
### Define the Bias of the model
self.bias = nn.Parameter(torch.empty(self.out_features))
self.reset_parameters()
def reset_parameters(self):
### standard weight initialization
k = math.sqrt(self.in_features)
nn.init.uniform_(self.weight, -k,k)
nn.init.uniform_(self.bias, -k,k)
def forward(self, x):
### This is where your method LinearFunction will be utilized
return LinearFunction.apply(x, self.weight, self.bias)
Problem 5: Attention Mechanism¶
We will now be implementing the forward/backward pass of the attention mechanism!
The forward pass is as given:
$$\text{Attention}(Q,K,V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_e}})V$$
The main challenge will be to think through what the backward pass looks like! For simplicity, we will be assuming a single head of attention rather than multiheaded attention!
HINT: Draw It Out¶
Remember, all you really have here is a couple of Matmuls and a softmax. You know the derivative of matmuls (as you just completed the Linear layer above) and you know the derivative of a Softmax. It will be easier to draw out the operation and take it step by step!
Verification¶
Code is provided here that compares your implementation to that of PyTorch Scaled Dot Product Attention. A very similar setup will be tried in the tests as well!
importlib.reload(submitted)
from submitted import AttentionFunction
### Create Input Shapes ###
batch_size = 32
seq_len = 128
embed_dim = 256
### Create input tensors
Q = torch.randn(batch_size, seq_len, embed_dim, requires_grad=True)
K = torch.randn(batch_size, seq_len, embed_dim, requires_grad=True)
V = torch.randn(batch_size, seq_len, embed_dim, requires_grad=True)
dO = torch.randn(batch_size, seq_len, embed_dim)
causal_mask = ~torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool)
attention_mask[0, -(seq_len//2):] = False
attention_mask[0, -(seq_len//3):] = False
### Create Clone for Testing ###
Q_ = Q.detach().clone().requires_grad_(True)
K_ = K.detach().clone().requires_grad_(True)
V_ = V.detach().clone().requires_grad_(True)
dO_ = dO.detach().clone()
causal_mask_ = causal_mask.detach().clone()
attention_mask_ = attention_mask.detach().clone()
### Our Attention Implementation
output = AttentionFunction.apply(Q, K, V, causal_mask, attention_mask)
### Comparison to torch SDPA
causal_mask_ = causal_mask_.unsqueeze(0)
attention_mask_ = attention_mask_.unsqueeze(1).repeat(1,seq_len,1)
mask = causal_mask_ & attention_mask_
output_ = torch.nn.functional.scaled_dot_product_attention(Q_, K_, V_, attn_mask=mask)
### Check Output ###
print("Max Diff Attention Output:", torch.max(torch.abs(output - output_)).item())
assert torch.allclose(output, output_, rtol=1e-2, atol=1e2)
### Check Gradients ###
output.backward(dO)
output_.backward(dO_)
print("Max Diff dQ:", torch.max(torch.abs(Q.grad - Q_.grad)).item())
assert torch.allclose(Q.grad, Q_.grad, rtol=1e-2, atol=1e2)
print("Max Diff dK:", torch.max(torch.abs(K.grad - K_.grad)).item())
assert torch.allclose(K.grad, K_.grad, rtol=1e-2, atol=1e2)
print("Max Diff dV:", torch.max(torch.abs(V.grad - V_.grad)).item())
assert torch.allclose(V.grad, V_.grad, rtol=1e-2, atol=1e2)
Max Diff Attention Output: 0.0 Max Diff dQ: 2.980232238769531e-07 Max Diff dK: 4.76837158203125e-07 Max Diff dV: 0.0
Thats It!¶
The remaining portion of the MP will take everything you built so far, wrap it into a GPT Model and train it on our addition task! We did our best to annotate the code so you can learn how it all works!
Attention Block¶
The first piece we need is the actual attention block. The method you wrote earlier takes in $Q,K,V$, but what we really need is to go from some input tensor $x$ and first project to our $Q,K,V$ and then perform the attention operation!
class AttentionBlock(nn.Module):
def __init__(self, embed_dim):
super().__init__()
### Q,K,V projections using the MyLinear method based on the LinearFunction you wrote!
self.q_proj = MyLinear(embed_dim, embed_dim)
self.k_proj = MyLinear(embed_dim, embed_dim)
self.v_proj = MyLinear(embed_dim, embed_dim)
def forward(self, x, attention_mask=None):
batch_size, seq_len, embed_dim = x.shape
### Three projections to get Q,K,V ###
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
### Create a causal mask (based on the incoming seq_len) and correct device placement ###
causal_mask = create_causal_mask(seq_len).to(x.device)
### Perform Attention Operation ###
output = AttentionFunction.apply(Q, K, V, causal_mask, attention_mask)
return output
MultiLayerPerceptron¶
Each Attention block in the Transformer is followed up by a stack of some linear layers (known as a the multilayer perceptron or feedforward layers). This linear layer typically expands the embedding dimension by some constant (typically 4) and then compressed down again back to the original embedding dimension. The activation function used between the layhers in the original paper was ReLU, but more modern implementation leverage other activation functions, we will be using GeLU
class MLP(nn.Module):
"""Feed-forward network with GELU activation."""
def __init__(
self,
in_features,
feature_proj_multiplier,
):
super().__init__()
hidden_features = in_features * feature_proj_multiplier
self.fc1 = MyLinear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = MyLinear(hidden_features, in_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
The Transformer Block¶

As you can see, the transformer block is simply a stack of alternating Attention and MLP blocks. Additionally you see the term add and norm. This refers to including residual connections in every block which helps with deep models (check out ResNet for more details) and normalizations. There are a variety of normalization methods you can use, one of the most popular for LLMs today is LayerNormalization.
LayerNorm¶

LayerNorm normalizes each timestep (token) independently, across its feature dimension (embedding dimension). The importance of this is it controls the outputs of our attention and our perceptron and prevents them from exploding or vanishing. Essentially it rescaled each token to have a mean of 0 and a variance of 1!
Residual Connections¶
'
Instead of learning $y = F(x)$, we instead learn $y = x + F(x)$. The interpretation here is then the function $F$ learns an update to the input rather than a full transformation. Additionally, during the backward pass and computing gradients, there are two paths for upstream gradients $\frac{dL}{dY}$ to get to $x$. One stream goes through the function, and the other goes directly to x. Because we have direct gradient copies from upstream gradients to $x$ we can prevent vanishing gradients in deeper models.
In the same way, in our transformer we will add the output of attention to the input, and again add the output of the feedforward to the input.
class TransformerBlock(nn.Module):
def __init__(
self,
embed_dim: int = 768,
mlp_ratio: float = 4.0,
):
super().__init__()
### Define the first Layernorm (we provide the dim of each token embedding it will normalize)
self.norm1 = nn.LayerNorm(embed_dim)
### Our AttentionBlock based on the AttentionFunction and LinearFunction you wrote! ###
self.attn = AttentionBlock(embed_dim)
### Second Layernorm
self.norm2 = nn.LayerNorm(embed_dim, eps=1e-6)
### FeedForward Layers based on the LinearFunction you wrote!
self.mlp = MLP(embed_dim, mlp_ratio)
def forward(
self,
x,
attention_mask = None
):
# Residual Connections and Normalization
x = self.norm1(x + self.attn(x, attention_mask))
x = self.norm2(x + self.mlp(x))
return x
Tiny GPT¶
Now we have all our pieces to write our GPT model! This model is simply just a stack of our transformer blocks.
Positional Embeddings¶
Transformers have one property in that they are Permuation Invariant. This means the transformer doesn't actually care about the order of the tokens being passed in. But this is an issue as we are solving a sequence problem, the order of tokens matter ALOT! For example:
- the cat chased the mouse
- the mouse chased the cat
These have the same words, but mean completely different things! So each token embedding we pass to the transformer has to contain two things: Meaning of Token + Position of Token.
There are a ton of ways this is done today, the original method proposed in the Attention is All You Need paper was sinusoidal position embeddings
$$\text{PE}(t, 2i) = \sin\left(\frac{t}{10000^{\frac{2i}{d}}}\right)$$
$$\text{PE}(t, 2i+1) = \cos\left(\frac{t}{10000^{\frac{2i}{d}}}\right)$$
But another way they included in the paper was learnable positional embeddings. If I know that I have a max sequence length of 100 tokens, then why not create an Embedding matrix (just like for the vocabulary embeddings) that has 100 positions and the vocabs embedding dimension? Then I can index off the positions I need and add it in!
Both are valid methods and neither performs meaningfully better than the other, there are some limitations of positional embeddings when it comes to length extrapolation (applying your model to sequences longer than that it was trained on) but we wont worry about those details today!
class GPT(nn.Module):
"""Small GPT-style transformer for addition."""
def __init__(
self,
max_seq_len: int = 512, # What is the maximum number of tokens this model can process?
vocab_size: int = 100, # we will get our vocab size from the tokenizer
embed_dim: int = 384, # What is the embedding dimension of each work
depth: int = 6, # How many transformer blocks do we want?
mlp_ratio: float = 4.0, # MLP projection ratio in the feedforward
):
super().__init__()
self.max_seq_len = max_seq_len
### Token and position embeddings
self.embeddings = nn.Embedding(vocab_size, embed_dim) # Each vocab as a vector of embed dim
self.pos_embed = nn.Embedding(max_seq_len, embed_dim) # Each position (upto the max length of data we want) has a vector of embed dim
### Stack of Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
embed_dim=embed_dim,
mlp_ratio=mlp_ratio,
)
for _ in range(depth)
])
### Output Layernorm
self.norm = nn.LayerNorm(embed_dim)
### Each token (of dimension embed_dim) needs to predict the NEXT token (one of the tokens out of vocab_size tokens)
### so LLM training is basically a classification task!
self.head = MyLinear(embed_dim, vocab_size)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module):
"""Initialize model weights. This is the standard init for most LLMs out there!"""
if isinstance(module, (MyLinear, nn.Linear)):
torch.nn.init.trunc_normal_(module.weight, std=0.02, a=-2, b=2)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.trunc_normal_(module.weight, std=0.02, a=-2, b=2)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, attention_mask=None):
### We now pass our token indexes through our model! ###
device = x.device
batch_size, seq_len = x.shape
### Embed the Tokens ###
tok_emb = self.embeddings(x)
### Get Positional Embedding
pos_indices = torch.arange(0, seq_len, dtype=torch.long, device=device)
pos_emb = self.pos_embed(pos_indices)
### Add positional information to the token embeddings ###
x = tok_emb + pos_emb
### Apply transformer blocks
for block in self.blocks:
x = block(x, attention_mask)
# Output projection
x = self.norm(x)
x = self.head(x)
return x
@torch.no_grad()
def generate(
self,
input_tokens,
max_new_tokens=None,
temperature=0.3,
sample=True,
eos_token_id = None,
):
"""
Generate tokens autoregressively. This is what we will use during inference time!
Args:
input_tokens: Starting tokens of shape (1, seq_len)
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more random)
sample: If True, sample from distribution w/ multinomial sampling; if False, use argmax
eos_token_id: Optional end-of-sequence token ID to stop generation (this is our <assistant_end> token)
Returns:
Generated tokens of shape (1, original_len + generated_len)
"""
### If not provided set the model context as the max tokens ###
if max_new_tokens is None:
max_new_tokens = self.max_seq_len
for _ in range(max_new_tokens):
# Truncate to max sequence length to make sure we dont pass in more tokens than our model can process
idx_cond = (
input_tokens
if input_tokens.shape[1] <= self.max_seq_len
else input_tokens[:, -self.max_seq_len:]
)
### Pass input tokens through model
logits = self.forward(idx_cond)
### We only need to predict the next token, so index out the last token (and scale by temperature to increase/decrease stochasticity)
logits = logits[:, -1, :] / temperature
### Convert to probabilties, to see the distribution over which token is most likely next
probs = F.softmax(logits, dim=-1)
### Sample or take argmax
if sample:
### Stochastic, can randomly select a token (with likelihood of the probs as computed above)
idx_next = torch.multinomial(probs, num_samples=1)
else:
### Deterministic, will always give the same results
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
### Append new token to sequence
input_tokens = torch.cat([input_tokens, idx_next], dim=1)
### If our <assistant_end> token is generated we are done and can stop generating!
if eos_token_id is not None and idx_next.item() == eos_token_id:
break
return input_tokens
Training Loop¶
We can now train the model!
def train_model(
num_places: int = 2,
iterations: int = 5000,
batch_size: int = 16,
learning_rate: float = 0.0005,
embed_dim: int = 512,
depth: int = 8,
eval_interval: int = 1000,
device: str = None,
reasoning=True,
num_test_samples=3
):
"""
Train the addition model.
Args:
num_places: Number of digits in training data (e.g., 2 = 0-99)
iterations: Number of training iterations
batch_size: Batch size for training
learning_rate: Learning rate for optimizer
embed_dim: Embedding dimension
depth: Number of transformer layers
num_heads: Number of attention heads
dropout: Dropout probability
eval_interval: How often to evaluate and print samples
device: Device to train on ('cuda' or 'cpu')
reasoning: Do you want data to be reasoning or not?
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize tokenizer
tokenizer = Tokenizer()
# Prepare dataset
print(f"Building dataset with {num_places}-digit numbers...")
dataset = tokenize_dataset(build_dataset(num_places, reasoning, num_samples=int(batch_size * iterations)))
print(f"Dataset size: {len(dataset)} examples")
max_len_sample = max([len(i) for i in dataset])
print(f"Longest Sample in Dataset:", max_len_sample)
### Initialize Model
model = GPT(
max_seq_len=max_len_sample, # Set our context length of this LLM to whatever is longest in the dataset!
vocab_size=tokenizer.vocab_size,
embed_dim=embed_dim,
depth=depth,
)
### Move to correct device
model = model.to(device)
### Print model size
num_params = sum(p.numel() for p in model.parameters())
print(f"Model has {num_params:,} parameters")
### Initialize optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss() # standard classification loss
# Train Test Split
trainset, testset = random_split_dataset(dataset, test_samples=500)
trainloader = DataLoader(
trainset,
shuffle=True,
collate_fn=collate_fn,
batch_size=batch_size
)
testloader = DataLoader(
testset,
shuffle=True,
collate_fn=collate_fn,
batch_size=batch_size
)
### Sample Some Text for Eval from the testset
if reasoning:
sample_texts = [tokenizer.decode(i).split("<start_think>")[0].strip().replace("<assistant_start>", "") for i in random.sample(testset, num_test_samples)]
else:
sample_texts = [tokenizer.decode(i).split("=")[0].strip().replace("<assistant_start>", "") for i in random.sample(testset, num_test_samples)]
print("Samples Kept for Inference:", sample_texts)
### Tokenize our Samples (is_prompt is True as we are inferencing these so we need to manually add on that <assistant_start> token)
sample_tokens = [
torch.tensor(tokenizer.encode(text, is_prompt=True)).unsqueeze(0).to(device)
for text in sample_texts
]
# Training loop
model.train()
pbar = tqdm(total=iterations, desc="Training")
completed_steps = 0
total_loss = 0.0
loss_log = []
while completed_steps < iterations:
for batch in trainloader:
### Move batch to device
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
attention_mask = batch["attention_mask"].to(device)
### Forward pass
### logits: (batch_size, seq_len, vocab_size)
logits = model(input_ids, attention_mask)
### Reshape logits to (batch_size*seq_len, vocab_size)
logits = logits.reshape(-1, logits.shape[-1])
### Reshape targets from (batch_size, seq_len) -> (batch_size*seq_len)
### so each token in batch_size*seq_len in our (batch_size*seq_len, vocab_size) flattened logits matrix
### has a single target in our targets! This reshape is necessary for the CrossEntropyLoss input expectations
targets_flat = targets.reshape(-1)
### Compute loss
loss = loss_fn(logits, targets_flat)
loss_log.append(loss.item())
### Backward pass (compute gradients)
loss.backward()
### Clip gradients (to avoid explosions)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
### Update Model
optimizer.step()
### Zero out gradients for next iteration ###
optimizer.zero_grad(set_to_none=True)
### Update progress
total_loss += loss.item()
completed_steps += 1
pbar.update(1)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
### Evaluation
if completed_steps % eval_interval == 0:
### Put in eval mode
model.eval()
### Average loss for this evaluation on the trainset
avg_loss = total_loss / eval_interval
print(f"Step {completed_steps}/{iterations} | Avg Loss: {avg_loss:.4f}")
### Generate samples for each of our test cases we stored ahead of time
for sample_text, sample_token in zip(sample_texts, sample_tokens):
generated = model.generate(
sample_token,
max_new_tokens=max_len_sample,
temperature=0.3, # Low temperature for more deterministic output
sample=True,
eos_token_id=tokenizer.assistant_end_id,
)
generated_text = tokenizer.decode(generated[0])
print(f"\nInput: {sample_text}")
print(f"Correct Answer: {eval(sample_text)}")
print(f"Predicted Answer: {extract_answer(generated_text)}")
print(f"Raw Output: \n{generated_text}")
total_loss = 0.0
model.train()
if completed_steps >= iterations:
break
pbar.close()
print("\nTraining complete!")
### Do our Eval on the Test Set ###
model.eval()
results = evaluate_model(
model=model,
tokenizer=tokenizer,
testloader=testloader,
device=device,
max_new_tokens=max_len_sample,
temperature=0.1, # Low temperature for more deterministic output
sample=True, # Use greedy decoding or sampling for evaluation
reasoning=reasoning
)
return model, tokenizer, loss_log, results
Train 2 Models¶
We will train a model on reasoning data and non-reasoning data to see is there really a benefit here! By increasing num_places you also increase the number of digits you can add together, which will increase the length of the reasoning traces, but have fun and try different lengths if you want! By default I keep it at training upto 4 digits for each. As the traces get longer, you can assume that you will also need to boost the model capacity (more layers, or more complex architectures)!
print('\033[95m' + "Training a Reasoning Model..." + '\033[0m')
reasoning_model, reasoning_tokenizer, reasoning_loss_log, reasoning_results = train_model(
num_places=4,
iterations=5000,
batch_size=32,
learning_rate=0.0005,
embed_dim=384,
depth=6,
eval_interval=5000,
num_test_samples=1,
reasoning=True,
)
print("\n" + '\033[95m' + "Training a Non-Reasoning Model..." + '\033[0m')
no_reasoning_model, no_reasoning_tokenizer, no_reasoning_loss_log, no_reasoning_results = train_model(
num_places=4,
iterations=5000,
batch_size=32,
learning_rate=0.0005,
embed_dim=384,
depth=6,
eval_interval=5000,
num_test_samples=1,
reasoning=False
)
print("\n" + '\033[91m' + "RESULTS" + '\033[0m')
print(f"Reasoning Model Accuracy:", reasoning_results["accuracy"])
print(f"Non-Reasoning Model Accuracy:", no_reasoning_results["accuracy"])
Training a Reasoning Model...
Building dataset with 4-digit numbers...
Generating Random Dataset...
100%|███████████████████████████████████████████████████| 160000/160000 [00:00<00:00, 403215.35it/s]
Tokenizing Dataset... Dataset size: 160000 examples Longest Sample in Dataset: 73 Model has 9,803,924 parameters Samples Kept for Inference: ['3153+9776']
Training: 100%|████████████████████████████████████| 5000/5000 [01:19<00:00, 63.02it/s, loss=0.2279]
Step 5000/5000 | Avg Loss: 0.2537 Input: 3153+9776 Correct Answer: 12929 Predicted Answer: 12929 Raw Output: 3153+9776<assistant_start> <start_think> =3+6=9<start_carry>0<end_carry> =5+7=12<start_carry>1<end_carry> =1+7+1=9<start_carry>0<end_carry> =3+9=12<start_carry>1<end_carry> <end_think> =12929 <assistant_end> Training complete! Evaluating model on test set...
Evaluating: 100%|███████████████████████████████████████████████████| 16/16 [00:52<00:00, 3.31s/it]
Training a Non-Reasoning Model...
Building dataset with 4-digit numbers...
Generating Random Dataset...
100%|██████████████████████████████████████████████████| 160000/160000 [00:00<00:00, 1151832.37it/s]
Tokenizing Dataset... Dataset size: 160000 examples Longest Sample in Dataset: 17 Model has 9,782,420 parameters Samples Kept for Inference: ['2530+7982']
Training: 100%|███████████████████████████████████| 5000/5000 [00:43<00:00, 115.88it/s, loss=1.5214]
Step 5000/5000 | Avg Loss: 1.5388 Input: 2530+7982 Correct Answer: 10512 Predicted Answer: 10917 Raw Output: 2530+7982<assistant_start>=10917<assistant_end> Training complete! Evaluating model on test set...
Evaluating: 100%|███████████████████████████████████████████████████| 16/16 [00:05<00:00, 2.72it/s]
RESULTS
Reasoning Model Accuracy: 0.952
Non-Reasoning Model Accuracy: 0.002
It is pretty clear that reasoning traces help ALOT for the LLM to produce an output that is correct!! But you should also notice that inference time generation of the model with reasoning traces is quite a bit longer than the model that directly attempts to emit the output.
Loss Curves¶
Lets take a look at our loss curves for reasoning vs non-reasoning
plt.title("Reasoning vs Non-Reasoning Loss Curves")
plt.plot(reasoning_loss_log, label="Reasoning")
plt.plot(no_reasoning_loss_log, label="Non-Reasoning")
plt.legend()
plt.show()
Some Caveats¶
If you start playing with this model you will notice that the model performs better on 4 digit summations than anything else! This is a data balance issue, as the majority of the data will be dominated by 4 digit summations. Most of the error you see are coming from less than 4 digit summations.
# Count numbers by digit length in range(0, 10000)
counts = {
1: 10, # 0-9
2: 90, # 10-99
3: 900, # 100-999
4: 9000 # 1000-9999
}
total_numbers = sum(counts.values()) # 10000
total_pairs = total_numbers ** 2 # Total (i, j) pairs
# Same-digit pairs
same_digit_counts = {k: counts[k] ** 2 for k in counts}
# Mixed-digit pairs
mixed_pairs = total_pairs - sum(same_digit_counts.values())
# Prepare data for plotting
labels = [
"Both 1-digit",
"Both 2-digit",
"Both 3-digit",
"Both 4-digit",
"Mixed"
]
values = [
same_digit_counts[1],
same_digit_counts[2],
same_digit_counts[3],
same_digit_counts[4],
mixed_pairs
]
# Plot
plt.figure(figsize=(8,5))
plt.bar(labels, values, color=['skyblue','lightgreen','orange','red','purple'])
plt.xticks(rotation=45)
plt.title("Distribution of Digit-Length Pair Types (0–9999)")
plt.ylabel("Number of (i, j) Pairs")
plt.tight_layout()
plt.show()
Accuracy by Digit Count¶
Lets go ahead and just sample some summations from each of these categories and see how it performs!
device = "cuda" if torch.cuda.is_available() else "cpu"
single_digit_sums = []
double_digit_sums = []
triple_digit_sums = []
quadruple_digit_sums = []
mixed_sums = []
single_digit_results = []
double_digit_results = []
triple_digit_results = []
quadruple_digit_results = []
mixed_results = []
# Single digit: exhaustive (100 examples)
for a in range(10):
for b in range(10):
single_digit_sums.append(f"{a}+{b}")
# Double digit: sample 150
for _ in range(150):
a = random.randint(10, 99)
b = random.randint(10, 99)
double_digit_sums.append(f"{a}+{b}")
# Triple digit: sample 150
for _ in range(150):
a = random.randint(100, 999)
b = random.randint(100, 999)
triple_digit_sums.append(f"{a}+{b}")
# Quadruple digit: sample 150
for _ in range(150):
a = random.randint(1000, 9999)
b = random.randint(1000, 9999)
quadruple_digit_sums.append(f"{a}+{b}")
# Mixed Digit
digit_ranges = [(0, 10), (10, 100), (100, 1000), (1000, 10000)]
for _ in range(150):
# Pick two different digit ranges
range1, range2 = random.sample(digit_ranges, 2)
a = random.randint(*range1)
b = random.randint(*range2)
mixed_sums.append(f"{a}+{b}")
for input_strings, results_store in zip(
[single_digit_sums, double_digit_sums, triple_digit_sums, quadruple_digit_sums, mixed_sums],
[single_digit_results, double_digit_results, triple_digit_results, quadruple_digit_results, mixed_results]
):
for input_string in tqdm(input_strings):
input_tokens = torch.tensor(reasoning_tokenizer.encode(input_string, is_prompt=True), dtype=torch.long).to(device).unsqueeze(0)
generated = reasoning_model.generate(
input_tokens,
temperature=0.3, # Low temperature for more deterministic output
sample=True,
eos_token_id=reasoning_tokenizer.assistant_end_id,
)
decoded = reasoning_tokenizer.decode(generated[0])
pred_answer = extract_answer(decoded)
correct = (pred_answer == str(eval(input_string)))
results_store.append(correct)
def get_acc(l):
return round((sum(l) / len(l)) * 100, 2)
### Just some plotting stuff
# Labels for the x-axis
categories = [
"Single Digit",
"Double Digit",
"Triple Digit",
"Quadruple Digit",
"Mixed Digit"
]
# Accuracies in percentages
accuracies = [
get_acc(single_digit_results),
get_acc(double_digit_results),
get_acc(triple_digit_results),
get_acc(quadruple_digit_results),
get_acc(mixed_results)
]
# Create the bar plot
plt.figure(figsize=(8, 5))
bars = plt.bar(categories, accuracies, color='skyblue', edgecolor='black')
# Add value labels on top of each bar
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 0.5, f"{yval:.2f}%", ha='center', va='bottom')
# Titles and labels
plt.ylabel("Accuracy (%)")
plt.title("Accuracy per Digit Category")
plt.ylim(0, 105) # leave space for labels on top
plt.xticks(rotation=15)
plt.show()
100%|█████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 11.09it/s] 100%|█████████████████████████████████████████████████████████████| 150/150 [00:13<00:00, 10.92it/s] 100%|█████████████████████████████████████████████████████████████| 150/150 [00:12<00:00, 12.23it/s] 100%|█████████████████████████████████████████████████████████████| 150/150 [00:16<00:00, 9.36it/s] 100%|█████████████████████████████████████████████████████████████| 150/150 [00:13<00:00, 10.80it/s]
Data Balance¶
What we show is that the data is clearly imbalanced. We could easily do some type of statified sampling and oversample the smaller groups and undersample larger ones and our LLM should perform better across the board! BUT!!! What about actual LLMs training on terabytes of raw text? If you have trillions of words in front of you how do you balance it? What does it mean to have balance? There are a ton of ways you can define balance:
- Topic balancings
- Fiction vs Non-Fiction
- Natural Text vs Progamming Languages
- English vs Non-English languages
- Types of Complexity
- Conversational Dataset vs Academic Texts
- Poetry vs Textbooks
- Legal Contracts vs Mathematics
This is a really challenging part of LLM training. Most SOTA LLMs have massive capital expenditures as its trained on enormous corpuses, but a huge effort goes into properly balancing the dataset towards the goals of the LLM. This is why most LLMs have multiple stages of training:
- Pretraining: Train on a ton of data (hopefully balances as much as you can)
- Midtraining: Start to bias the model towards your target domain (conversational, math, code, legal, etc..)
- SFT: A small subset of extremely high quality data to continue improving your model (and getting it into the typical instruction format like we have in our data of users and assistants)
- RL Preference Optimization: Optimizes the "behavior" of the LLMs towards what humans think is high quality
Play with the Adder Yourself!¶
Even as a bad calculator this is fun so try different things!
input_string = "1234+5678"
input_tokens = torch.tensor(reasoning_tokenizer.encode(input_string, is_prompt=True), dtype=torch.long).to(device).unsqueeze(0)
generated = reasoning_model.generate(
input_tokens,
temperature=0.3, # Low temperature for more deterministic output
sample=True,
eos_token_id=reasoning_tokenizer.assistant_end_id,
)
decoded = reasoning_tokenizer.decode(generated[0])
pred_answer = extract_answer(decoded)
print("Predicted Answer:", pred_answer)
print("Correct Answer:", eval(input_string))
Predicted Answer: 6912 Correct Answer: 6912