LLMs From Scratch - Chapter 6: Fine-tuning for Classification

llms
fine-tuning
tutorial
Author

Daniel Pickem

Published

May 16, 2025

Fine-tuning for classification

This notebook explores the fine-tuning process of LLMs with the purpose of creating a classification model based on Sebastian Raschka’s book (Chapter 6). In particular, it discusses the following:

  • Introducing different LLM fine-tuning approaches
  • Preparing a dataset for text classification
  • Modifying a pretrained LLM for fine-tuning
  • Fine-tuning an LLM to identify spam messages
  • Evaluating the accuracy of a fine-tuned LLM classifier
  • Using a fine-tuned LLM to classify new data

Instruction fine-tuning

  • Instruction-tuned models can typically handle a broader range of tasks
  • More general approach that can handle multiple tasks
  • Best suited for models that need to handle a variety of tasks based on complex user instructions
  • These models improve flexibility and interaction quality
  • Instruction fine-tuning requires larger datasets and greater computational resources

Classification fine-tuning

  • Ideal for projects requiring precise categorization into predefined classes (e.g. sentiment analysis or spam detection)
  • Specialized approach targeted at outputting a specific set of labels
  • The model is restricted to only the labels encountered during training
  • Requires less data and compute power

Acknowledgment

All concepts, architectures, and implementation approaches are credited to Sebastian Raschka’s work. This repository serves as my personal implementation and notes while working through the book’s content.

Resources

Topic overview
# Install import-ipynb for importing ipynb files.
# %pip install import-ipynb
from typing import Optional, Tuple
import urllib.request
import zipfile
import os
from pathlib import Path


import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import tiktoken
from tqdm.notebook import tqdm

import pandas as pd

# Import previous chapter dependencies.
# See https://stackoverflow.com/questions/44116194/import-a-function-from-another-ipynb-file
# NOTE: Importing these functions seems to run the entire cell the symbol is defined in, which would
#       suggest that symbols should be defined in separate cells from the test code.
import import_ipynb
from gpt_download import download_and_load_gpt2
from chapter_02_dataset_creation import create_dataloader_v1
from chapter_04_gpt_from_scratch import (
    GPTConfig,
    GPTModel,
    generate_text_simple,
)

# NOTE: Importing another ipynb file basically runs the entire imported notebook.
from chapter_05_pretraining_on_unlabeled_data import (
    generate,
    token_ids_to_text,
    text_to_token_ids,
)

# Define the base config.
GPT_CONFIG_124M = GPTConfig(
    vocab_size=50257,  # as used by the BPE tokenizer for GPT-2.
    context_length=1024,
    emb_dim=768,
    n_heads=12,
    n_layers=12,
    dropout_rate=0.0,  # disable dropout for inference
    qkv_bias=False,
)

# Determine the device to run the model on.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Stage 1: Preparing the dataset

This section follows stage 1 in the following figure:

Dataset preparation

Download the dataset

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = Path("data/sms_spam_collection.zip")
extracted_path = Path("data/sms_spam_collection")
data_file_path = extracted_path / "SMSSpamCollection.tsv"


def download_and_unzip_spam_data(
    url: str, zip_path: Path, extracted_path: Path, data_file_path: Path
):
    """Download and unzip the spam data from the UCI repository.

    Args:
        url: The URL of the zip file.
        zip_path: The path to save the zip file.
        extracted_path: The path to save the extracted files.
        data_file_path: The path to save the data file.
    """
    # Check if the file already exists.
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download " "and extraction.")
        return

    # Download the zip file.
    with urllib.request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    # Extract the zip file.
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    # Add a .tsv extension to the file (tab-separated values).
    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")


download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)

Load dataset

# Load data into a pandas DataFrame.
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])

# Show the label count.
print(df["Label"].value_counts())

# Show a few examples.
df.head()

Balancing the dataset

# Create a balanced dataset by undersampling the majority class.
def create_balanced_dataset(df: pd.DataFrame) -> pd.DataFrame:
    """Create a balanced dataset by undersampling the majority class.

    NOTE: This function can quite significantly reduce the size of the dataset.

    Args:
        df: The input DataFrame.

    Returns:
        A balanced DataFrame.
    """
    num_spam = len(df[df["Label"] == "spam"])
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
    return balanced_df.reset_index(drop=True)


