LLMs From Scratch - Chapter 4: GPT from Scratch

llms
gpt
tutorial
Author

Daniel Pickem

Published

May 16, 2025

GPT from scratch

This notebook explores full LLM architecture for GPT based on Sebastian Raschka’s book (Chapter 4), implementing normalization layers, shortcut connections, and transformer blocks. This notebook also shows how to compute the parameter count as well as the storage requirements of GPT-like models.

Acknowledgment

All concepts, architectures, and implementation approaches are credited to Sebastian Raschka’s work. This repository serves as my personal implementation and notes while working through the book’s content.

Resources

GPT architecture

The components of a GPT-like architecture are shown in the image below.

GPT components

The following figure shows the flow of data through the model and how they are transformed at each stage.

GPT data flow
import dataclasses
from typing import Dict, List, Tuple

import torch
import torch.nn as nn

import tiktoken

Model configuration

@dataclasses.dataclass(frozen=True)
class GPTConfig:
    """Configuration for the GPT model.

    Attributes:
        vocab_size: The size of the vocabulary.
        context_length: The maximum number of tokens that the model can process in a single forward
            pass (i.e. the maximum sequence length). Also denotes the number of input tokens the
            model can handle via the positional embeddings (from chapter 2).
        emb_dim: The dimension of the token embeddings.
        n_heads: The number of attention heads.
        n_layers: The number of transformer layers.
        dropout_rate: The dropout rate for the transformer layers.
        qkv_bias: Whether to use bias in the QKV projections.
    """

    vocab_size: int
    context_length: int
    emb_dim: int
    n_heads: int
    n_layers: int
    dropout_rate: float
    qkv_bias: bool


# Instantiate the GPT-2 configuration.
GPT_CONFIG_124M = GPTConfig(
    vocab_size=50257,  # as used by the BPE tokenizer for GPT-2.
    context_length=1024,
    emb_dim=768,
    n_heads=12,
    n_layers=12,
    dropout_rate=0.1,
    qkv_bias=False,
)

Example data

tokenizer = tiktoken.get_encoding("gpt2")

batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)

print(batch)
tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])

Placeholder architecture

