.. _ch:distributed_ml: Distributed Deep Learning ========================= Until this point, we ran our machine learning workloads exclusively in the shared memory domain. This means that all communication between the workers happened through a node's memory subsystem. Often times this is sufficient for simple models and one doesn't have to push further. Yet, state-of-the-art machine learning is becoming increasingly hungry for compute power. `GPT-3 `__ is one such model at the upper end of the spectrum. So, what should we do if a node isn't sufficient for our needs? Well.. we could just use more nodes 😁. In this lab we'll harness PyTorch's `distributed memory backend(s) `__ to overcome node-boundaries. This allows workers to collaborate over the network. Effectively this approach gives us the ability to greatly increase the compute power available to our workloads. We'll start our distributed memory journey by having a in-depth look at the comparably low-level `c10d library `__. Next, we'll have a look at PyTorch's high-level approach `Distributed Data-Parallel Training `__. .. _ch:distributed_ml_scenic: The Scenic Route ---------------- The collective communication library c10d is at the lowest PyTorch-level. It supports peer-to-peer and collective communication. In general, the syntax is very close to that of the `Message Passing Interface `__ (MPI) which allows us to get started very quickly. For example, one uses ``torch.distributed.send`` and ``torch.distributed.recv`` for blocking sends and receives. The non-blocking counterpart is accessed through ``torch.distributed.isend`` and ``torch.distributed.irecv``, and respective waits on the returned communication requests. Further, collectives, e.g., ``torch.distributed.all_reduce``, operate on groups similar to MPI's communicators. Going further down the software stack, c10d supports a series of `backends `__. One of these backends is MPI which is the one we'll use. We start simple by communicating some elementary data. Data w.r.t. c10d is in PyTorch's currency, meaning that we'll communicate tensors. .. admonition:: Tasks #. Write a simple program which initializes ``torch.distributed`` with the MPI-backend. Print the rank and size on every process. #. Allocate a 3 :math:`\times` 4 tensor on every rank. Initialize the tensor to ones on rank 0 and to zeroes on all other ranks. Use blocking sends and receives to send rank 0's tensor to rank 1. #. Repeat the previous task but use non-blocking sends and receives. #. On every rank, allocate a 3 :math:`\times` 4 tensor and initialize it to: .. math:: \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix}. Perform an `allreduce `__ with the `reduce operation `__ ``SUM`` on the tensor. Print the result! .. _ch:distributed_ml_data_parallel: Vista Point: Data Parallelism ----------------------------- The idea of data parallelism is a rather simple one: * Replicate the model on every rank; * Perform the training in parallel, i.e., partition every minibatch among the ranks and run the forward- and backward-pass in parallel; and * Keep the distributed replicas in sync, i.e., average the gradient updates over the entire minibatch and make sure to apply the same weight updates everywhere. In praxis, this poses two challenges: First, the minibatch has to be partitioned consistently w.r.t. the version without data parallelism. Second, we have to average the distributed gradients before updating the weights. Partitioning Minibatches ^^^^^^^^^^^^^^^^^^^^^^^^ At first, the partitioning task might appear simple. We have :math:`n` samples in our dataset and :math:`p` processes: Every process simply gets :math:`\frac{n}{p}` samples, problem solved 🤔? When looking a bit closer, the situation is more complex. Our goal is to partition the minibatches as they appear in the sequential version. Especially, we require a consistent parallel implementation of commonly used shuffle operations on the entire dataset. .. figure:: /chapters/data_distributed_ml/ordered_batch_1.svg :name: fig:ordered_batch_1 :width: 100% Illustration of a dataset with 24 samples. The sequential training uses minibatches with eight samples. Currently, the second minibatch is being processed. For example, assume the situation shown in :numref:`fig:ordered_batch_1`. In this case, we have :math:`n=24` samples and a single process, i.e., :math:`p=1`. The size of a minibatch is :math:`m=8`. As highlighted in red the single process currently works on the second minibatch with associated sample-ids 8-15. .. figure:: /chapters/data_distributed_ml/ordered_dist_batch_1.svg :name: fig:ordered_dist_batch_1 :width: 100% Illustration of a dataset with 24 samples which is replicated on two processes. Analogously to the sequential example in :numref:`fig:ordered_batch_1` a minibatch size of eight is used. Currently, the processes work on the second minibatch in a data-parallel fashion. As highlighted in red, process 0 works on one half of the minibatch and process 1 on the other half. Now, we'd like two harness two processes, i.e., :math:`p=2`, in a data parallel way. Equally partitioning a minibatch means that every process gets a microbatch with :math:`\frac{m}{p} = 4` samples of every minibatch. One possible partitioning is illustrated in :numref:`fig:ordered_dist_batch_1`. Here, the first process works on one half of the minibatch, i.e., it work on the samples with ids 8, 10, 12 and 14. Similarly, the second process works on the other half with sample-ids 9, 11, 13 and 15. .. figure:: /chapters/data_distributed_ml/shuffled_dist_batch_1.svg :name: fig:shuffled_dist_batch_1 :width: 100% Illustration of a dataset with 24 samples which is replicated on two processes. The setting is largely identical to the one shown in :numref:`fig:ordered_batch_1`. However, the dataset was shuffled at the beginning of the epoch. The situation gets more complex if support for shuffles is considered. An example of this is given in :numref:`fig:shuffled_dist_batch_1`. Be aware that we have to do the very same shuffling on both processes. This typically means using the same random seeds on all processes. Once we implemented a consistent shuffling of the sample ids, we obtain the partitioning of the minibatches by following our previous ideas. Specifically for the situation in :numref:`fig:shuffled_dist_batch_1`, the first process once again works on the one half of the minibatch, i.e., the samples with ids 10, 17, 1 and 23. Analogously, the second process works on the other half, i.e., samples 8, 15, 7 and 14. .. admonition:: Tasks #. Use `torch.utils.data.distributed.DistributedSampler `__ to derive the distribution of a dataset to MPI processes. Illustrate the behavior of the sampler by passing it to `torch.utils.data.DataLoader `__ through the parameter ``sampler``. Pass the dataloader a dataset of type ``SimpleDataSet`` through the argument ``dataset`` in your tests: .. code-block:: python class SimpleDataSet( torch.utils.data.Dataset ): def __init__( self, i_length ): self.m_length = i_length def __len__( self ): return self.m_length def __getitem__( self, i_idx ): return i_idx*10 Highlight the influence of the sampler's optional parameters ``num_replicas``, ``rank``, ``shuffle`` and ``drop_last``. #. Wrap your ``DistributedSampler`` in a `BatchSampler `__. Once again use ``SimpleDataSet`` in your tests. Highlight the influence of the parameters ``batch_size`` and ``drop_last``. Synchronizing the Replicas ^^^^^^^^^^^^^^^^^^^^^^^^^^ We have found a way to consistently partition every minibatch, great! Let's move forward: Time to work on the synchronization of the replicated model. We already know that PyTorch's optimizers use the computed vector-Jacobian products to update a model's parameters in every step. Typically, we zero all gradients before doing the backward pass of a minibatch. After the backward pass the gradients of a single minibatch are distributed among the nodes. For example, consider the situation in :numref:`fig:ordered_dist_batch_1`. Here, the first process holds the gradients w.r.t. samples 8, 10, 12 and 14. In contrast, the second process holds those w.r.t. samples 9, 11, 13 and 15. Thus, to get the gradient of the entire minibatch, we have to average all distributed partial gradients. Once, the gradient of the entire minibatch is available on all processes, we are able update the weights in a consistent way: The replicas stay in sync! Luckily we already have all required tools at hand. c10d's allreduce together with the reduce operation ``SUM`` is sufficient. Further, the individual gradients of a model's parameters are easily obtained from PyTorch. Assume that the replicas are stored in the variable ``io_model`` on every process and the number of MPI-processes in ``i_size_distributed``, then the desired allreduce's are implemented in a few lines of code: .. code-block:: python :linenos: :caption: Averaging of the gradients for synchronous data parallel training. :name: lst:allreduce_avg # reduce gradients for l_pa in io_model.parameters(): torch.distributed.all_reduce( l_pa.grad.data, op = torch.distributed.ReduceOp.SUM ) l_pa.grad.data = l_pa.grad.data / float(i_size_distributed) .. admonition:: Optional Note The code in :numref:`lst:allreduce_avg` appears to be super simple. Just four lines of code, how bad can it be.., right? Be aware that `efficient implementations `__ of the allreduce-operation are very complex 😱. Our easy life is simply backed by very powerful libraries! Here, the MPI library used in PyTorch's backend contains carefully implemented and highly advanced algorithms for the allreduce operations. .. admonition:: Tasks #. Add the allreduce in :numref:`lst:allreduce_avg` to your MLP's training procedure written in :numref:`ch:pytorch_mlp`. #. Adjust all other required distributed parts, e.g., the derivation of the training loss, validation loss, or accuracy derivations. Make sure that the results of your data parallel training match those of the sequential training. .. hint:: Floating point math is not associative. This means that changing the summation orders will most likely change the result! Thus, we have to accept some small differences in the solutions. Be on the guard though, it's easy to confuse serious bugs with inaccuracies originating from floating point math. #. Measure the speedup of your training when increasing the number of processes. Use one MPI process per NUMA domain and up to four nodes. .. hint:: You may view the two NUMA domains of a node in the ``short`` queue by using the tool ``lscpu``. The following snipped allows us to efficiently use a single node through the OpenMPI-option ``--bind-to socket`` and the environment variable ``OMP_NUM_THREADS``: .. code-block:: bash OMP_NUM_THREADS=24 mpiexec -n 2 --bind-to socket --report-bindings python mlp_fashion_mnist_distributed.py Distributed Data Parallel ------------------------- :numref:`ch:distributed_ml_data_parallel` went through the steps required for data-parallel training in PyTorch. We have seen that a substantial amount of work is required to get the inputs and outputs into the right form. Only the small code piece in :numref:`lst:allreduce_avg` did the actual allreduces of the gradients and kept the replicas in sync. In this section we'll extend our data parallel training by using `torch.nn.parallel.DistributedDataParallel (DDT) `__. DDT has two key features. First, it takes care of the initial replication of the model by broadcasting the model's weights. This step is required since the weights might be initialized randomly. In :numref:`ch:distributed_ml_data_parallel` we worked around this by setting the same random seeds everywhere and hoping for the best 😇. Second, DDT takes care of efficiently averaging the gradients for us. Our c10-based snipped was sufficient to get this done correctly, but better implementations are possible. For example, DDT `supports `__ partial reductions of the first set of gradients while others are still being computed. Such an approach allows us to hide communication behind computation whereas our code in :numref:`lst:allreduce_avg` has a communication-only phase. .. admonition:: Tasks #. Read the paper `PyTorch Distributed: Experiences on Accelerating Data Parallel Training `__. Explain the advantages and disadvantages of larger gradient buckets (3.2.2 and 3.2.3). #. Integrate DDT into :numref:`ch:distributed_ml_data_parallel`'s data-parallel version of the Multilayer Perceptron. #. Repeat the scaling experiment on four nodes. Did time-to-solution of your DDT enhanced version improve?