13. Distributed Deep Learning

Up to this point, we have run our machine learning workloads exclusively in the shared memory domain. This means that all communication between workers was done through a node’s memory subsystem. Often this is sufficient for inference and training simple models, and we don’t need to go any further. However, modern machine learning is becoming increasingly hungry for computational power. GPT-4 is one such model at the high end of the spectrum. So what do we do when one node isn’t enough for our needs? Well… we could just use more nodes 😁.

In this lab we’ll use PyTorch’s distributed memory backend(s) to overcome node boundaries. This allows workers to collaborate across the network. Effectively, this approach gives us the ability to greatly increase the computational power available to our workloads.

We’ll begin our journey into distributed memory computing by taking a close look at the relatively low-level c10d library. Next, we’ll take a look at PyTorch’s high-level Distributed Data-Parallel training approach.

13.1. The Scenic Route

The collective communication library c10d is at the lowest level of PyTorch. It supports peer-to-peer as well as collective communication. In general, the syntax is very close to the Message Passing Interface (MPI) which allows us to get started very quickly. For example, we can use torch.distributed.send and torch.distributed.recv for blocking sends and receives. The non-blocking counterpart is accessed via torch.distributed.isend and torch.distributed.irecv and waits for the returned communication requests. Furthermore, collectives, e.g., torch.distributed.all_reduce, operate on groups similar to MPI communicators.

Further down the software stack, c10d supports a number of backends. One of these backends is MPI, which is the one we’re going to use. We will start by simply communicating some elementary data. Data w.r.t. c10d is in PyTorch’s currency, meaning we’ll be communicating tensors.

Tasks

  1. Write a simple program that initializes torch.distributed with the MPI backend. Print the rank and size on every process.

  2. Allocate a 3 \(\times\) 4 tensor on each rank. Initialize the tensor to ones on rank 0 and to zeros on all other ranks. Use blocking sends and receives to send rank 0’s tensor to rank 1.

  3. Repeat the previous task, but use non-blocking sends and receives.

  4. On every rank, allocate a 3 \(\times\) 4 tensor and initialize it to:

    \[\begin{split}\begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix}.\end{split}\]

    Perform an allreduce with the reduce operation SUM on the tensor. Print the result!

13.2. Vista Point: Data Parallelism

The idea of data parallelism is quite simple:

  • Replicate the model on every rank;

  • Perform training in parallel, i.e. partition each minibatch across the ranks and run the forward and backward passes in parallel; and

  • Keep the distributed replicas in sync, i.e. average the gradient updates over the entire minibatch and ensure that the same weight updates are applied everywhere.

In practice, this poses two challenges: First, the minibatch must be consistently partitioned with respect to the version without data parallelism. Second, we need to average the distributed gradients before updating the weights.

Partitioning Minibatches

At first glance, the partitioning task seems simple. We have \(n\) samples in our dataset and \(p\) processes: Each process simply gets \(\frac{n}{p}\) samples, problem solved 🤔?

Looking a bit closer, the situation is more complex. Our goal is to partition the minibatches as they appear in the sequential version. In particular, we need a consistent parallel implementation of commonly used shuffle operations on the entire dataset.

../_images/ordered_batch_1.svg

Fig. 13.2.1 Illustration of a dataset with 24 samples. Sequential training uses minibatches of eight samples. The second minibatch is currently being processed.

For example, consider the situation shown in Fig. 13.2.1. In this case we have \(n=24\) samples and a single process, i.e. \(p=1\). The size of a minibatch is \(m=8\). As highlighted in red, the single process is currently working on the second minibatch with associated sample ids 8-15.

../_images/ordered_dist_batch_1.svg

Fig. 13.2.2 Illustration of a dataset with 24 samples replicated on two processes. Similar to the sequential example in Fig. 13.2.1, a minibatch size of eight is used. Currently, the processes work on the second minibatch in a data-parallel fashion. As highlighted in red, process 0 works on one half of the minibatch and process 1 works on the other half.

Now we’d like to use two processes, i.e. \(p=2\), in a data-parallel way. Equally partitioning a minibatch means that each process gets a microbatch with \(\frac{m}{p} = 4\) samples from each minibatch. One possible partitioning is shown in Fig. 13.2.2. Here, the first process works on half of the minibatch, i.e. it works on the samples with ids 8, 10, 12 and 14. Similarly, the second process works on the other half with sample ids 9, 11, 13 and 15.

../_images/shuffled_dist_batch_1.svg

Fig. 13.2.3 Illustration of a dataset with 24 samples replicated on two processes. The setting is largely identical to that shown in Fig. 13.2.1. However, the dataset has been shuffled at the beginning of the epoch.

The situation becomes more complex when support for shuffles is considered. An example is given in Fig. 13.2.3. Note that we need to do the same shuffling on both processes. This typically means using the same random seeds on all processes. Once we have implemented consistent shuffling of the sample ids, we obtain the partitioning of the minibatches by following our previous ideas. Specifically for the situation in Fig. 13.2.3, the first process again works on half of the minibatch, i.e. the samples with ids 10, 17, 1, and 23. Similarly, the second process works on the other half, i.e. samples 8, 15, 7 and 14.