class DummyLayerNorm(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()

    def forward(self, x):
        return x


class DummyTransformerBlock(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()

    def forward(self, x):
        return x


class DummyGPTModel(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim)
        self.pos_emb = nn.Embedding(cfg.context_length, cfg.emb_dim)
        self.drop_emb = nn.Dropout(cfg.dropout_rate)
        self.transformer_blocks = nn.Sequential(
            *[DummyTransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )
        self.final_norm = DummyLayerNorm(cfg.emb_dim)
        self.out_head = nn.Linear(cfg.emb_dim, cfg.vocab_size, bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.transformer_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits


# Instantiate the placeholder model.
model = DummyGPTModel(GPT_CONFIG_124M)

# Feed example data through it.
torch.manual_seed(123)
logits = model(batch)
print("Output shape:", logits.shape)
print(logits)
Output shape: torch.Size([2, 4, 50257])
tensor([[[-0.5640, -0.8061,  0.5556,  ..., -0.0121, -0.1054, -0.6956],
         [ 0.7765, -1.0823,  1.1478,  ...,  0.3540,  0.9335, -0.1086],
         [-1.2292,  1.0940, -1.0646,  ..., -1.6435, -0.0090,  0.0854],
         [-1.6669, -1.4979,  0.6711,  ...,  0.4817, -0.1961, -0.8135]],

        [[-0.9549, -0.7139,  0.3969,  ...,  0.0884, -0.2823, -0.7652],
         [ 1.2106,  0.4950,  1.3336,  ...,  0.4385, -1.0175, -0.2729],
         [ 1.2351,  0.0197,  0.5197,  ..., -1.0924,  1.2276, -0.4298],
         [-0.6795, -1.1446,  1.1755,  ...,  0.7063, -0.5953, -1.7645]]],
       grad_fn=<UnsafeViewBackward0>)

Layer normalization

Training deep neural networks with many layers can sometimes prove challenging due to problems like vanishing or exploding gradients. These problems lead to unsta- ble training dynamics and make it difficult for the network to effectively adjust its weights, which means the learning process struggles to find a set of parameters (weights) for the neural network that minimizes the loss function.

The main idea behind layer normalization is to adjust the activations (outputs) of a neural network layer to have a mean of 0 and a variance of 1, also known as unit variance. This adjustment speeds up the convergence to effective weights and ensures consistent, reliable training.

Layer normalization

Example of aggregating across different dimensions of a tensor illustrating the use of the dim argument of aggregation functions like mean or std.

Illustration of dim argument

Example

# Example of layer normalization.
torch.set_printoptions(sci_mode=False)
torch.manual_seed(123)

# Create two training examples with 5 features each.
batch_example = torch.randn(2, 5)

# Create a basic neural network layer consisting of a Linear layer followed by a non-linear
# activation function, ReLU (short for rectified linear unit).
layer = nn.Sequential(nn.Linear(5, 6), nn.ReLU())
out = layer(batch_example)

# Compute mean and variance along the feature dimension (i.e. the last dimension).
# NOTE: Using keepdim=True in operations like mean or variance calculation ensures that the output
#       tensor retains the same number of dimensions as the input tensor, even though the operation
#       reduces the tensor along the dimension specified via dim. Here, without keepdim=True, the
#       output tensor would be a two-dimensional vector (e.g. [1, 2]) rather than a 2x1-dimensional
#       matrix (e.g. [[1], [2]]).
mean = out.mean(dim=-1, keepdim=True)
var = out.var(dim=-1, keepdim=True)
print(f"\nRaw layer outputs:\n{'-' * 18}\n{out}")
print("Mean:\n", mean)
print("Variance:\n", var)

# Apply layer normalization.
out_norm = (out - mean) / torch.sqrt(var)
mean = out_norm.mean(dim=-1, keepdim=True)
var = out_norm.var(dim=-1, keepdim=True)
print(f"\nNormalized layer outputs:\n{'-' * 25}\n{out_norm}")
print(f"\nMean:\n{mean}")
print(f"\nVariance:\n{var}")

Raw layer outputs:
------------------
tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)
Mean:
 tensor([[0.1324],
        [0.2170]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[0.0231],
        [0.0398]], grad_fn=<VarBackward0>)

Normalized layer outputs:
-------------------------
tensor([[ 0.6159,  1.4126, -0.8719,  0.5872, -0.8719, -0.8719],
        [-0.0189,  0.1121, -1.0876,  1.5173,  0.5647, -1.0876]],
       grad_fn=<DivBackward0>)

Mean:
tensor([[    0.0000],
        [    0.0000]], grad_fn=<MeanBackward1>)

Variance:
tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)

A layer normalization class

class LayerNorm(nn.Module):
    """Layer normalization (https://arxiv.org/abs/1607.06450).

    This specific implementation of layer normalization operates on the last dimension of
    the input tensor x, which represents the embedding dimension (emb_dim).
    """

    def __init__(self, emb_dim: int):
        super().__init__()

        # Use a small constant which will be added to the variance to prevent division by zero.
        self.eps = 1e-5

        # The scale and shift are two trainable parameters (of the same dimension as the input)
        # that the LLM automatically adjusts during training.
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift


# Test the layer normalization class.
ln = LayerNorm(emb_dim=5)
out_ln = ln(batch_example)
mean = out_ln.mean(dim=-1, keepdim=True)
var = out_ln.var(dim=-1, unbiased=False, keepdim=True)
print("Mean:\n", mean)
print("Variance:\n", var)
Mean:
 tensor([[    -0.0000],
        [     0.0000]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)

GELU activation function

This section is skipped in this notebook since torch already implements to full version of GELU as well as the curve fitting approximation.

See GELU activation function

The below figure shows a comparison of GELU and ReLU.

Feed forward network

Feed Forward NN module

An overview of the connections between the layers of the feed forward neural network. This neural network can accommodate variable batch sizes and numbers of tokens in the input. However, the embedding size for each token is determined and fixed when initializing the weights.

Note how the size of the embedding dimension is increased by 4x before it is shrunk down to the input embedding dimension.

Feed forward network

The FeedForward module plays a crucial role in enhancing the model’s ability to learn from and generalize the data. Although the input and output dimensions of this module are the same, it internally expands the embedding dimension into a higher-dimensional space through the first linear layer (as shown below). This expansion is followed by a nonlinear GELU activation and then a contraction back to the original dimension with the second linear transformation. Such a design allows for the exploration of a richer representation space.

Feed forward network

Moreover, the uniformity in input and output dimensions simplifies the architecture by enabling the stacking of multiple layers, as we will do later, without the need to adjust dimensions between them, thus making the model more scalable.

class FeedForward(nn.Module):
    """Feed Forward Neural Network (FFNN) module.

    This module implements a feed-forward neural network with two linear layers and a GELU
    activation function.
    """

    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg.emb_dim, 4 * cfg.emb_dim),
            torch.nn.GELU(approximate="tanh"),
            nn.Linear(4 * cfg.emb_dim, cfg.emb_dim),
        )

    def forward(self, x):
        return self.layers(x)


# Test the feed forward neural network.
ffn = FeedForward(GPT_CONFIG_124M)
x = torch.rand(2, 3, 768)
out = ffn(x)
print(out.shape)
torch.Size([2, 3, 768])

Skip connections

Shortcut connections are also known as skip or residual connections. Originally, shortcut connections were proposed for deep networks in computer vision (specifically, in residual networks) to mitigate the challenge of vanishing gradients. The vanishing gradient problem refers to the issue where gradients (which guide weight updates during training) become progressively smaller as they propagate backward through the layers, making it difficult to effectively train earlier layers.

They play a crucial role in preserving the flow of gradients during the backward pass in training.

Skip connections
class ExampleDeepNeuralNetwork(nn.Module):
    def __init__(self, layer_sizes, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut

        # Implement a deep neural network with 5 layers of linear transformations and GELU
        # activations.
        self.layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(layer_sizes[0], layer_sizes[1]),
                    torch.nn.GELU(approximate="tanh"),
                ),
                nn.Sequential(
                    nn.Linear(layer_sizes[1], layer_sizes[2]),
                    torch.nn.GELU(approximate="tanh"),
                ),
                nn.Sequential(
                    nn.Linear(layer_sizes[2], layer_sizes[3]),
                    torch.nn.GELU(approximate="tanh"),
                ),
                nn.Sequential(
                    nn.Linear(layer_sizes[3], layer_sizes[4]),
                    torch.nn.GELU(approximate="tanh"),
                ),
                nn.Sequential(
                    nn.Linear(layer_sizes[4], layer_sizes[5]),
                    torch.nn.GELU(approximate="tanh"),
                ),
            ]
        )

    def forward(self, x):
        # Compute the output of the current layer.
        for layer in self.layers:
            layer_output = layer(x)

            # Check if shortcut can be applied.
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x + layer_output
            else:
                x = layer_output

        return x


