The full guide to creating custom datasets and dataloaders for different models in PyTorch
Before you can build a machine learning model, you need to load your data into a dataset. Luckily, PyTorch has many commands to help with this entire process (if you are not familiar with PyTorch I recommend refreshing on the basics here).
PyTorch has good documentation to help with this process, but I have not found any comprehensive documentation or tutorials towards custom datasets. I’m first going to start with creating basic premade datasets and then work my way up to creating datasets from scratch for different models!
Before we dive into code for different use cases, let’s understand the difference between the two terms. Generally, you first create your dataset and then create a dataloader. A dataset contains the features and labels from each data point that will be fed into the model. A dataloader is a custom PyTorch iterable that makes it easy to load data with added features.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
The most common arguments in the dataloader are batch_size, shuffle (usually only for the training data), num_workers (to multi-process loading the data), and pin_memory (to put the fetched data Tensors in pinned memory and enable faster data transfer to CUDA-enabled GPUs).
It is recommended to set pin_memory = True instead of specifying num_workers due to multiprocessing complications with CUDA.
In the case that your dataset is downloaded from online or locally, it will be extremely simple to create the dataset. I think PyTorch has good documentation on this, so I will be brief.
If you know the dataset is either from PyTorch or PyTorch-compatible, simply call the necessary imports and the dataset of choice:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms imports ToTensordata = torchvision.datasets.CIFAR10('path', train=True, transform=ToTensor())
Each dataset will have unique arguments to pass into it (found here). In general, it will be the path the dataset is stored at, a boolean indicating if it needs to be downloaded or not (conveniently called download), whether it is training or testing, and if transforms need to be applied.
I dropped in that transforms can be applied to a dataset at the end of the last section, but what actually is a transform?
A transform is a method of manipulating data for preprocessing an image. There are many different facets to transforms. The most common transform, ToTensor(), will convert the dataset to tensors (needed to input into any model). Other transforms built into PyTorch (torchvision.transforms) include flipping, rotating, cropping, normalizing, and shifting images. These are typically used so the model can generalize better and doesn’t overfit to the training data. Data augmentations can also be used to artificially increase the size of the dataset if needed.
Beware most torchvision transforms only accept Pillow image or tensor formats (not numpy). To convert, simply use
To convert from numpy, either create a torch tensor or use the following:
From PIL import Image
# assume arr is a numpy array
# you may need to normalize and cast arr to np.uint8 depending on format
img = Image.fromarray(arr)
Transforms can be applied simultaneously using torchvision.transforms.compose. You can combine as many transforms as needed for the dataset. An example is shown below:
import torchvision.transforms.Composedataset_transform = transforms.Compose([
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Be sure to pass the saved transform as an argument into the dataset for it to be applied in the dataloader.
In most cases of developing your own model, you will need a custom dataset. A common use case would be transfer learning to apply your own dataset on a pretrained model.
There are 3 required parts to a PyTorch dataset class: initialization, length, and retrieving an element.
__init__: To initialize the dataset, pass in the raw and labeled data. The best practice is to pass in the raw image data and labeled data separately.
__len__: Return the length of the dataset. Before creating the dataset, the raw and labeled data should be checked to be the same size.
__getitem__: This is where all the data handling occurs to return a given index (idx) of the raw and labeled data. If any transforms need to be applied, the data must be converted to a tensor and transformed. If the initialization contained a path to the dataset, the path must be opened and data accessed/preprocessed before it can be returned.
Example dataset for a semantic segmentation model:
from torch.utils.data import Dataset
from torchvision import transformsclass ExampleDataset(Dataset):
"""Example dataset"""
def __init__(self, raw_img, data_mask, transform=None):
self.raw_img = raw_img
self.data_mask = data_mask
self.transform = transform
def __len__(self):
return len(self.raw_img)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.raw_img[idx]
mask = self.data_mask[idx]
sample = {'image': image, 'mask': mask}
if self.transform:
sample = self.transform(sample)
return sample