Understanding Generators in PyTorch

Hussain Wali
4 min readMar 17, 2023

PyTorch provides an easy and efficient way to build and train deep learning models. One of the key features of PyTorch is its support for generators, which are a powerful tool for handling large datasets.

In this article, we will provide a beginner-friendly guide to understanding generators in PyTorch. We will cover what generators are, how they work, and why they are useful. We will also include code examples and a real-world example of how generators can be used in practice.

What are Generators?

A generator is a special type of iterator that generates data on-the-fly instead of loading it all into memory at once. This is particularly useful when working with large datasets that cannot fit into memory, as it allows us to process the data in smaller batches.

In PyTorch, generators are implemented using the DataLoader class. The DataLoader class takes a dataset and a batch size as input, and returns an iterator that generates batches of data.

How do Generators Work?

Generators work by generating data on-the-fly as it is needed. When we create a generator using the DataLoader class, PyTorch automatically creates a background process that loads the data from disk and preprocesses it as needed. As we iterate over the generator, it yields batches of data until all of the data has been processed.

Why are Generators Useful?

Generators are useful for several reasons:

  1. Memory efficiency: Generators allow us to process large datasets that cannot fit into memory by loading data in smaller batches.
  2. Time efficiency: Since generators generate data on-the-fly, we can start processing data before the entire dataset has been loaded, which can save time. For example, let’s say we have a large dataset containing millions of records that we need to process. If we were to load the entire dataset into memory before processing it, this could take a significant amount of time, especially if the dataset is too large to fit into memory all at once. However, by using a generator to load and process the data one record at a time, we can start processing data immediately without waiting for the entire dataset to be loaded. This can significantly reduce the overall processing time and improve efficiency.
  3. Randomization: By default, generators shuffle the dataset before generating batches, which can help prevent overfitting.

Let's take a look at a simple code example that demonstrates how to use generators in PyTorch:

import torch
from torch.utils.data import DataLoader, Dataset

# Define a custom dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]

# Define some dummy data
data = [torch.randn(5) for _ in range(1000)]

# Create a DataLoader
batch_size = 32
dataloader = DataLoader(CustomDataset(data), batch_size=batch_size)

# Iterate over the DataLoader to get batches of data
for batch in dataloader:
print(batch.shape)

In this example, we define a custom dataset that contains some dummy data. We then create a DataLoader object with a batch size of 32, and iterate over the data loader to get batches of data. The print(batch.shape) line prints out the shape of each batch of data, which should be (32, 5).

A common use case for generators in PyTorch is image classification. Let's say we have a large dataset of images that we want to use to train a neural network. Instead of loading all of the images into memory at once, we can use a generator to load them in smaller batches.

Here's an example of how this might look in practice:

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

# Define a transformation to apply to each image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])

# Create a DataLoader for the training set
train_dataset = datasets.ImageFolder('path/to/training/dataset', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Create a DataLoader for the validation set
val_dataset = datasets.ImageFolder('path/to/validation/dataset', transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=32)

# Define a neural network
model = ...

# Train the model using the training set
for epoch in range(num_epochs):
# Iterate over the training set
for x, y in train_dataloader:
# Forward pass
output = model(x)

# Compute loss
loss = ...

# Backward pass
loss.backward()

# Update parameters
optimizer.step()
optimizer.zero_grad()

# Evaluate the model on the validation set
for x, y in val_dataloader:
...

Here we define a transformation to apply to each image (resizing it to 224x224 and converting it to a tensor). We then create a DataLoader for the training set and a separate DataLoader for the validation set. We use these data loaders to iterate over the datasets in batches, shuffle the training set, and evaluate the model on the validation set after each epoch of training.

Then we define a neural network and train it on the training set using the train_dataloader. We iterate over the training set in batches, compute the model's output, loss, and gradients, update the parameters, and zero out the gradients for the next batch. After training for one epoch, we evaluate the model on the validation set using the val_dataloader.

Generators are a powerful tool for handling large datasets in PyTorch. By generating data on-the-fly instead of loading it all into memory at once, we can process large datasets more efficiently and with less memory usage. The DataLoader class makes it easy to create generators in PyTorch, and they are commonly used in image classification and other deep learning tasks.

--

--

Hussain Wali

Software Engineer by profession. Data Scientist by heart. MS Data Science at National University of Science and Technology Islamabad.