# Implement a function for printing gradients.
def print_gradients(model, x):
    """Print gradients of the model."""
    # Forward pass.
    output = model(x)
    target = torch.tensor([[0.0]])

    # Calculates loss based on how close the target and output are
    loss = nn.MSELoss()
    loss = loss(output, target)

    # Backward pass to compute gradients.
    loss.backward()

    # Visualize the gradients of the model.
    # NOTE: This loop iterates over all of the model's parameters and prints the mean of the
    #       absolute mean values of the gradients for each parameter.
    # NOTE: Suppose we have a 3 × 3 weight parameter matrix for a given layer. In that case,
    #       this layer will have 3 × 3 gradient values, and we print the mean absolute gradient
    #       of these 3 × 3 gradient values to obtain a single gradient value per layer to compare
    #       the gradients between layers more easily.
    for name, param in model.named_parameters():
        if "weight" in name:
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")


# Test the network with and without shortcut connections.

# Specifies random seed for the initial weights for reproducibility
torch.manual_seed(123)
layer_sizes = [3, 3, 3, 3, 3, 1]
sample_input = torch.tensor([[1.0, 0.0, -1.0]])

# Instantiate the model without shortcut connections.
# NOTE: The gradients in earlier layers become too small to be useful for training, which is due
#       to the vanishing gradient problem.
model_without_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=False)
print("-" * 100)
print(f"Model without shortcut connections:\n{'-' * 100}")
print_gradients(model_without_shortcut, sample_input)
print("-" * 100)

