
Debugging PyTorch Machine Learning Models: A Step-by-Step Guide
Image by Editor | Midjourney
Introduction
Debugging machine learning models entails inspecting, discovering, and fixing possible errors in the internal mechanisms of these models. As important as debugging a machine learning model is to ensure it works correctly and efficiently, debugging is often challenging. Fortunately, this article is here to help by walking you through the steps to debug machine learning models written in Python using PyTorch library.
To illustrate how to debug PyTorch machine learning models, we will consider a simple neural network model for classification, concretely for recognizing (classifying) handwritten digits from 0 to 9, using the well-known MNIST dataset.
Prepration
First, we ensure PyTorch and other necessary dependencies are installed and imported.
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader |
Aided by PyTorch’s nn
package for building neural network models, concretely via the nn.Module
class, we will define a quite simple neural network architecture. Building a neural network in PyTorch involves establishing its architecture in the constructor __init__
method and overriding the forward
method to define activation functions and other calculations performed over the data as they pass through the layers of the neural network.
class SimpleNN(nn.Module):     def __init__(self):         super(SimpleNN, self).__init__()         self.fc1 = nn.Linear(28*28, 128)         self.fc2 = nn.Linear(128, 10)         def forward(self, x):         x = x.view(–1, 28*28)  # Flatten the input         x = F.relu(self.fc1(x))         x = self.fc2(x)         return x |
The neural network we just built has two fully connected linear layers, with a ReLU (rectified linear unit) activation function in between. The first layer flattens the original data consisting of 28×28 pixel handwritten digit images into arrays of 128 features: one per pixel. The output layer has 10 neurons, one for each possible classification output: remember, we are classifying images into one out of 10 possible classes.
Next, we load the MNIST dataset. This is an easy endeavor, since PyTorch’s torchvision
package provides it as one of its built-in sample datasets, so no need to obtain it from an external source. As part of the process to load the data, we need to ensure it is stored as a tensor, which is the data structure internally managed by PyTorch models.
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) train_dataset = datasets.MNIST(root=‘./data’, train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
Next, we initialize the model calling the function defined earlier, establish the optimization criterion or loss function to guide the training process upon the data, and also choose the Adam optimizer for further guiding this process, with a moderate learning rate of 0.001.
model = SimpleNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) |
Step-by-Step Debugging
Now, assuming we suspect something might be wrong with the model (it is not, just supposing!), let’s get into the core of debugging steps. The first is simple, printing the model itself to ensure it is correctly defined.
Output:
SimpleNN( Â Â (fc1): Linear(in_features=784, out_features=128, bias=True) Â Â (fc2): Linear(in_features=128, out_features=10, bias=True) ) |
That looked right. Next, let’s inspect the shape of the data (input images and output labels) by using this instruction:
for images, labels in train_loader: Â Â Â Â print(“Input batch shape:”, images.shape) Â Â Â Â print(“Labels batch shape:”, labels.shape) Â Â Â Â break |
Output:
Input batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64]) |
Since we earlier specified a batch size of 64, this also looks like it makes sense.
The next natural step in debugging is checking whether the outputs produced by the model have no errors. This process is called forward pass debugging, and it can be performed by using the train_loader
instance where we loaded the dataset earlier, as follows:
images, labels = next(iter(train_loader)) outputs = model(images) print(“Output shape:”, outputs.shape) |
If no errors are raised, the output per data batch should look like:
Output shape: torch.Size([64, 10]) |
A common cause for a machine learning model to malfunction is that the training process is unstable, in which case it is common that training loss values become NaN
or infinity. A way to check this is through this code, which will raise no output message if such a problem does not appear to exist.
def check_nan(tensor, name): Â Â Â Â if torch.isnan(tensor).any(): Â Â Â Â Â Â Â Â print(f“Warning: NaN detected in name”) Â Â Â Â if torch.isinf(tensor).any(): Â Â Â Â Â Â Â Â print(f“Warning: Inf detected in name”) Â for param in model.parameters(): Â Â Â Â check_nan(param, “Model Parameter”) |
Finally, for more in-depth debugging, here’s a debug training loop that monitors loss and gradients during the training process.
for epoch in range(1): Â Â Â Â for images, labels in train_loader: Â Â Â Â Â Â Â Â optimizer.zero_grad() Â Â Â Â Â Â Â Â outputs = model(images) Â Â Â Â Â Â Â Â loss = criterion(outputs, labels) Â Â Â Â Â Â Â Â loss.backward() Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â for name, param in model.named_parameters(): Â Â Â Â Â Â Â Â Â Â Â Â if param.grad is not None: Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â print(f“Gradient for name: param.grad.norm()”) Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â optimizer.step() Â Â Â Â Â Â Â Â print(“Loss:”, loss.item()) Â Â Â Â Â Â Â Â break |
The steps involved here included:
- Clearing old gradients to prevent cumulations
- Applying a forward pass to get model predictions
- Computing loss, given by the deviation between predictions and actual labels (ground-truth)
- Backward pass: computing gradients for backpropagation and later adjustment of neural network weights
- Gradient norms per layer are also printed to identify issues like exploding and vanishing gradients
- The weights or parameters get updated by using
step()
- Monitoring loss: the final print instruction helps track model performance over iterations
Wrapping Up
This article provided, through a neural network-based example, a set of steps and resources to consider for machine learning model debugging in PyTorch. Applying these debugging methods can sometimes become a model life-saver, helping identify issues that would otherwise be hard to spot.