Tasks

  1. Use torch.utils.data.distributed.DistributedSampler to derive the distribution of a dataset to MPI processes. Illustrate the behavior of the sampler by passing it to torch.utils.data.DataLoader using the sampler parameter. Pass a dataset of type SimpleDataSet to the dataloader through the argument dataset in your tests:

    class SimpleDataSet( torch.utils.data.Dataset ):
      def __init__( self,
                    i_length ):
        self.m_length = i_length
    
      def __len__( self ):
        return self.m_length
    
      def __getitem__( self,
                       i_idx ):
        return i_idx*10
    

    Highlight the influence of the sampler’s optional parameters num_replicas, rank, shuffle and drop_last.

  2. Wrap your DistributedSampler in a BatchSampler. Use SimpleDataSet in your tests. Highlight the influence of the batch_size and drop_last parameters.

Synchronizing Replicas

We found a way to consistently partition each minibatch, great! Let’s move forward: Time to work on synchronizing the replicated model. We already know that PyTorch’s optimizers use the computed vector-Jacobian products to update a model’s parameters at each step. Typically, we set all gradients to zero before doing the backward pass of a minibatch. After the backward pass the gradients of a single minibatch are distributed among the nodes. For example, consider the situation in Fig. 13.2.2. Here, the first process holds the gradients with respect to samples 8, 10, 12, and 14. In contrast, the second process holds the gradients w.r.t. samples 9, 11, 13, and 15. Thus, to obtain the gradient of the entire minibatch, we have to average all the distributed partial gradients. Once the gradient of the entire minibatch is available on all processes, we can update the weights in a consistent manner: The replicas stay in sync!

Fortunately, we already have all the tools we need. c10d’s allreduce together with the reduce operation SUM is sufficient. Furthermore, the individual gradients of a model’s parameters are easily obtained from PyTorch. Assuming that the replicas on each process are stored in the variable io_model and the number of MPI processes in i_size_distributed, the desired allreduce’s are implemented in a few lines of code:

Listing 13.2.1 Averaging gradients for synchronous data parallel training.
1# reduce gradients
2for l_pa in io_model.parameters():
3  torch.distributed.all_reduce( l_pa.grad.data,
4                                op = torch.distributed.ReduceOp.SUM )
5  l_pa.grad.data = l_pa.grad.data / float(i_size_distributed)

Optional Note

The code in Listing 13.2.1 seems super simple. Only four lines of code, how bad can it be…, right? Be aware that efficient implementations of the allreduce operation are very complex 😱. Our easy life is simply backed by very powerful libraries! Here, the MPI library used in PyTorch’s backend contains carefully implemented and highly advanced algorithms for allreduce operations.

Tasks

  1. Add the allreduce in Listing 13.2.1 to the training procedure of your MLP written in Section 4.

  2. Adjust any other distributed parts you need, such as the training loss derivation, validation loss derivation, or accuracy derivations. Make sure that the results of your data-parallel training match those of the sequential training.

    Hint

    Floating-point arithmetic is not associative. This means that changing the order of summation will most likely change the result! So we have to accept some small differences in the solutions. Be careful, it’s easy to confuse serious bugs with inaccuracies caused by floating-point math.

  3. Measure the speedup of your training as you increase the number of processes. Use one MPI process per NUMA domain and up to four nodes.

    Hint

    You can view the two NUMA domains of a node in the short queue by using the lscpu tool. The following snippet allows us to use a single node efficiently by using the OpenMPI --bind-to socket option and the OMP_NUM_THREADS environment variable:

    OMP_NUM_THREADS=24 mpiexec -n 2 --bind-to socket --report-bindings python mlp_fashion_mnist_distributed.py
    

13.3. Distributed Data Parallel

In Section 13.2, we went through the steps necessary for data-parallel training in PyTorch. We saw that a lot of work is needed to get the inputs and outputs in the right form. Only the small piece of code in Listing 13.2.1 did the actual allreduces of the gradients and kept the replicas in sync. In this section we’ll extend our data parallel training by using torch.nn.parallel.DistributedDataParallel (DDT). DDT has two main features. First, it handles the initial replication of the model by broadcasting the model weights. This step is necessary because the weights can be randomly initialized. In Section 13.2 we got around this by setting the same random seeds everywhere and hoping for the best 😇. Second, DDT takes care of efficiently averaging the gradients for us. Our c10-based snipped was sufficient to do this correctly, but better implementations are possible. For example, DDT supports partial reductions of the first set of gradients while others are still being computed. Such an approach allows us to hide the communication behind the computation whereas our code in Listing 13.2.1 has a pure communication phase.

Tasks

  1. Read the paper PyTorch Distributed: Experiences on Accelerating Data Parallel Training. Explain the advantages and disadvantages of larger gradient buckets (3.2.2 and 3.2.3).

  2. Integrate DDT into the data-parallel version of the Multilayer Perceptron from Section 13.2.

  3. Repeat the scaling experiment using four nodes. Did the time-to-solution of your DDT enhanced version improve?