# Instantiate the model with shortcut connections.
# NOTE: The gradients in earlier layers are now more stable and useful for training, which is due
#       to the preservation of the flow of gradients during the backward pass.
torch.manual_seed(123)
model_with_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=True)
print_gradients(model_with_shortcut, sample_input)
print(f"Model with shortcut connections:\n{'-' * 100}")
----------------------------------------------------------------------------------------------------
Model without shortcut connections:
----------------------------------------------------------------------------------------------------
layers.0.0.weight has gradient mean of 0.0002017359365709126
layers.1.0.weight has gradient mean of 0.00012011162471026182
layers.2.0.weight has gradient mean of 0.0007152041071094573
layers.3.0.weight has gradient mean of 0.0013988733990117908
layers.4.0.weight has gradient mean of 0.005049645435065031
----------------------------------------------------------------------------------------------------
layers.0.0.weight has gradient mean of 0.22169798612594604
layers.1.0.weight has gradient mean of 0.20694109797477722
layers.2.0.weight has gradient mean of 0.3289699852466583
layers.3.0.weight has gradient mean of 0.26657330989837646
layers.4.0.weight has gradient mean of 1.3258543014526367
Model with shortcut connections:
----------------------------------------------------------------------------------------------------

Transformer block

The operations within the transformer block, including multi-head attention and feed forward layers, are designed to transform these vectors in a way that preserves their dimensionality. The idea is that the self-attention mechanism in the multi-head attention block identifies and analyzes relationships between elements in the input sequence. In contrast, the feed forward network modifies the data individually at each position. This combination not only enables a more nuanced understanding and processing of the input but also enhances the model’s overall capacity for handling complex data patterns.

Transformer block

As we can see, the transformer block maintains the input dimensions in its output, indicating that the transformer architecture processes sequences of data without altering their shape throughout the network.

The preservation of shape throughout the transformer block architecture is not incidental but a crucial aspect of its design. This design enables its effective application across a wide range of sequence-to-sequence tasks, where each output vector directly corresponds to an input vector, maintaining a one-to-one relationship. However, the output is a context vector that encapsulates information from the entire input sequence. This means that while the physical dimensions of the sequence (length and feature size) remain unchanged as it passes through the transformer block, the content of each output vector is re-encoded to integrate contextual information from across the entire input sequence.

