Building Scalable PyTorch Models with Dask and Distributed Data Parallelism

Blake Bradford Avatar

·

Building Scalable PyTorch Models with Dask and Distributed Data Parallelism

Are you looking to train large-scale PyTorch models that can leverage distributed computing to achieve improved scalability and performance? In this article, we will explore how you can harness the power of Dask and PyTorch’s Distributed Data Parallel (DDP) to build scalable and performant PyTorch models.

Introducing dask-pytorch-ddp

At the heart of this endeavor is the dask-pytorch-ddp library. This Python package facilitates the training of PyTorch models on Dask clusters using distributed data parallelism. The main objectives of this project are:

  1. Bootstrapping PyTorch workers on top of a Dask cluster.
  2. Using distributed data stores, such as S3, as normal PyTorch datasets.
  3. Enabling tracking and logging of intermediate results, training statistics, and checkpoints.

While the initial focus of this library is on computer vision tasks, the underlying functionality is intended to be applicable to any PyTorch task. The S3ImageFolder dataset class, specific to image processing, can be easily extended to support other types of datasets.

Loading Data

In a typical non-Dask PyTorch workflow, loading data involves creating a dataset and wrapping it in a DataLoader. However, in the context of Dask and distributed data parallelism, the process is slightly different. With dask-pytorch-ddp, you can load datasets from S3 and explicitly set the multiprocessing context to ensure compatibility with PyTorch’s forking mechanism.

“`python
from dask_pytorch_ddp.data import S3ImageFolder

whole_dataset = S3ImageFolder(bucket, prefix, transform=transform)
train_loader = torch.utils.data.DataLoader(
whole_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, multiprocessing_context=mp.get_context(‘fork’)
)
“`

Training in Parallel

Training a model in a parallelized environment involves wrapping the training loop in a function and converting the model into a PyTorch Distributed Data Parallel (DDP) model. The DDP model, which knows how to synchronize gradients across workers, enables running Stochastic Gradient Descent (SGD) with a larger batch size.

“`python
import uuid
import pickle
import logging
import json

key = uuid.uuid4().hex
rh = DaskResultsHandler(key)

def run_transfer_learning(bucket, prefix, samplesize, n_epochs, batch_size, num_workers, train_sampler):
worker_rank = int(dist.get_rank())
device = torch.device(0)
net = models.resnet18(pretrained=False)
model = net.to(device)
model = DDP(model, device_ids=[0])

criterion = nn.CrossEntropyLoss().cuda()
lr = 0.001
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

count = 0
for epoch in range(n_epochs):
    model.train()
    for inputs, labels in train_loader:
        dt = datetime.datetime.now().isoformat()
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        count += 1

        rh.submit_result(
            f"worker/{worker_rank}/data-{dt}.json",
            json.dumps({'loss': loss.item(), 'epoch': epoch, 'count': count, 'worker': worker_rank})
        )
        if (count % 100) == 0 and worker_rank == 0:
            rh.submit_result(f"checkpoint-{dt}.pkl", pickle.dumps(model.state_dict()))

“`

How does it work?

dask-pytorch-ddp is essentially a wrapper around existing PyTorch functionality, specifically the pytorch.distributed module that provides infrastructure for Distributed Data Parallel (DDP) training. In DDP, each worker synchronizes buffers and gradients, with the 0th worker acting as the “master” and coordinating the synchronization.

To use DDP effectively, dask-pytorch-ddp sets environment variables to configure the master host and port, calls init_process_group before training, and destroy_process_group after training. This behind-the-scenes setup is now seamlessly handled for you.

Multi-GPU Machines

For those working on multi-GPU machines, dask_cuda_worker automatically rotates CUDA_VISIBLE_DEVICES for each worker it creates. As a result, your PyTorch code should always start with the 0th GPU, regardless of the number of GPUs available.

Additional Features

Apart from distributed data parallelism, dask-pytorch-ddp also includes an S3-based ImageFolder class, making it easy to work with image datasets stored in distributed data stores. The library also implements a basic results aggregation framework, with the DaskResultsHandler leveraging Dask’s pub-sub communication protocols for collecting training metrics across different workers. Future plans for the library include support for additional distributed-friendly datasets and result handlers.

Some Notes

When working with Dask, it’s important to consider that Dask generally spawns processes while PyTorch generally forks. Therefore, for compatibility, it is recommended to pass the Fork multiprocessing context to the multiprocessing-enabled data loader.

In some Dask deployments, spawning processes may not be permitted. In such cases, you can override this behavior by changing the distributed.worker.daemon setting, either directly or through environment variables.

python
DASK_DISTRIBUTED__WORKER__DAEMON=False

In Conclusion

By harnessing the power of Dask and PyTorch’s Distributed Data Parallel (DDP), you can build scalable and performant PyTorch models that can leverage distributed computing. The dask-pytorch-ddp library simplifies the process of training models on Dask clusters and provides support for distributed data parallelism, data loading from S3, and result aggregation. Start exploring the possibilities of scalable PyTorch models today!

References

Leave a Reply

Your email address will not be published. Required fields are marked *