# Create a balanced dataset.
balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

# Convert string labels to integers.
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
balanced_df

Splitting the datast

def random_split(
    df: pd.DataFrame, train_frac: float, validation_frac: float
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Split a DataFrame into train, validation, and test sets.

    NOTE: The size of the test set is implied to be the remainder of train and validation fraction
          (all fractions should add up to 1).

    Args:
        df: The input DataFrame.
        train_frac: The fraction of the dataset to use for training.
        validation_frac: The fraction of the dataset to use for validation.

    Returns:
        A tuple of DataFrames for train, validation, and test sets.
    """
    # Shuffle the entire DataFrame
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)

    # Calculate split indices (for train and validation explicitly.)
    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)

    # Split the DataFrame.
    train_df = df[:train_end]
    validation_df = df[train_end:validation_end]
    test_df = df[validation_end:]

    return train_df, validation_df, test_df


# Test size is implied to be 0.2 as the remainder.
train_df, validation_df, test_df = random_split(
    df=balanced_df, train_frac=0.7, validation_frac=0.1
)
print(f"Train set size: {len(train_df)}")
print(f"Validation set size: {len(validation_df)}")
print(f"Test set size: {len(test_df)}")

# Save the DataFrames to CSV files.
train_df.to_csv(extracted_path / "train.csv", index=None)
validation_df.to_csv(extracted_path / "validation.csv", index=None)
test_df.to_csv(extracted_path / "test.csv", index=None)

Creating the datasets

Previously, we utilized a sliding window technique to generate uniformly sized text chunks, which we then grouped into batches for more efficient model training. Each chunk functioned as an individual training instance. However, we are now working with a spam dataset that contains text messages of varying lengths. To batch these messages as we did with the text chunks, we have two primary options:

  • Truncate all messages to the length of the shortest message in the dataset or batch.
  • Pad all messages to the length of the longest message in the dataset or batch.

The first option is computationally cheaper, but it may result in significant information loss if shorter messages are much smaller than the average or longest messages, potentially reducing model performance. So, we opt for the second option, which preserves the entire content of all messages.

To implement batching, where all messages are padded to the length of the longest message in the dataset, we add padding tokens to all shorter messages. For this purpose, we use “<|endoftext|>” as a padding token. However, instead of appending the string “<|endoftext|>” to each of the text messages directly, we can add the token ID corresponding to “<|endoftext|>” to the encoded text messages

Padding approach

The example below shows what a training batch looks like. A single training batch consisting of eight text messages represented as token IDs. Each text message consists of 120 token IDs. A class label array stores the eight class labels corresponding to the text messages, which can be either 0 (“not spam”) or 1 (“spam”).

Training batch example
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")
print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))
# Creating a Dataset class.
class SpamDataset(Dataset):
    """Dataset class for the spam dataset."""

    def __init__(
        self,
        csv_file: Path,
        tokenizer: tiktoken.Encoding,
        max_length: Optional[int] = None,
        pad_token_id: int = 50256,
    ):
        """
        Initializes the SpamDataset class.

        Args:
            csv_file: The path to the CSV file containing the data.
            tokenizer: The tokenizer to use.
            max_length: The maximum length of the encoded texts.
            pad_token_id: The ID of the padding token.
        """
        # Load the data from the CSV file.
        self.data = pd.read_csv(csv_file)

        # Pretokenize all texts.
        self.encoded_texts = [tokenizer.encode(text) for text in self.data["Text"]]

        if max_length is None:
            # If no maximum length is provided, use the longest encoded text.
            self.max_length = self._longest_encoded_length()
        else:
            # Truncate sequences if they are longer than max_length.
            self.max_length = max_length
            self.encoded_texts = [
                encoded_text[: self.max_length] for encoded_text in self.encoded_texts
            ]

        # Pads sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long),
        )

    def __len__(self) -> int:
        return len(self.data)

    def _longest_encoded_length(self) -> int:
        """Determine the longest encoded text length."""
        return max(len(encoded_text) for encoded_text in self.encoded_texts)


# Load the training dataset.
train_dataset = SpamDataset(
    csv_file=extracted_path / "train.csv", max_length=None, tokenizer=tokenizer
)

# Load the validation and test sets and limit the max length to the same value as the training set.
# NOTE: Importantly, any validation and test set samples exceeding the length of the longest
#       training example are truncated using encoded_text[:self.max_length] in the SpamDataset code
#       we defined earlier. This truncation is optional; you can set max_length=None for both
#       validation and test sets, provided there are no sequences exceeding 1,024 tokens in these
#       sets.
val_dataset = SpamDataset(
    csv_file=extracted_path / "validation.csv",
    max_length=None,
    tokenizer=tokenizer,
)
test_dataset = SpamDataset(
    csv_file=extracted_path / "test.csv",
    max_length=None,
    tokenizer=tokenizer,
)

# Show the maximum length of the encoded texts.
print(f"Maximum length of the encoded texts: {train_dataset.max_length}")
print(f"Maximum length of the encoded texts: {val_dataset.max_length}")
print(f"Maximum length of the encoded texts: {test_dataset.max_length}")

# Verify that the maximum length does not exceed the context length.
assert train_dataset.max_length <= GPT_CONFIG_124M.context_length
assert val_dataset.max_length <= GPT_CONFIG_124M.context_length
assert test_dataset.max_length <= GPT_CONFIG_124M.context_length

Creating the data loaders

# This num_worker setting ensures compatibility with most computers.
num_workers = 0
batch_size = 8
torch.manual_seed(123)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

# Show the size of the data loaders.
print(f"Train set size: {len(train_loader)}")
print(f"Validation set size: {len(val_loader)}")
print(f"Test set size: {len(test_loader)}")

# Show the first batch of the training set.
# NOTE: As we can see, the input batches consist of eight training examples with 120 tokens each,
#       as expected. The label tensor stores the class labels corresponding to the eight training
#       examples.
print("\nFirst training batch:")
for input_batch, target_batch in train_loader:
    print("Input batch dimensions:", input_batch.shape)
    print("Label batch dimensions", target_batch.shape)
    break

Stage 2: Model Setup

Model initialization

Initializing a model with pretrained weights

import dataclasses

# Load the base config.
GPT_CONFIG_124M = GPTConfig(
    vocab_size=50257,  # as used by the BPE tokenizer for GPT-2.
    context_length=1024,
    emb_dim=768,
    n_heads=12,
    n_layers=12,
    dropout_rate=0.0,  # disable dropout for inference
    qkv_bias=False,
)

# Update the model configuration to conform to the model size.
model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

# Instantiate a base config.
tmp_config = dataclasses.asdict(GPT_CONFIG_124M)

# Load the overlay parameters.
model_name = "gpt2-small (124M)"
tmp_config.update(model_configs[model_name])

# Update the context length to match OpenAI's GPT-2 models.
tmp_config.update({"context_length": 1024})

# OpenAI used bias vectors in the multi-head attention module’s linear layers to implement the
# query, key, and value matrix computations. Bias vectors are not commonly used in LLMs anymore as
# they don’t improve the modeling performance and are thus unnecessary. However, since we are
# working with pretrained weights, we need to match the settings for consistency and enable these
# bias vectors.
tmp_config.update({"qkv_bias": True})

# Instantiate the new configuration.
NEW_CONFIG = GPTConfig(**tmp_config)

# Initialize the model with the new configuration.
gpt = GPTModel(NEW_CONFIG)
gpt.eval()
# NOTE: This code is copied from chapter_05_pretraining_on_unlabeled_data.ipynb because the import
#       from load_weights_into_gpt.py is not working.


def assign(left, right):
    """Safely assign the right weight tensor to the left layer.

    Checks whether two tensors or arrays (left and right) have the same dimensions or shape and
    returns the right tensor as trainable PyTorch parameters.
    """
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch. Left: {left.shape}, " "Right: {right.shape}")

    return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_gpt(gpt: GPTModel, params: dict):
    # Sets the model’s positional and token embedding weights to those specified in params.
    gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
    gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])

    # Iterates over each transformer block in the model.
    for b in range(len(params["blocks"])):
        # The np.split function is used to divide the attention and bias weights into three equal
        # parts for the query, key, and value components.
        q_w, k_w, v_w = np.split(
            (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1
        )
        gpt.trf_blocks[b].mha.W_q.weight = assign(
            gpt.trf_blocks[b].mha.W_q.weight, q_w.T
        )
        gpt.trf_blocks[b].mha.W_k.weight = assign(
            gpt.trf_blocks[b].mha.W_k.weight, k_w.T
        )
        gpt.trf_blocks[b].mha.W_v.weight = assign(
            gpt.trf_blocks[b].mha.W_v.weight, v_w.T
        )
        q_b, k_b, v_b = np.split(
            (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1
        )
        gpt.trf_blocks[b].mha.W_q.bias = assign(gpt.trf_blocks[b].mha.W_q.bias, q_b)
        gpt.trf_blocks[b].mha.W_k.bias = assign(gpt.trf_blocks[b].mha.W_k.bias, k_b)
        gpt.trf_blocks[b].mha.W_v.bias = assign(gpt.trf_blocks[b].mha.W_v.bias, v_b)
        gpt.trf_blocks[b].mha.out_proj.weight = assign(
            gpt.trf_blocks[b].mha.out_proj.weight,
            params["blocks"][b]["attn"]["c_proj"]["w"].T,
        )
        gpt.trf_blocks[b].mha.out_proj.bias = assign(
            gpt.trf_blocks[b].mha.out_proj.bias,
            params["blocks"][b]["attn"]["c_proj"]["b"],
        )
        gpt.trf_blocks[b].ff.layers[0].weight = assign(
            gpt.trf_blocks[b].ff.layers[0].weight,
            params["blocks"][b]["mlp"]["c_fc"]["w"].T,
        )
        gpt.trf_blocks[b].ff.layers[0].bias = assign(
            gpt.trf_blocks[b].ff.layers[0].bias, params["blocks"][b]["mlp"]["c_fc"]["b"]
        )
        gpt.trf_blocks[b].ff.layers[2].weight = assign(
            gpt.trf_blocks[b].ff.layers[2].weight,
            params["blocks"][b]["mlp"]["c_proj"]["w"].T,
        )
        gpt.trf_blocks[b].ff.layers[2].bias = assign(
            gpt.trf_blocks[b].ff.layers[2].bias,
            params["blocks"][b]["mlp"]["c_proj"]["b"],
        )
        gpt.trf_blocks[b].pre_attention_norm.scale = assign(
            gpt.trf_blocks[b].pre_attention_norm.scale, params["blocks"][b]["ln_1"]["g"]
        )
        gpt.trf_blocks[b].pre_attention_norm.shift = assign(
            gpt.trf_blocks[b].pre_attention_norm.shift, params["blocks"][b]["ln_1"]["b"]
        )
        gpt.trf_blocks[b].pre_ff_norm.scale = assign(
            gpt.trf_blocks[b].pre_ff_norm.scale, params["blocks"][b]["ln_2"]["g"]
        )
        gpt.trf_blocks[b].pre_ff_norm.shift = assign(
            gpt.trf_blocks[b].pre_ff_norm.shift, params["blocks"][b]["ln_2"]["b"]
        )

        gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
        gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])

        # The original GPT-2 model by OpenAI reused the token embedding weights in the output layer
        # to reduce the total number of parameters, which is a concept known as weight tying.
        gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
# Download the GPT-2 weights.
settings, params = download_and_load_gpt2(model_size="124M", models_dir="gpt2")

# Load the weights into the model.
load_weights_into_gpt(gpt, params)
gpt.to(device)
# Test the model to verify that it can generate coherent text.
text_1 = "Every effort moves you"
token_ids = generate_text_simple(
    model=gpt.to(device),
    idx=text_to_token_ids(text_1, tokenizer).to(device),
    max_new_tokens=15,
    context_size=NEW_CONFIG.context_length,
)
print(token_ids_to_text(token_ids, tokenizer))
# Check if the model is already capable of classifying spam and ham messages via instruction
# examples.
text_2 = (
    "Is the following text 'spam'? Answer with 'yes' or 'no':"
    " 'You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award.'"
)
token_ids = generate_text_simple(
    model=gpt.to(device),
    idx=text_to_token_ids(text_2, tokenizer).to(device),
    max_new_tokens=23,
    context_size=NEW_CONFIG.context_length,
)
print(token_ids_to_text(token_ids, tokenizer))

Modify model for fine-tuning (adding a classification head)

Adapting a GPT model for spam classification by altering its architecture. Initially, the model’s linear output layer mapped 768 hidden units to a vocabulary of 50,257 tokens. To detect spam, we replace this layer with a new output layer that maps the same 768 hidden units to just two classes, representing “spam” and “not spam.”

Fine-tuning selected layers vs. all layers

Since we start with a pretrained model, it’s not necessary to fine-tune all model layers. In neural network-based language models, the lower layers generally capture basic language structures and semantics applicable across a wide range of tasks and datasets. So, fine-tuning only the last layers (i.e., layers near the output), which are more specific to nuanced linguistic patterns and task-specific features, is often sufficient to adapt the model to new tasks. A nice side effect is that it is computationally more efficient to fine-tune only a small number of layers.

Model modification
# Prepare the model for fine-tuning.

# 1. Freeze all parameters in the model.
for param in gpt.parameters():
    param.requires_grad = False

# 2. Replace the final linear layer with a new one for the two classes.
torch.manual_seed(123)
num_classes = 2
gpt.out_head = nn.Linear(GPT_CONFIG_124M.emb_dim, num_classes)

# Mark additional layers as trainable, in particular the last transformer block as well as the
# final layer norm.
for param in gpt.trf_blocks[-1].parameters():
    param.requires_grad = True
for param in gpt.final_norm.parameters():
    param.requires_grad = True
# Try running the model with a random input to see that it is working.
inputs_str = "Do you have time"
inputs = tokenizer.encode(inputs_str)
inputs = torch.tensor(inputs).unsqueeze(0)
print("Inputs:", inputs_str)
print("Inputs dimensions:", inputs.shape)  # B x T, i.e. batch size x sequence length

with torch.no_grad():
    outputs = gpt.to(device)(inputs.to(device))

# NOTE: The output shape is B x T x 2, i.e. batch size x sequence length x number of classes.
#       The model produces logits for each class and for each token in the input sequence.
# NOTE: We are interested in fine-tuning this model to return a class label indicating whether a
#       model input is “spam” or “not spam.” We don’t need to fine-tune all four output rows;
#       instead, we can focus on a single output token. In particular, we will focus on the last
#       row corresponding to the last output token.
print("Outputs:\n", outputs)
print(
    "Outputs dimensions:", outputs.shape
)  # B x T x 2, i.e. batch size x sequence length x number of classes
print("Last output token:", outputs[:, -1, :])

Selecting the right output for fine-tuning

Output selection

To understand why we are particularly interested in the last output token only let’s take a look at the attention mechanism. We have already explored the attention mechanism, which establishes a relationship between each input token and every other input token, and the concept of a causal attention mask. This mask restricts a token’s focus to its current position and the those before it, ensuring that each token can only be influenced by itself and the preceding tokens (as shown below).

The empty cells indicate masked positions due to the causal attention mask, preventing tokens from attending to future tokens. The values in the cells represent attention scores; the last token, time, is the only one that computes attention scores for all preceding tokens.

The last token in a sequence accumulates the most information since it is the only token with access to data from all the previous tokens. Therefore, in our spam classification task, we focus on this last token during the fine-tuning process.

Output selection

Evaluation utilities

Similar to next token prediction, we use softmax to compute probabilities for the output logits, in particular, probabilities for each class (spam, not spam) - as shown below.

Computing classification probabilities
# Compute the probabilities for the last output token.
probas = torch.softmax(outputs[:, -1, :], dim=-1)

# Compute the class label.
# NOTE: {"ham": 0, "spam": 1}
label = torch.argmax(probas)
print("Inputs:", inputs_str)
print("Class label:", label.item())
# A utility function for computing classification accuracy for a data loader.
def calc_accuracy_loader(
    data_loader: DataLoader,
    model: GPTModel,
    device: torch.device,
    num_batches: int = None,
) -> float:
    """Compute the accuracy of a model on a data loader.

    Args:
        data_loader: The data loader to compute the accuracy on.
        model: The model to compute the accuracy on.
        device: The device to compute the accuracy on.
        num_batches: The number of batches to compute the accuracy on. Defaults to None.

    Returns:
        The accuracy of the model on the data loader.
    """
    # Set the model to evaluation mode (to avoid tracking gradients).
    model.eval()

    # Initialize the number of correct predictions and the number of examples.
    correct_predictions, num_examples = 0, 0

    # If the number of batches is not specified, use all batches.
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))

    # Iterate over the data loader.
    for i, (input_batch, target_batch) in enumerate(data_loader):
        # If the number of batches has not been reached, compute the accuracy.
        if i < num_batches:
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            # Compute the logits for the last output token.
            with torch.no_grad():
                # NOTE: The output shape is B x T x 2, i.e. batch size x sequence length x number
                #       of classes. Here, we are only interested in the logits for the last output
                #       token.
                logits = model(input_batch)[:, -1, :]

            # Compute the predicted labels.
            # NOTE: dim=-1 computes the argmax over the classes.
            predicted_labels = torch.argmax(logits, dim=-1)

            # Update the number of examples and the number of correct predictions.
            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break

    return correct_predictions / num_examples
# Compute baseline accuracy for the not yet fine-tuned model.

# Move the model to the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt.to(device)
torch.manual_seed(123)

# Compute the accuracy for the training, validation, and test sets.
train_accuracy = calc_accuracy_loader(train_loader, gpt, device, num_batches=10)
val_accuracy = calc_accuracy_loader(val_loader, gpt, device, num_batches=10)
test_accuracy = calc_accuracy_loader(test_loader, gpt, device, num_batches=10)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Define the loss function

However, before we begin fine-tuning the model, we must define the loss function we will optimize during training. Our objective is to maximize the spam classification accuracy of the model, which means that the preceding code should output the correct class labels: 0 for non-spam and 1 for spam. Because classification accuracy is not a differentiable function, we use cross-entropy loss as a proxy to maximize accuracy.

def calc_loss_batch(
    input_batch: torch.Tensor,
    target_batch: torch.Tensor,
    model: GPTModel,
    device: torch.device,
) -> torch.Tensor:
    """Compute the loss for a batch of inputs and targets.

    Args:
        input_batch: The input batch.
        target_batch: The target batch.
        model: The model.
        device: The device.

    Returns:
        The loss for the batch.
    """
    # Move the input and target batches to the device.
    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)

    # Compute the logits for the last output token.
    logits = model(input_batch)[:, -1, :]

    # Compute the loss (only for the last output token).
    loss = torch.nn.functional.cross_entropy(logits, target_batch)

    return loss


def calc_loss_loader(
    data_loader: DataLoader,
    model: GPTModel,
    device: torch.device,
    num_batches: int = None,
) -> float:
    """Compute the loss for a data loader.

    Args:
        data_loader: The data loader.
        model: The model.
        device: The device.
        num_batches: The number of batches to compute the loss on.

    Returns:
        The loss for the data loader.
    """
    # Initialize the total loss.
    total_loss = 0.0

    # If the data loader is empty, return NaN.
    if len(data_loader) == 0:
        return float("nan")
    # If the number of batches is not specified, use all batches.
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))

    # Iterate over the data loader.
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break

    return total_loss / num_batches
# Compute the loss for the training, validation, and test sets.

# Disables gradient tracking for efficiency because we are not training yet
with torch.no_grad():
    train_loss = calc_loss_loader(train_loader, gpt, device, num_batches=5)
    val_loss = calc_loss_loader(val_loader, gpt, device, num_batches=5)
    test_loss = calc_loss_loader(test_loader, gpt, device, num_batches=5)

print(f"Training loss: {train_loss:.3f}")
print(f"Validation loss: {val_loss:.3f}")
print(f"Test loss: {test_loss:.3f}")

Stage 3 - Model line-tuning and usage

Fine-tuning on supervised data
def evaluate_model(
    model: GPTModel,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    eval_iter: int,
) -> tuple[float, float]:
    """Evaluate a model on the training and validation sets.

    Args:
        model: The model to evaluate.
        train_loader: The training data loader.
        val_loader: The validation data loader.
        device: The device to evaluate the model on.
        eval_iter: The number of iterations between evaluations.

    Returns:
        train_loss: The training loss.
        val_loss: The validation loss.
    """
    # Set the model to evaluation mode.
    model.eval()

    # Compute the loss for the training and validation sets.
    with torch.no_grad():
        train_loss = calc_loss_loader(
            train_loader, model, device, num_batches=eval_iter
        )
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)

    # Reset the model to training mode.
    model.train()

    return train_loss, val_loss


def train_classifier_simple(
    model: GPTModel,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    num_epochs: int,
    eval_freq: int,
    eval_iter: int,
) -> tuple[list[float], list[float], list[float], list[float], int, int]:
    """Train a classifier model.

    Args:
        model: The model to train.
        train_loader: The training data loader.
        val_loader: The validation data loader.
        optimizer: The optimizer.
        device: The device to train the model on.
        num_epochs: The number of epochs to train the model.
        eval_freq: The frequency of evaluation.
        eval_iter: The number of iterations between evaluations.

    Returns:
        train_losses: The training losses.
        val_losses: The validation losses.
        train_accs: The training accuracies.
        val_accs: The validation accuracies.
        examples_seen: The number of examples seen.
    """
    # Initialize the lists for the training and validation losses and accuracies.
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    # Main training loop.
    for epoch in range(num_epochs):
        # Sets model to training mode (to enable gradient tracking, drop out, etc.).
        model.train()
        for input_batch, target_batch in train_loader:
            # Resets loss gradients from the previous batch iteration.
            optimizer.zero_grad()

            # Computes the loss for the current batch.
            loss = calc_loss_batch(input_batch, target_batch, model, device)

            # Calculates loss gradients.
            loss.backward()

            # Updates the model parameters using the computed loss gradients.
            optimizer.step()

            # Updates the number of examples seen and the global step.
            examples_seen += input_batch.shape[0]
            global_step += 1

            # Optional evaluation step.
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    device=device,
                    eval_iter=eval_iter,
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(
                    f"Ep {epoch+1} (Step {global_step:06d}): "
                    f"Train loss {train_loss:.3f}, "
                    f"Val loss {val_loss:.3f}"
                )

        # Calculates accuracy after each epoch
        train_accuracy = calc_accuracy_loader(
            train_loader, model, device, num_batches=eval_iter
        )
        val_accuracy = calc_accuracy_loader(
            val_loader, model, device, num_batches=eval_iter
        )
        print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)

    return train_losses, val_losses, train_accs, val_accs, examples_seen
import time

# Set the random seed and track training time.
start_time = time.time()
torch.manual_seed(123)

# Initialize the optimizer.
optimizer = torch.optim.AdamW(gpt.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 5

# Train the model.
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
    model=gpt,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=device,
    num_epochs=num_epochs,
    eval_freq=50,
    eval_iter=5,
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

Loss visualization

The model’s training and validation loss over the five training epochs. Both the training loss, represented by the solid line, and the validation loss, represented by the dashed line, sharply decline in the first epoch and gradually stabilize toward the fifth epoch. This pattern indicates good learning progress and suggests that the model learned from the training data while generalizing well to the unseen validation data.

# Plot the training and validation losses.
import matplotlib.pyplot as plt


def plot_values(
    epochs_seen: torch.Tensor,
    examples_seen: torch.Tensor,
    train_values: list[float],
    val_values: list[float],
    label: str = "loss",
):
    """Plot the training and validation losses.

    Args:
        epochs_seen: The number of epochs seen.
        examples_seen: The number of examples seen.
        train_values: The training values.
        val_values: The validation values.
        label: The label for the plot.
    """
    # Create the plot.
    fig, ax1 = plt.subplots(figsize=(8, 6))

    # Plot the training and validation losses against the epochs.
    ax1.plot(epochs_seen, train_values, label=f"Training {label}")
    ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")

    # Set the x-axis label.
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel(label.capitalize())
    ax1.legend()

    # Creates a second x-axis for examples seen
    ax2 = ax1.twiny()

    # Invisible plot for aligning ticks
    ax2.plot(examples_seen, train_values, alpha=0)
    ax2.set_xlabel("Examples seen")

    # Adjusts layout to make room.
    fig.tight_layout()

    # Save the plot.
    plt.savefig(f"{label}-plot.pdf")

    # Show the plot.
    plt.show()


# Create the epochs and examples seen tensors.
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))

# Plot the values.
plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)

Classification accuracy plot

Both the training accuracy (solid line) and the validation accuracy (dashed line) increase substantially in the early epochs and then plateau, achieving almost perfect accuracy scores of 1.0. The close proximity of the two lines throughout the epochs suggests that the model does not overfit the training data very much.

epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))
examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))

plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")

Performance metrics on all data sets

The training and test set performances are almost identical. The slight discrepancy between the training and test set accuracies suggests minimal overfitting of the training data. Typically, the validation set accuracy is somewhat higher than the test set accuracy because the model development often involves tuning hyperparameters to perform well on the validation set, which might not generalize as effectively to the test set. This situation is common, but the gap could potentially be minimized by adjusting the model’s settings, such as increasing the dropout rate (drop_rate) or the weight_decay parameter in the optimizer configuration.

train_accuracy = calc_accuracy_loader(train_loader, gpt, device)
val_accuracy = calc_accuracy_loader(val_loader, gpt, device)
test_accuracy = calc_accuracy_loader(test_loader, gpt, device)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Using the model for classification

Model usage
def classify_review(
    text: str,
    model: GPTModel,
    tokenizer: tiktoken.Encoding,
    device: torch.device,
    max_length: Optional[int] = None,
    pad_token_id: int = 50256,
):
    """Classify a review using a fine-tuned GPT model.

    Args:
        text: The review to classify.
        model: The fine-tuned GPT model.
        tokenizer: The tokenizer.
        device: The device to classify the review on.
        max_length: The maximum length of the review.
        pad_token_id: The padding token ID (defaults to the end-of-text token).

    Returns:
        The predicted label.
    """
    model.eval()
    # Prepares inputs to the model
    input_ids = tokenizer.encode(text)

    # Determine the maximum supported context length.
    supported_context_length = model.pos_emb.weight.shape[1]

    # Truncates sequences if they are too long
    input_ids = input_ids[: min(max_length, supported_context_length)]

    # Determine the maximum sequence length.
    max_length = max_length if max_length is not None else supported_context_length

    # Pad sequences to the longest sequence length.
    input_ids += [pad_token_id] * (max_length - len(input_ids))

    # Add a batch dimension.
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

    # Model inference without gradient tracking.
    with torch.no_grad():
        # Logits for the last output token.
        logits = model(input_tensor)[:, -1, :]

    predicted_label = torch.argmax(logits, dim=-1).item()

    # Return the predicted label.
    return "spam" if predicted_label == 1 else "not spam"
# Try classifying some examples.
text_1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award."
)
print(
    classify_review(
        text=text_1,
        model=gpt,
        tokenizer=tokenizer,
        device=device,
        max_length=train_dataset.max_length,
    )
)

text_2 = (
    "Hey, just wanted to check if we're still on" " for dinner tonight? Let me know!"
)
print(
    classify_review(
        text=text_2,
        model=gpt,
        tokenizer=tokenizer,
        device=device,
        max_length=train_dataset.max_length,
    )
)

Save the model checkpoint to disk

# Save the model checkpoint to disk.
torch.save(gpt.state_dict(), "review_classifier.pth")
# Load the model checkpoint from disk.
model_state_dict = torch.load("review_classifier.pth", map_location=device)
gpt.load_state_dict(model_state_dict)