# NOTE: This is the MultiHeadAttention class from chapter 3 of the book
#       (see 03_attention_mechanisms.ipynb).
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float,
        num_heads: int,
        qkv_bias: bool = False,
    ):
        """Initialize the multi-head attention class.

        Args:
            d_in: The dimension of the input.
            d_out: The dimension of the output.
            context_length: The length of the context. This argument sets the length of the causal
                mask, i.e. the maximum supported sequence length.
            dropout: The dropout probability.
            num_heads: The number of attention heads.
            qkv_bias: Whether to use a bias in the query, key, and value projections.
        """
        super().__init__()

        # Verify that the output dimension is divisible by the number of heads, which is required
        # for splitting the output into the specified number of heads.
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        # Cache the output dimension and the number of heads for later use.
        # NOTE: This reduces the projection dim to match the desired output dim.
        # NOTE: The key operation is to split the d_out dimension into num_heads and head_dim,
        #       where head_dim = d_out / num_heads. This splitting is then achieved using the .view
        #       method: a tensor of dimensions (b, num_tokens, d_out) is reshaped to dimension
        #       (b, num_tokens, num_heads, head_dim).
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # Initialize the weight matrices.
        # NOTE: Only a single weight matrix is initialized for the query, key, and value projections
        #       since we will split the output into the specified number of heads.
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Initialize the output projection layer.
        # NOTE: This implementation uses a Linear layer to combine all head outputs.
        self.out_proj = nn.Linear(d_out, d_out)

        # Initialize the dropout layer.
        self.dropout = nn.Dropout(dropout)

        # Register a buffer for the mask.
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Extract input dimensions.
        # NOTE: The customary notation for these dimensions is [B, T, D] where:
        #
        #   B = batch size
        #   T = sequence length (num_tokens)
        #   D = model dimension (d_in)
        b, num_tokens, d_in = x.shape

        # Project the input into query, key, and value vectors.
        # NOTE: The shape of the projected query, key, and value vectors is [B, T, D]
        keys = self.W_k(x)  # [B, T, D]
        queries = self.W_q(x)  # [B, T, D]
        values = self.W_v(x)  # [B, T, D]

        # Split the query, key, and value vectors into the specified number of heads.
        # NOTE: We implicitly split the matrix by adding a num_heads dimension. Then we unroll the
        #       last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim).
        # Compute head dimension
        head_dim = self.d_out // self.num_heads

        # Reshape to [B, T, H, D_h] where:
        #
        #   B = batch size
        #   T = sequence length (num_tokens)
        #   H = number of heads
        #   D_h = dimension per head (head_dim)
        #
        # NOTE: We implicitly split the matrix by adding a num_heads dimension. Then we unroll the
        #       last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim).
        # NOTE: See section "A note on views" for more details.
        keys = keys.view(b, num_tokens, self.num_heads, head_dim)
        values = values.view(b, num_tokens, self.num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, head_dim)

        # Transpose from [B, T, H, D_h] to [B, H, T, D_h], i.e. from shape
        # (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        #
        # NOTE: This transposition is crucial for correctly aligning the queries, keys, and values
        #       across the different heads and performing batched matrix multiplications
        #       efficiently.
        # NOTE: This reshaping results in each head having access to the full sequence of tokens
        #       (i.e. a tensor of shape T x D_h).
        keys = keys.transpose(1, 2)  # shape [B, H, T, D_h]
        queries = queries.transpose(1, 2)  # shape [B, H, T, D_h]
        values = values.transpose(1, 2)  # shape [B, H, T, D_h]

        # Compute the unnormalized attention scores via a dot product for each head.
        # NOTE: We transpose the T and D_h dimension (i.e. num_tokens and head_dim) just like we
        #       have already done in the "Trainable self-attention" section.
        # NOTE: Change key shape from [B, H, T, D_h] to [B, H, D_h, T]
        # NOTE: The output shape of attn_scores is [B, H, T, T], i.e. a square matrix of size T
        #       (i.e. sequence length) for each head.
        # NOTE: The following operation does a batched matrix multiplication between queries and
        #       keys. In this case, the matrix multiplication implementation in PyTorch handles the
        #       four-dimensional input tensor so that the matrix multiplication is carried out
        #       between the two last dimensions (num_tokens, head_dim) and then repeated for the
        #       individual heads.
        # NOTE: See section "A note on batched matrix multiplications" for more details.
        attn_scores = queries @ keys.transpose(2, 3)  # shape [B, H, T, T]

        # The mask is truncated to the number of tokens in the input sequence (i.e. sequence
        # length T)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]  # shape [T, T]

        # Apply the mask to the attention scores.
        attn_scores.masked_fill_(mask_bool, -torch.inf)  # shape [B, H, T, T]

        # Compute the normalized attention weights (as before)
        # NOTE: The scaling factor is the last dimension of the keys tensor (i.e. head_dim, see
        #       line 88).
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)  # shape [B, H, T, T]

        # Compute the context vectors.
        # NOTE: The shapes of the individual tensors are:
        #       - attn_weights: [B, H, T, T]
        #       - values: [B, H, T, D_h]
        #       - context_vec: [B, H, T, D_h]
        #       - context_vec.transposed: [B, T, H, D_h]
        # NOTE: The context vectors from all heads are transposed back to the shape
        #       (b, num_tokens, num_heads, head_dim).
        context_vec = (attn_weights @ values).transpose(1, 2)  # shape [B, T, H, D_h]

        # Reshape/flatten the context vectors from [B, T, H, D_h] to [B, T, D_out], i.e. combine all
        # individual attention heads (d_out = H * D_h).
        # NOTE: Combines heads, where self.d_out = self.num_heads * self.head_dim (see line 32).
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        # Optionally project the context vectors to the output dimension.
        # NOTE: This output projection layer is not strictly necessary (see appendix B of the book
        #       for more details), but it is commonly used in many LLM architectures.
        context_vec = self.out_proj(context_vec)
        return context_vec
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # Initialize the multi-head attention module.
        self.mha = MultiHeadAttention(
            d_in=cfg.emb_dim,
            d_out=cfg.emb_dim,
            context_length=cfg.context_length,
            num_heads=cfg.n_heads,
            dropout=cfg.dropout_rate,
            qkv_bias=cfg.qkv_bias,
        )

        # Initialize the feed forward module.
        self.ff = FeedForward(cfg)

        # Initialize the layer normalization modules.
        self.pre_attention_norm = LayerNorm(cfg.emb_dim)
        self.pre_ff_norm = LayerNorm(cfg.emb_dim)

        # Initialize the dropout module.
        self.drop_shortcut = nn.Dropout(cfg.dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for the transformer block.

        NOTE: Layer normalization (LayerNorm) is applied before each of these two components,
              and dropout is applied after them to regularize the model and prevent overfitting.
              This is also known as Pre-LayerNorm. Older architectures, such as the original
              transformer model, applied layer normalization after the self-attention and feed
              forward networks instead, known as Post-LayerNorm, which often leads to worse
              training dynamics.
        """
        # Shortcut connection for attention block (this is the original input).
        shortcut = x

        # Layer norm + attention + dropout.
        x = self.pre_attention_norm(x)
        x = self.mha(x)
        x = self.drop_shortcut(x)

        # Add the original input back.
        x = x + shortcut

        # Shortcut connection for feed forward block (this is the output of the attention block).
        shortcut = x

        # Layer norm + feed forward + dropout.
        x = self.pre_ff_norm(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)

        # Add the post-attention shortcut output back.
        x = x + shortcut

        return x


# Test the transformer block.
torch.manual_seed(123)

# Creates sample input of shape [batch_size, num_tokens, emb_dim].
# NOTE: In more standard notation, this is [B, T, D] where:
#
#   B = batch size
#   T = sequence length (num_tokens)
#   D = model dimension (emb_dim)
x = torch.rand(2, 4, 768)

block = TransformerBlock(GPT_CONFIG_124M)
output = block(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

The full GPT model

An overview of the GPT model architecture showing the flow of data through the GPT model. Starting from the bottom, tokenized text is first converted into token embeddings, which are then augmented with positional embeddings. This combined information forms a tensor that is passed through a series of transformer blocks shown in the center (each containing multi-head attention and feed forward neural network layers with dropout and layer normalization), which are stacked on top of each other and repeated 12 times.

GPT model
class GPTModel(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()

        # Initialize the token embedding layer (including positional embeddings).
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.emb_dim)
        self.pos_emb = nn.Embedding(cfg.context_length, cfg.emb_dim)

        # Initialize the dropout layer.
        self.drop_emb = nn.Dropout(cfg.dropout_rate)

        # Initialize the transformer blocks.
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )

        # Initialize the final layer normalization.
        self.final_norm = LayerNorm(cfg.emb_dim)
        self.out_head = nn.Linear(cfg.emb_dim, cfg.vocab_size, bias=False)

    def forward(self, in_idx: torch.Tensor) -> torch.Tensor:
        # Extract input dimensions.
        # NOTE: The more standard notation for these dimensions is [B, T] where:
        #
        #   B = batch size
        #   T = sequence length (num_tokens)
        batch_size, seq_len = in_idx.shape

        # Embed the input tokens and add positional encodings.
        # NOTE: The token embeddings are a tensor of shape [B, T, D] where:
        #
        #   B = batch size
        #   T = sequence length (num_tokens)
        #   D = model dimension (emb_dim)
        tok_embeds = self.tok_emb(in_idx)
        # NOTE: The positional encodings are a tensor of shape [T, D] where:
        #
        #   T = sequence length (num_tokens)
        #   D = model dimension (emb_dim)
        #
        # The positional encodings though are limited in length to the context length, which is
        # set to 1024 for the GPT model.
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds

        # Dropout the embedded tokens.
        x = self.drop_emb(x)

        # Pass through the transformer blocks.
        x = self.trf_blocks(x)

        # Final layer normalization.
        x = self.final_norm(x)

        # Project the output to the vocabulary size.
        logits = self.out_head(x)
        return logits


# Test the GPT model.
torch.manual_seed(123)

model = GPTModel(GPT_CONFIG_124M)
out = model(batch)

print("Input batch:\n", batch)
print("\nOutput shape:", out.shape)
print(out)
Input batch:
 tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])

Output shape: torch.Size([2, 4, 50257])
tensor([[[ 0.1381,  0.0077, -0.1963,  ..., -0.0222, -0.1060,  0.1717],
         [ 0.3865, -0.8408, -0.6564,  ..., -0.5163,  0.2369, -0.3357],
         [ 0.6989, -0.1829, -0.1631,  ...,  0.1472, -0.6504, -0.0056],
         [-0.4290,  0.1669, -0.1258,  ...,  1.1579,  0.5303, -0.5549]],

        [[ 0.1094, -0.2894, -0.1467,  ..., -0.0557,  0.2911, -0.2824],
         [ 0.0882, -0.3552, -0.3527,  ...,  1.2930,  0.0053,  0.1898],
         [ 0.6091,  0.4702, -0.4094,  ...,  0.7688,  0.3787, -0.1974],
         [-0.0612, -0.0737,  0.4751,  ...,  1.2463, -0.3834,  0.0609]]],
       grad_fn=<UnsafeViewBackward0>)

Analyzing the model architecture

# NOTE: The reason for the total number of parameters to be 163M instead of 124M is a concept called
#       weight tying, which was used in the original GPT-2 architecture. It means that the original
#       GPT-2 architecture reuses the weights from the token embedding layer in its output layer.
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")
print("Token embedding layer shape:", model.tok_emb.weight.shape)
print("Output layer shape:", model.out_head.weight.shape)
Total number of parameters: 163,009,536
Token embedding layer shape: torch.Size([50257, 768])
Output layer shape: torch.Size([50257, 768])
total_params_gpt2 = total_params - sum(p.numel() for p in model.out_head.parameters())
print(
    f"Number of trainable parameters "
    f"considering weight tying: {total_params_gpt2:,}"
)
Number of trainable parameters considering weight tying: 124,412,160
# Calculates the total size in bytes (assuming float32, 4 bytes per parameter).
total_size_bytes = total_params * 4
total_size_mb = total_size_bytes / (1024 * 1024)
print(f"Total size of the model: {total_size_mb:.2f} MB")
Total size of the model: 621.83 MB

Generating text

Starting with an initial input context (“Hello, I am”), the model predicts a subsequent token during each iteration, appending it to the input context for the next round of prediction.

Steps: 1. Decode the output tensors 2. Select tokens based on a probability distribution 3. Convert these tokens to human-readable text

Text generation process

The detailed steps for generating text are shown in the figure below:

  1. In each step, the model outputs a matrix with vectors representing potential next tokens.
  2. The vector corresponding to the next token is extracted and converted into a probability distribution via the softmax function.
  3. Within the vector containing the resulting probability scores, the index of the highest value is located, which translates to the token ID. Selecting the highest probability is also called “greedy decoding” but other strategies are also possible (top-k sampling, top-p sampling, etc.).
  4. This token ID is then decoded back into text, producing the next token in the sequence.
  5. Finally, this token is appended to the previous inputs, forming a new input sequence for the subsequent iteration.

This step-by-step process enables the model to generate text sequentially, building coherent phrases and sentences from the initial input context. In practice, we repeat this process over many iterations until we reach a user-specified number of generated tokens (or the model produces an EOS - end of sequence - token).

Text generation process

TODO: It seems wasteful to generate output tokens for each input token when we are really only interested in the last output token (i.e. the next token).

def generate_text_simple(
    model: GPTModel, idx: torch.Tensor, max_new_tokens: int, context_size: int
) -> torch.Tensor:
    """Generate text using a simple greedy decoding strategy.

    NOTE: This function requires the model to be put into eval mode. This has to be done by the
          caller.

    Args:
        model: The GPT model.
        idx: The input tensor of shape [B, T] containing the tokenized input text, i.e. a
            (batch, n_tokens) array of indices in the current context.
        max_new_tokens: The maximum number of new tokens to generate.
        context_size: The size of the context window.

    Returns:
        A tensor of shape [B, T + max_new_tokens] containing the generated text.
    """
    for _ in range(max_new_tokens):
        # Crops current context if it exceeds the supported context size, e.g., if LLM supports
        # only 5 tokens, and the context size is 10, then only the last 5 tokens are used as
        # context.
        idx_cond = idx[:, -context_size:]

        # Disable gradient calculation since they are not needed for inference.
        with torch.no_grad():
            # Forward pass through the model.
            # NOTE: In more standard notation, this is [B, T, V] where:
            #
            #   B = batch size
            #   T = sequence length (num_tokens)
            #   V = vocabulary size
            logits = model(idx_cond)  # shape [B, T, V]

            # Focus only on the last time step, so that (batch, n_token, vocab_size) becomes
            # (batch, vocab_size).
            logits = logits[:, -1, :]  # shape [B, V]

            # Convert the logits to probabilities.
            # NOTE: The softmax function is applied over the vocabulary dimension (i.e. last
            #       dimension). That results in a probability distribution over the vocabulary such
            #       that the most likely token can be selected.
            probas = torch.softmax(logits, dim=-1)  # shape [B, V]

            # Select the token with the highest probability.
            idx_next = torch.argmax(probas, dim=-1, keepdim=True)  # shape [B, 1]

            # Append the selected / sampled index / token to the input sequence.
            idx = torch.cat((idx, idx_next), dim=1)  # shape [B, T + 1]

    return idx


# Test the greedy decoding strategy.
start_context = "Hello, I am"
encoded = tokenizer.encode(start_context)

# NOTE: The unsqueeze function adds a new dimension at the specified index. In this case, we add a
#       new dimension at index 0, which is the batch dimension.
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print(f"Encoded input tensor: {encoded}")
print(f"Encoded input tensor shape: {encoded_tensor.shape}")

# Disables dropout since we are not training the model.
model.eval()
out = generate_text_simple(
    model=model,
    idx=encoded_tensor,
    max_new_tokens=6,
    context_size=GPT_CONFIG_124M.context_length,
)
print(f"Output tensor: {out}")
print(f"Output length: {len(out.squeeze(0))}")

# Decodes the generated text.
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"Decoded output text: {decoded_text}")
Encoded input tensor: [15496, 11, 314, 716]
Encoded input tensor shape: torch.Size([1, 4])
Output tensor: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267]])
Output length: 10
Decoded output text: Hello, I am Featureiman Byeswickattribute argue