# Install import-ipynb for importing ipynb files.
# %pip install import-ipynb
Fine-tuning for classification
This notebook explores the fine-tuning process of LLMs with the purpose of creating a classification model based on Sebastian Raschka’s book (Chapter 6). In particular, it discusses the following:
- Introducing different LLM fine-tuning approaches
- Preparing a dataset for text classification
- Modifying a pretrained LLM for fine-tuning
- Fine-tuning an LLM to identify spam messages
- Evaluating the accuracy of a fine-tuned LLM classifier
- Using a fine-tuned LLM to classify new data
Instruction fine-tuning
- Instruction-tuned models can typically handle a broader range of tasks
- More general approach that can handle multiple tasks
- Best suited for models that need to handle a variety of tasks based on complex user instructions
- These models improve flexibility and interaction quality
- Instruction fine-tuning requires larger datasets and greater computational resources
Classification fine-tuning
- Ideal for projects requiring precise categorization into predefined classes (e.g. sentiment analysis or spam detection)
- Specialized approach targeted at outputting a specific set of labels
- The model is restricted to only the labels encountered during training
- Requires less data and compute power
Acknowledgment
All concepts, architectures, and implementation approaches are credited to Sebastian Raschka’s work. This repository serves as my personal implementation and notes while working through the book’s content.
Resources
from typing import Optional, Tuple
import urllib.request
import zipfile
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tiktoken
from tqdm.notebook import tqdm
import pandas as pd
# Import previous chapter dependencies.
# See https://stackoverflow.com/questions/44116194/import-a-function-from-another-ipynb-file
# NOTE: Importing these functions seems to run the entire cell the symbol is defined in, which would
# suggest that symbols should be defined in separate cells from the test code.
import import_ipynb
from gpt_download import download_and_load_gpt2
from chapter_02_dataset_creation import create_dataloader_v1
from chapter_04_gpt_from_scratch import (
GPTConfig,
GPTModel,
generate_text_simple,
)
# NOTE: Importing another ipynb file basically runs the entire imported notebook.
from chapter_05_pretraining_on_unlabeled_data import (
generate,
token_ids_to_text,
text_to_token_ids,
)
# Define the base config.
= GPTConfig(
GPT_CONFIG_124M =50257, # as used by the BPE tokenizer for GPT-2.
vocab_size=1024,
context_length=768,
emb_dim=12,
n_heads=12,
n_layers=0.0, # disable dropout for inference
dropout_rate=False,
qkv_bias
)
# Determine the device to run the model on.
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device print(f"Using device: {device}")
Stage 1: Preparing the dataset
This section follows stage 1 in the following figure:
Download the dataset
= "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
url = Path("data/sms_spam_collection.zip")
zip_path = Path("data/sms_spam_collection")
extracted_path = extracted_path / "SMSSpamCollection.tsv"
data_file_path
def download_and_unzip_spam_data(
str, zip_path: Path, extracted_path: Path, data_file_path: Path
url:
):"""Download and unzip the spam data from the UCI repository.
Args:
url: The URL of the zip file.
zip_path: The path to save the zip file.
extracted_path: The path to save the extracted files.
data_file_path: The path to save the data file.
"""
# Check if the file already exists.
if data_file_path.exists():
print(f"{data_file_path} already exists. Skipping download " "and extraction.")
return
# Download the zip file.
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
# Extract the zip file.
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extracted_path)
# Add a .tsv extension to the file (tab-separated values).
= Path(extracted_path) / "SMSSpamCollection"
original_file_path
os.rename(original_file_path, data_file_path)print(f"File downloaded and saved as {data_file_path}")
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
Load dataset
# Load data into a pandas DataFrame.
= pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
df
# Show the label count.
print(df["Label"].value_counts())
# Show a few examples.
df.head()
Balancing the dataset
# Create a balanced dataset by undersampling the majority class.
def create_balanced_dataset(df: pd.DataFrame) -> pd.DataFrame:
"""Create a balanced dataset by undersampling the majority class.
NOTE: This function can quite significantly reduce the size of the dataset.
Args:
df: The input DataFrame.
Returns:
A balanced DataFrame.
"""
= len(df[df["Label"] == "spam"])
num_spam = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
ham_subset = pd.concat([ham_subset, df[df["Label"] == "spam"]])
balanced_df return balanced_df.reset_index(drop=True)
# Create a balanced dataset.
= create_balanced_dataset(df)
balanced_df print(balanced_df["Label"].value_counts())
# Convert string labels to integers.
"Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
balanced_df[ balanced_df
Splitting the datast
def random_split(
float, validation_frac: float
df: pd.DataFrame, train_frac: -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
) """Split a DataFrame into train, validation, and test sets.
NOTE: The size of the test set is implied to be the remainder of train and validation fraction
(all fractions should add up to 1).
Args:
df: The input DataFrame.
train_frac: The fraction of the dataset to use for training.
validation_frac: The fraction of the dataset to use for validation.
Returns:
A tuple of DataFrames for train, validation, and test sets.
"""
# Shuffle the entire DataFrame
= df.sample(frac=1, random_state=123).reset_index(drop=True)
df
# Calculate split indices (for train and validation explicitly.)
= int(len(df) * train_frac)
train_end = train_end + int(len(df) * validation_frac)
validation_end
# Split the DataFrame.
= df[:train_end]
train_df = df[train_end:validation_end]
validation_df = df[validation_end:]
test_df
return train_df, validation_df, test_df
# Test size is implied to be 0.2 as the remainder.
= random_split(
train_df, validation_df, test_df =balanced_df, train_frac=0.7, validation_frac=0.1
df
)print(f"Train set size: {len(train_df)}")
print(f"Validation set size: {len(validation_df)}")
print(f"Test set size: {len(test_df)}")
# Save the DataFrames to CSV files.
/ "train.csv", index=None)
train_df.to_csv(extracted_path / "validation.csv", index=None)
validation_df.to_csv(extracted_path / "test.csv", index=None) test_df.to_csv(extracted_path
Creating the datasets
Previously, we utilized a sliding window technique to generate uniformly sized text chunks, which we then grouped into batches for more efficient model training. Each chunk functioned as an individual training instance. However, we are now working with a spam dataset that contains text messages of varying lengths. To batch these messages as we did with the text chunks, we have two primary options:
- Truncate all messages to the length of the shortest message in the dataset or batch.
- Pad all messages to the length of the longest message in the dataset or batch.
The first option is computationally cheaper, but it may result in significant information loss if shorter messages are much smaller than the average or longest messages, potentially reducing model performance. So, we opt for the second option, which preserves the entire content of all messages.
To implement batching, where all messages are padded to the length of the longest message in the dataset, we add padding tokens to all shorter messages. For this purpose, we use “<|endoftext|>” as a padding token. However, instead of appending the string “<|endoftext|>” to each of the text messages directly, we can add the token ID corresponding to “<|endoftext|>” to the encoded text messages
The example below shows what a training batch looks like. A single training batch consisting of eight text messages represented as token IDs. Each text message consists of 120 token IDs. A class label array stores the eight class labels corresponding to the text messages, which can be either 0 (“not spam”) or 1 (“spam”).
import tiktoken
= tiktoken.get_encoding("gpt2")
tokenizer print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))
# Creating a Dataset class.
class SpamDataset(Dataset):
"""Dataset class for the spam dataset."""
def __init__(
self,
csv_file: Path,
tokenizer: tiktoken.Encoding,int] = None,
max_length: Optional[int = 50256,
pad_token_id:
):"""
Initializes the SpamDataset class.
Args:
csv_file: The path to the CSV file containing the data.
tokenizer: The tokenizer to use.
max_length: The maximum length of the encoded texts.
pad_token_id: The ID of the padding token.
"""
# Load the data from the CSV file.
self.data = pd.read_csv(csv_file)
# Pretokenize all texts.
self.encoded_texts = [tokenizer.encode(text) for text in self.data["Text"]]
if max_length is None:
# If no maximum length is provided, use the longest encoded text.
self.max_length = self._longest_encoded_length()
else:
# Truncate sequences if they are longer than max_length.
self.max_length = max_length
self.encoded_texts = [
self.max_length] for encoded_text in self.encoded_texts
encoded_text[:
]
# Pads sequences to the longest sequence
self.encoded_texts = [
+ [pad_token_id] * (self.max_length - len(encoded_text))
encoded_text for encoded_text in self.encoded_texts
]
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
= self.encoded_texts[index]
encoded = self.data.iloc[index]["Label"]
label return (
=torch.long),
torch.tensor(encoded, dtype=torch.long),
torch.tensor(label, dtype
)
def __len__(self) -> int:
return len(self.data)
def _longest_encoded_length(self) -> int:
"""Determine the longest encoded text length."""
return max(len(encoded_text) for encoded_text in self.encoded_texts)
# Load the training dataset.
= SpamDataset(
train_dataset =extracted_path / "train.csv", max_length=None, tokenizer=tokenizer
csv_file
)
# Load the validation and test sets and limit the max length to the same value as the training set.
# NOTE: Importantly, any validation and test set samples exceeding the length of the longest
# training example are truncated using encoded_text[:self.max_length] in the SpamDataset code
# we defined earlier. This truncation is optional; you can set max_length=None for both
# validation and test sets, provided there are no sequences exceeding 1,024 tokens in these
# sets.
= SpamDataset(
val_dataset =extracted_path / "validation.csv",
csv_file=None,
max_length=tokenizer,
tokenizer
)= SpamDataset(
test_dataset =extracted_path / "test.csv",
csv_file=None,
max_length=tokenizer,
tokenizer
)
# Show the maximum length of the encoded texts.
print(f"Maximum length of the encoded texts: {train_dataset.max_length}")
print(f"Maximum length of the encoded texts: {val_dataset.max_length}")
print(f"Maximum length of the encoded texts: {test_dataset.max_length}")
# Verify that the maximum length does not exceed the context length.
assert train_dataset.max_length <= GPT_CONFIG_124M.context_length
assert val_dataset.max_length <= GPT_CONFIG_124M.context_length
assert test_dataset.max_length <= GPT_CONFIG_124M.context_length
Creating the data loaders
# This num_worker setting ensures compatibility with most computers.
= 0
num_workers = 8
batch_size 123)
torch.manual_seed(
= DataLoader(
train_loader =train_dataset,
dataset=batch_size,
batch_size=True,
shuffle=num_workers,
num_workers=True,
drop_last
)= DataLoader(
val_loader =val_dataset,
dataset=batch_size,
batch_size=num_workers,
num_workers=False,
drop_last
)= DataLoader(
test_loader =test_dataset,
dataset=batch_size,
batch_size=num_workers,
num_workers=False,
drop_last
)
# Show the size of the data loaders.
print(f"Train set size: {len(train_loader)}")
print(f"Validation set size: {len(val_loader)}")
print(f"Test set size: {len(test_loader)}")
# Show the first batch of the training set.
# NOTE: As we can see, the input batches consist of eight training examples with 120 tokens each,
# as expected. The label tensor stores the class labels corresponding to the eight training
# examples.
print("\nFirst training batch:")
for input_batch, target_batch in train_loader:
print("Input batch dimensions:", input_batch.shape)
print("Label batch dimensions", target_batch.shape)
break
Stage 2: Model Setup
Initializing a model with pretrained weights
import dataclasses
# Load the base config.
= GPTConfig(
GPT_CONFIG_124M =50257, # as used by the BPE tokenizer for GPT-2.
vocab_size=1024,
context_length=768,
emb_dim=12,
n_heads=12,
n_layers=0.0, # disable dropout for inference
dropout_rate=False,
qkv_bias
)
# Update the model configuration to conform to the model size.
= {
model_configs "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
# Instantiate a base config.
= dataclasses.asdict(GPT_CONFIG_124M)
tmp_config
# Load the overlay parameters.
= "gpt2-small (124M)"
model_name
tmp_config.update(model_configs[model_name])
# Update the context length to match OpenAI's GPT-2 models.
"context_length": 1024})
tmp_config.update({
# OpenAI used bias vectors in the multi-head attention module’s linear layers to implement the
# query, key, and value matrix computations. Bias vectors are not commonly used in LLMs anymore as
# they don’t improve the modeling performance and are thus unnecessary. However, since we are
# working with pretrained weights, we need to match the settings for consistency and enable these
# bias vectors.
"qkv_bias": True})
tmp_config.update({
# Instantiate the new configuration.
= GPTConfig(**tmp_config)
NEW_CONFIG
# Initialize the model with the new configuration.
= GPTModel(NEW_CONFIG)
gpt eval() gpt.
# NOTE: This code is copied from chapter_05_pretraining_on_unlabeled_data.ipynb because the import
# from load_weights_into_gpt.py is not working.
def assign(left, right):
"""Safely assign the right weight tensor to the left layer.
Checks whether two tensors or arrays (left and right) have the same dimensions or shape and
returns the right tensor as trainable PyTorch parameters.
"""
if left.shape != right.shape:
raise ValueError(f"Shape mismatch. Left: {left.shape}, " "Right: {right.shape}")
return torch.nn.Parameter(torch.tensor(right))
def load_weights_into_gpt(gpt: GPTModel, params: dict):
# Sets the model’s positional and token embedding weights to those specified in params.
= assign(gpt.pos_emb.weight, params["wpe"])
gpt.pos_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
gpt.tok_emb.weight
# Iterates over each transformer block in the model.
for b in range(len(params["blocks"])):
# The np.split function is used to divide the attention and bias weights into three equal
# parts for the query, key, and value components.
= np.split(
q_w, k_w, v_w "blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1
(params[
)= assign(
gpt.trf_blocks[b].mha.W_q.weight
gpt.trf_blocks[b].mha.W_q.weight, q_w.T
)= assign(
gpt.trf_blocks[b].mha.W_k.weight
gpt.trf_blocks[b].mha.W_k.weight, k_w.T
)= assign(
gpt.trf_blocks[b].mha.W_v.weight
gpt.trf_blocks[b].mha.W_v.weight, v_w.T
)= np.split(
q_b, k_b, v_b "blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1
(params[
)= assign(gpt.trf_blocks[b].mha.W_q.bias, q_b)
gpt.trf_blocks[b].mha.W_q.bias = assign(gpt.trf_blocks[b].mha.W_k.bias, k_b)
gpt.trf_blocks[b].mha.W_k.bias = assign(gpt.trf_blocks[b].mha.W_v.bias, v_b)
gpt.trf_blocks[b].mha.W_v.bias = assign(
gpt.trf_blocks[b].mha.out_proj.weight
gpt.trf_blocks[b].mha.out_proj.weight,"blocks"][b]["attn"]["c_proj"]["w"].T,
params[
)= assign(
gpt.trf_blocks[b].mha.out_proj.bias
gpt.trf_blocks[b].mha.out_proj.bias,"blocks"][b]["attn"]["c_proj"]["b"],
params[
)0].weight = assign(
gpt.trf_blocks[b].ff.layers[0].weight,
gpt.trf_blocks[b].ff.layers["blocks"][b]["mlp"]["c_fc"]["w"].T,
params[
)0].bias = assign(
gpt.trf_blocks[b].ff.layers[0].bias, params["blocks"][b]["mlp"]["c_fc"]["b"]
gpt.trf_blocks[b].ff.layers[
)2].weight = assign(
gpt.trf_blocks[b].ff.layers[2].weight,
gpt.trf_blocks[b].ff.layers["blocks"][b]["mlp"]["c_proj"]["w"].T,
params[
)2].bias = assign(
gpt.trf_blocks[b].ff.layers[2].bias,
gpt.trf_blocks[b].ff.layers["blocks"][b]["mlp"]["c_proj"]["b"],
params[
)= assign(
gpt.trf_blocks[b].pre_attention_norm.scale "blocks"][b]["ln_1"]["g"]
gpt.trf_blocks[b].pre_attention_norm.scale, params[
)= assign(
gpt.trf_blocks[b].pre_attention_norm.shift "blocks"][b]["ln_1"]["b"]
gpt.trf_blocks[b].pre_attention_norm.shift, params[
)= assign(
gpt.trf_blocks[b].pre_ff_norm.scale "blocks"][b]["ln_2"]["g"]
gpt.trf_blocks[b].pre_ff_norm.scale, params[
)= assign(
gpt.trf_blocks[b].pre_ff_norm.shift "blocks"][b]["ln_2"]["b"]
gpt.trf_blocks[b].pre_ff_norm.shift, params[
)
= assign(gpt.final_norm.scale, params["g"])
gpt.final_norm.scale = assign(gpt.final_norm.shift, params["b"])
gpt.final_norm.shift
# The original GPT-2 model by OpenAI reused the token embedding weights in the output layer
# to reduce the total number of parameters, which is a concept known as weight tying.
= assign(gpt.out_head.weight, params["wte"]) gpt.out_head.weight
# Download the GPT-2 weights.
= download_and_load_gpt2(model_size="124M", models_dir="gpt2")
settings, params
# Load the weights into the model.
load_weights_into_gpt(gpt, params) gpt.to(device)
# Test the model to verify that it can generate coherent text.
= "Every effort moves you"
text_1 = generate_text_simple(
token_ids =gpt.to(device),
model=text_to_token_ids(text_1, tokenizer).to(device),
idx=15,
max_new_tokens=NEW_CONFIG.context_length,
context_size
)print(token_ids_to_text(token_ids, tokenizer))
# Check if the model is already capable of classifying spam and ham messages via instruction
# examples.
= (
text_2 "Is the following text 'spam'? Answer with 'yes' or 'no':"
" 'You are a winner you have been specially"
" selected to receive $1000 cash or a $2000 award.'"
)= generate_text_simple(
token_ids =gpt.to(device),
model=text_to_token_ids(text_2, tokenizer).to(device),
idx=23,
max_new_tokens=NEW_CONFIG.context_length,
context_size
)print(token_ids_to_text(token_ids, tokenizer))
Modify model for fine-tuning (adding a classification head)
Adapting a GPT model for spam classification by altering its architecture. Initially, the model’s linear output layer mapped 768 hidden units to a vocabulary of 50,257 tokens. To detect spam, we replace this layer with a new output layer that maps the same 768 hidden units to just two classes, representing “spam” and “not spam.”
Fine-tuning selected layers vs. all layers
Since we start with a pretrained model, it’s not necessary to fine-tune all model layers. In neural network-based language models, the lower layers generally capture basic language structures and semantics applicable across a wide range of tasks and datasets. So, fine-tuning only the last layers (i.e., layers near the output), which are more specific to nuanced linguistic patterns and task-specific features, is often sufficient to adapt the model to new tasks. A nice side effect is that it is computationally more efficient to fine-tune only a small number of layers.
# Prepare the model for fine-tuning.
# 1. Freeze all parameters in the model.
for param in gpt.parameters():
= False
param.requires_grad
# 2. Replace the final linear layer with a new one for the two classes.
123)
torch.manual_seed(= 2
num_classes = nn.Linear(GPT_CONFIG_124M.emb_dim, num_classes)
gpt.out_head
# Mark additional layers as trainable, in particular the last transformer block as well as the
# final layer norm.
for param in gpt.trf_blocks[-1].parameters():
= True
param.requires_grad for param in gpt.final_norm.parameters():
= True param.requires_grad
# Try running the model with a random input to see that it is working.
= "Do you have time"
inputs_str = tokenizer.encode(inputs_str)
inputs = torch.tensor(inputs).unsqueeze(0)
inputs print("Inputs:", inputs_str)
print("Inputs dimensions:", inputs.shape) # B x T, i.e. batch size x sequence length
with torch.no_grad():
= gpt.to(device)(inputs.to(device))
outputs
# NOTE: The output shape is B x T x 2, i.e. batch size x sequence length x number of classes.
# The model produces logits for each class and for each token in the input sequence.
# NOTE: We are interested in fine-tuning this model to return a class label indicating whether a
# model input is “spam” or “not spam.” We don’t need to fine-tune all four output rows;
# instead, we can focus on a single output token. In particular, we will focus on the last
# row corresponding to the last output token.
print("Outputs:\n", outputs)
print(
"Outputs dimensions:", outputs.shape
# B x T x 2, i.e. batch size x sequence length x number of classes
) print("Last output token:", outputs[:, -1, :])
Selecting the right output for fine-tuning
To understand why we are particularly interested in the last output token only let’s take a look at the attention mechanism. We have already explored the attention mechanism, which establishes a relationship between each input token and every other input token, and the concept of a causal attention mask. This mask restricts a token’s focus to its current position and the those before it, ensuring that each token can only be influenced by itself and the preceding tokens (as shown below).
The empty cells indicate masked positions due to the causal attention mask, preventing tokens from attending to future tokens. The values in the cells represent attention scores; the last token, time, is the only one that computes attention scores for all preceding tokens.
The last token in a sequence accumulates the most information since it is the only token with access to data from all the previous tokens. Therefore, in our spam classification task, we focus on this last token during the fine-tuning process.
Evaluation utilities
Similar to next token prediction, we use softmax
to compute probabilities for the output logits, in particular, probabilities for each class (spam, not spam) - as shown below.
# Compute the probabilities for the last output token.
= torch.softmax(outputs[:, -1, :], dim=-1)
probas
# Compute the class label.
# NOTE: {"ham": 0, "spam": 1}
= torch.argmax(probas)
label print("Inputs:", inputs_str)
print("Class label:", label.item())
# A utility function for computing classification accuracy for a data loader.
def calc_accuracy_loader(
data_loader: DataLoader,
model: GPTModel,
device: torch.device,int = None,
num_batches: -> float:
) """Compute the accuracy of a model on a data loader.
Args:
data_loader: The data loader to compute the accuracy on.
model: The model to compute the accuracy on.
device: The device to compute the accuracy on.
num_batches: The number of batches to compute the accuracy on. Defaults to None.
Returns:
The accuracy of the model on the data loader.
"""
# Set the model to evaluation mode (to avoid tracking gradients).
eval()
model.
# Initialize the number of correct predictions and the number of examples.
= 0, 0
correct_predictions, num_examples
# If the number of batches is not specified, use all batches.
if num_batches is None:
= len(data_loader)
num_batches else:
= min(num_batches, len(data_loader))
num_batches
# Iterate over the data loader.
for i, (input_batch, target_batch) in enumerate(data_loader):
# If the number of batches has not been reached, compute the accuracy.
if i < num_batches:
= input_batch.to(device)
input_batch = target_batch.to(device)
target_batch
# Compute the logits for the last output token.
with torch.no_grad():
# NOTE: The output shape is B x T x 2, i.e. batch size x sequence length x number
# of classes. Here, we are only interested in the logits for the last output
# token.
= model(input_batch)[:, -1, :]
logits
# Compute the predicted labels.
# NOTE: dim=-1 computes the argmax over the classes.
= torch.argmax(logits, dim=-1)
predicted_labels
# Update the number of examples and the number of correct predictions.
+= predicted_labels.shape[0]
num_examples += (predicted_labels == target_batch).sum().item()
correct_predictions else:
break
return correct_predictions / num_examples
# Compute baseline accuracy for the not yet fine-tuned model.
# Move the model to the device.
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
gpt.to(device)123)
torch.manual_seed(
# Compute the accuracy for the training, validation, and test sets.
= calc_accuracy_loader(train_loader, gpt, device, num_batches=10)
train_accuracy = calc_accuracy_loader(val_loader, gpt, device, num_batches=10)
val_accuracy = calc_accuracy_loader(test_loader, gpt, device, num_batches=10)
test_accuracy
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")
Define the loss function
However, before we begin fine-tuning the model, we must define the loss function we will optimize during training. Our objective is to maximize the spam classification accuracy of the model, which means that the preceding code should output the correct class labels: 0 for non-spam and 1 for spam. Because classification accuracy is not a differentiable function, we use cross-entropy loss as a proxy to maximize accuracy.
def calc_loss_batch(
input_batch: torch.Tensor,
target_batch: torch.Tensor,
model: GPTModel,
device: torch.device,-> torch.Tensor:
) """Compute the loss for a batch of inputs and targets.
Args:
input_batch: The input batch.
target_batch: The target batch.
model: The model.
device: The device.
Returns:
The loss for the batch.
"""
# Move the input and target batches to the device.
= input_batch.to(device)
input_batch = target_batch.to(device)
target_batch
# Compute the logits for the last output token.
= model(input_batch)[:, -1, :]
logits
# Compute the loss (only for the last output token).
= torch.nn.functional.cross_entropy(logits, target_batch)
loss
return loss
def calc_loss_loader(
data_loader: DataLoader,
model: GPTModel,
device: torch.device,int = None,
num_batches: -> float:
) """Compute the loss for a data loader.
Args:
data_loader: The data loader.
model: The model.
device: The device.
num_batches: The number of batches to compute the loss on.
Returns:
The loss for the data loader.
"""
# Initialize the total loss.
= 0.0
total_loss
# If the data loader is empty, return NaN.
if len(data_loader) == 0:
return float("nan")
# If the number of batches is not specified, use all batches.
elif num_batches is None:
= len(data_loader)
num_batches else:
= min(num_batches, len(data_loader))
num_batches
# Iterate over the data loader.
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
= calc_loss_batch(input_batch, target_batch, model, device)
loss += loss.item()
total_loss else:
break
return total_loss / num_batches
# Compute the loss for the training, validation, and test sets.
# Disables gradient tracking for efficiency because we are not training yet
with torch.no_grad():
= calc_loss_loader(train_loader, gpt, device, num_batches=5)
train_loss = calc_loss_loader(val_loader, gpt, device, num_batches=5)
val_loss = calc_loss_loader(test_loader, gpt, device, num_batches=5)
test_loss
print(f"Training loss: {train_loss:.3f}")
print(f"Validation loss: {val_loss:.3f}")
print(f"Test loss: {test_loss:.3f}")
Stage 3 - Model line-tuning and usage
def evaluate_model(
model: GPTModel,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,int,
eval_iter: -> tuple[float, float]:
) """Evaluate a model on the training and validation sets.
Args:
model: The model to evaluate.
train_loader: The training data loader.
val_loader: The validation data loader.
device: The device to evaluate the model on.
eval_iter: The number of iterations between evaluations.
Returns:
train_loss: The training loss.
val_loss: The validation loss.
"""
# Set the model to evaluation mode.
eval()
model.
# Compute the loss for the training and validation sets.
with torch.no_grad():
= calc_loss_loader(
train_loss =eval_iter
train_loader, model, device, num_batches
)= calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
val_loss
# Reset the model to training mode.
model.train()
return train_loss, val_loss
def train_classifier_simple(
model: GPTModel,
train_loader: DataLoader,
val_loader: DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device,int,
num_epochs: int,
eval_freq: int,
eval_iter: -> tuple[list[float], list[float], list[float], list[float], int, int]:
) """Train a classifier model.
Args:
model: The model to train.
train_loader: The training data loader.
val_loader: The validation data loader.
optimizer: The optimizer.
device: The device to train the model on.
num_epochs: The number of epochs to train the model.
eval_freq: The frequency of evaluation.
eval_iter: The number of iterations between evaluations.
Returns:
train_losses: The training losses.
val_losses: The validation losses.
train_accs: The training accuracies.
val_accs: The validation accuracies.
examples_seen: The number of examples seen.
"""
# Initialize the lists for the training and validation losses and accuracies.
= [], [], [], []
train_losses, val_losses, train_accs, val_accs = 0, -1
examples_seen, global_step
# Main training loop.
for epoch in range(num_epochs):
# Sets model to training mode (to enable gradient tracking, drop out, etc.).
model.train()for input_batch, target_batch in train_loader:
# Resets loss gradients from the previous batch iteration.
optimizer.zero_grad()
# Computes the loss for the current batch.
= calc_loss_batch(input_batch, target_batch, model, device)
loss
# Calculates loss gradients.
loss.backward()
# Updates the model parameters using the computed loss gradients.
optimizer.step()
# Updates the number of examples seen and the global step.
+= input_batch.shape[0]
examples_seen += 1
global_step
# Optional evaluation step.
if global_step % eval_freq == 0:
= evaluate_model(
train_loss, val_loss =model,
model=train_loader,
train_loader=val_loader,
val_loader=device,
device=eval_iter,
eval_iter
)
train_losses.append(train_loss)
val_losses.append(val_loss)print(
f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, "
f"Val loss {val_loss:.3f}"
)
# Calculates accuracy after each epoch
= calc_accuracy_loader(
train_accuracy =eval_iter
train_loader, model, device, num_batches
)= calc_accuracy_loader(
val_accuracy =eval_iter
val_loader, model, device, num_batches
)print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
return train_losses, val_losses, train_accs, val_accs, examples_seen
import time
# Set the random seed and track training time.
= time.time()
start_time 123)
torch.manual_seed(
# Initialize the optimizer.
= torch.optim.AdamW(gpt.parameters(), lr=5e-5, weight_decay=0.1)
optimizer = 5
num_epochs
# Train the model.
= train_classifier_simple(
train_losses, val_losses, train_accs, val_accs, examples_seen =gpt,
model=train_loader,
train_loader=val_loader,
val_loader=optimizer,
optimizer=device,
device=num_epochs,
num_epochs=50,
eval_freq=5,
eval_iter
)
= time.time()
end_time = (end_time - start_time) / 60
execution_time_minutes print(f"Training completed in {execution_time_minutes:.2f} minutes.")
Loss visualization
The model’s training and validation loss over the five training epochs. Both the training loss, represented by the solid line, and the validation loss, represented by the dashed line, sharply decline in the first epoch and gradually stabilize toward the fifth epoch. This pattern indicates good learning progress and suggests that the model learned from the training data while generalizing well to the unseen validation data.
# Plot the training and validation losses.
import matplotlib.pyplot as plt
def plot_values(
epochs_seen: torch.Tensor,
examples_seen: torch.Tensor,list[float],
train_values: list[float],
val_values: str = "loss",
label:
):"""Plot the training and validation losses.
Args:
epochs_seen: The number of epochs seen.
examples_seen: The number of examples seen.
train_values: The training values.
val_values: The validation values.
label: The label for the plot.
"""
# Create the plot.
= plt.subplots(figsize=(8, 6))
fig, ax1
# Plot the training and validation losses against the epochs.
=f"Training {label}")
ax1.plot(epochs_seen, train_values, label="-.", label=f"Validation {label}")
ax1.plot(epochs_seen, val_values, linestyle
# Set the x-axis label.
"Epochs")
ax1.set_xlabel(
ax1.set_ylabel(label.capitalize())
ax1.legend()
# Creates a second x-axis for examples seen
= ax1.twiny()
ax2
# Invisible plot for aligning ticks
=0)
ax2.plot(examples_seen, train_values, alpha"Examples seen")
ax2.set_xlabel(
# Adjusts layout to make room.
fig.tight_layout()
# Save the plot.
f"{label}-plot.pdf")
plt.savefig(
# Show the plot.
plt.show()
# Create the epochs and examples seen tensors.
= torch.linspace(0, num_epochs, len(train_losses))
epochs_tensor = torch.linspace(0, examples_seen, len(train_losses))
examples_seen_tensor
# Plot the values.
plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)
Classification accuracy plot
Both the training accuracy (solid line) and the validation accuracy (dashed line) increase substantially in the early epochs and then plateau, achieving almost perfect accuracy scores of 1.0. The close proximity of the two lines throughout the epochs suggests that the model does not overfit the training data very much.
= torch.linspace(0, num_epochs, len(train_accs))
epochs_tensor = torch.linspace(0, examples_seen, len(train_accs))
examples_seen_tensor
="accuracy") plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label
Performance metrics on all data sets
The training and test set performances are almost identical. The slight discrepancy between the training and test set accuracies suggests minimal overfitting of the training data. Typically, the validation set accuracy is somewhat higher than the test set accuracy because the model development often involves tuning hyperparameters to perform well on the validation set, which might not generalize as effectively to the test set. This situation is common, but the gap could potentially be minimized by adjusting the model’s settings, such as increasing the dropout rate (drop_rate
) or the weight_decay
parameter in the optimizer configuration.
= calc_accuracy_loader(train_loader, gpt, device)
train_accuracy = calc_accuracy_loader(val_loader, gpt, device)
val_accuracy = calc_accuracy_loader(test_loader, gpt, device)
test_accuracy
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")
Using the model for classification
def classify_review(
str,
text:
model: GPTModel,
tokenizer: tiktoken.Encoding,
device: torch.device,int] = None,
max_length: Optional[int = 50256,
pad_token_id:
):"""Classify a review using a fine-tuned GPT model.
Args:
text: The review to classify.
model: The fine-tuned GPT model.
tokenizer: The tokenizer.
device: The device to classify the review on.
max_length: The maximum length of the review.
pad_token_id: The padding token ID (defaults to the end-of-text token).
Returns:
The predicted label.
"""
eval()
model.# Prepares inputs to the model
= tokenizer.encode(text)
input_ids
# Determine the maximum supported context length.
= model.pos_emb.weight.shape[1]
supported_context_length
# Truncates sequences if they are too long
= input_ids[: min(max_length, supported_context_length)]
input_ids
# Determine the maximum sequence length.
= max_length if max_length is not None else supported_context_length
max_length
# Pad sequences to the longest sequence length.
+= [pad_token_id] * (max_length - len(input_ids))
input_ids
# Add a batch dimension.
= torch.tensor(input_ids, device=device).unsqueeze(0)
input_tensor
# Model inference without gradient tracking.
with torch.no_grad():
# Logits for the last output token.
= model(input_tensor)[:, -1, :]
logits
= torch.argmax(logits, dim=-1).item()
predicted_label
# Return the predicted label.
return "spam" if predicted_label == 1 else "not spam"
# Try classifying some examples.
= (
text_1 "You are a winner you have been specially"
" selected to receive $1000 cash or a $2000 award."
)print(
classify_review(=text_1,
text=gpt,
model=tokenizer,
tokenizer=device,
device=train_dataset.max_length,
max_length
)
)
= (
text_2 "Hey, just wanted to check if we're still on" " for dinner tonight? Let me know!"
)print(
classify_review(=text_2,
text=gpt,
model=tokenizer,
tokenizer=device,
device=train_dataset.max_length,
max_length
) )
Save the model checkpoint to disk
# Save the model checkpoint to disk.
"review_classifier.pth") torch.save(gpt.state_dict(),
# Load the model checkpoint from disk.
= torch.load("review_classifier.pth", map_location=device)
model_state_dict gpt.load_state_dict(model_state_dict)