7. Linear Layers on Steroids

Let’s quickly summarize our progress made in Section 6 before moving any further: First, we drafted the required forward-pass and backward-pass operations for a linear layer. PyTorch’s built-in torch.nn.Linear computes the values we expected from our small pen & paper example. We then successfully wrote our own linear layer by extending torch.nn.Module and torch.autograd.Function in Python. This can be lowered to C/C++ through pybind where we used ATen and eventually wrote source code for our own matrix-matrix multiplication function. Yet, we still relied on either PyTorch’s Python bindings, ATen or the compiler to use or generate fast code for us. As a rule of thumb, this typically works rather well for the Python bindings and ATen. The compiler we should not trust with this job.

This lab harnesses Just-In-Time (JIT) code generation for our custom layers. Specifically, we’ll generate small matrix-matrix multiplication kernels to formulate linear layers in this lab and two-dimensional convolutions in Section 8. Through this we take full control of all performance-relevant pieces! Once accomplished, PyTorch’s job is reduced to that one of a mere frontend where we are in control of the enabling backend. We’ll use the library LIBXSMM to JIT machine code for our custom layers. LIBXSMM supports JITting fast matrix multiplication kernels for most common vector instructions sets, i.e., AVX2, AVX512, ASIMD and SVE. We are interested in AVX512 kernels since our tests run on Intel Xeon Gold 6248R processors (code-named Ice Lake). However, the library is versatile enough such that our code would run out-of-the-box on other processors, e.g., those from AMD, NVIDIA or Intel.

Note

Interested in how one would JIT efficient machine code for small matrix kernels? Maybe the class High Performance Computing is something for you 😇.

In addition to small matrix kernels LIBXSMM is also able to generate code for other Tensor Processing Primitives (TPPs) which are required outside of linear layers or convolutions. An example are activation functions where the Rectified Linear Unit (ReLU) is commonly used. We’ll use the ReLU unary TPP when discussing the fusion of operators.

Our exploration of accelerated deep learning workloads through JITted TPPs starts where we left off: Linear layers which are also called fully connected layers. As discussed in Section 6 the forward- and backward-pass operations of linear layers are dominated by matrix-matrix products. The sizes of the involved matrices depend on the used batch size, the number of input features, and the number of output features. In this section we assume a batch size of 256, and a linear layer with 4096 input features and 4096 output features. The linear layer is directly followed by a ReLU activation function. These specs are chosen to resemble one of the final layers of the VGG architecture described in the work Very Deep Convolutional Networks for Large-Scale Image Recognition. A PyTorch implementation is available as part of the torchvision library where our targeted linear layer is part of the network’s classifier.

From a performance standpoint we’ll likely compete with highly optimized implementations for two reasons. First, fully connected layers are mainstream deep learning. It is likely that library maintainers have a high interest in getting them fast on every supported architecture. Second, midsize and large scale matrix-matrix products have also been optimized as part of libraries accelerating the Basic Linear Algebra Subprograms (BLAS) in the computational sciences for a very long time.

To ease our development efforts we’ll stay within C++ for the time being. Since most of PyTorch is build on top of the tensor library ATen, we’ll simply call ATen directly when comparing the sustained performance of our custom implementations to PyTorch. The playground mini_dnn builds the basis for our tests. mini_dnn brings supporting boilerplate and contains the ATen-based reference implementation mini_dnn::backend::MatmulReluAten which simply calls at::matmul and at::relu.

7.1. Blocked Matrix Multiplications

../_images/blocking.svg

Fig. 7.1.1 Illustration of our memory layout for a linear layer accelerated through small matrix kernels. Shown are the blocked input batch \(X_b \in \mathbb{R}^{N_b \times C_b \times b_c \times b_n}\) on the bottom-left, the blocked weight matrix \(W_b \in \mathbb{R}^{K_b \times C_b \times b_k \times b_c}\) on the top-right, and the blocked output batch \(Y_b = \mathbb{R}^{K_b \times N_b \times b_k \times b_n}\) on the bottom-right. Specifically, the blocking uses \(n_b = 3\) , \(k_b = 5\), \(c_b = 4\), and \(N_b = 2\), \(K_b = 3\) and \(C_b = 2\). Note that the chosen sizes are illustrative, i.e., one would not block such small matrices in the first place.

We use the sizes \(N\), \(K\) and \(C\) to describe the matrix-matrix product \(Y = XW\). \(X \in \mathbb{R}^{N \times C}\) is the input batch, \(W \in \mathbb{R}^{C \times K}\) the weight matrix, and \(Y \in \mathbb{R}^{N \times K}\) the output batch. Therefore, \(N\) is the batch size, \(C\) the number of input channels, and \(K\) the number of output channels. Fig. 7.1.1 illustrates these sizes for \(N=6\), \(K=15\) and \(C=8\).

We assume that we may use any suitable memory layout for storing the data. Since we will utilize small matrix-matrix kernels as our building block, we block the matrices \(X\), \(W\) and \(Y\) accordingly. Specifically, we choose the block sizes \(n_b\), \(k_b\) and \(c_b\) for \(N\), \(K\) and \(C\) respectively. The illustrative example shown in Fig. 7.1.1 uses \(n_b = 3\) , \(k_b = 5\), and \(c_b = 4\). With \(N_b\), \(K_b\) and \(C_b\) describing the number of blocks per dimension we therefore obtain \(N = N_b \cdot b_n\), \(K = K_b \cdot b_k\) and \(C = C_b \cdot b_c\). Fig. 7.1.1’s example thus uses \(N_b = 2\), \(K_b = 3\) and \(C_b = 2\). We use rank-4 tensors to store our blocked matrices \(X_b\), \(W_b\) and \(Y_b\), i.e., \(X_b \in \mathbb{R}^{N_b \times C_b \times b_c \times b_n}\), \(W_b \in \mathbb{R}^{K_b \times C_b \times b_k \times b_c}\) and \(Y_b = \mathbb{R}^{K_b \times N_b \times b_k \times b_n}\). Here, we assume that the rightmost dimension of every tensor has stride 1 and that the following dimensions simply follow contiguously. For example, for \(X_b \in \mathbb{R}^{N_b \times C_b \times b_c \times b_n}\) we obtain the strides \((C_b \cdot b_c \cdot b_n, \; b_c \cdot b_n, \; b_n, \; 1)\).

Given our blocked memory layout, the actual matrix multiplication is split into a series of small matrix multiplications. We use three loops over \(K_b\), \(N_b\) and \(C_b\) to implement the blocked matrix multiplication using small matrix kernels:

for( int64_t l_kb = 0; l_kb < l_sizes.kb; l_kb++ ) {
  for( int64_t l_nb = 0; l_nb < l_sizes.nb; l_nb++ ) {
    for( int64_t l_cb = 0; l_cb < l_sizes.cb; l_cb++ ) {
      // TODO: execute small matrix kernel
    }
  }
}

Fig. 7.1.2 shows the first small matrix kernel, i.e., l_kb=0, l_nb=0 and l_cb=0, for our running example illustrated in Fig. 7.1.1. The first matrix multiplication computes the partial result \(Y^*_b[0,0] := \text{Matmul} \left( X_b[0,0], W_b[0,0] \right)\).

The cb loop is our innermost one w.r.t. the blocking. This means that we add the contribution of l_kb=0, l_nb=0, l_cb=1 to the partial result \(Y^*_b[0,0]\) next. As illustrated in Fig. 7.1.3, we now complete \(Y_b\)’s respective block: \(Y_b[0,0] := Y^*_b[0,0] + \text{Matmul} \left( X_b[0,1], W_b[0,1] \right)\).

Next, we compute the partial result \(Y^*_b[0,1] := \text{Matmul} \left( X_b[1,0], W_b[0,0] \right)\) which corresponds to l_kb=0, l_nb=1 and l_cb=0. This situation is illustrated in Fig. 7.1.4. The remaining small matrix products follow in a similar fashion.

../_images/gemm_000.svg

Fig. 7.1.2 Illustration of the first executed small matrix multiplication kernel when computing the blocked matrix multiplication for the blocking shown in Fig. 7.1.1.

../_images/gemm_001.svg

Fig. 7.1.3 Illustration of the second executed small matrix multiplication kernel when computing the blocked matrix multiplication for the blocking shown in Fig. 7.1.1.

../_images/gemm_100.svg

Fig. 7.1.4 Illustration of the third executed small matrix multiplication kernel when computing the blocked matrix multiplication for the blocking shown in Fig. 7.1.1.

Tasks

  1. Block the input batch \(X\) and the weight matrix \(W\) in the unit tests, i.e., file MatmulAtenBlocked.test.cpp in mini_dnn’s unfinished mini_dnn::backend::MatmulAtenBlocked implementation. Double-check that the sizes and strides of the obtained blocked matrices match your expectations.

  2. Pass the blocked matrices to mini_dnn::backend::MatmulAtenBlocked::forward and finish the forward-pass operation by calling ATen for every submatrix, i.e., you have to call at::matmul \(K_b \cdot N_b \cdot C_b\) times. Return the blocked output batch.

  3. In the unit tests bring the blocked layout back to the one before you blocked the matrices. Make sure that your computed result matches ATen’s unblocked version.

Hint

The blocking can be done by first creating a view which splits every dimension. For example, for the input \(X \in \mathbb{R}^{N \times C}\) , we would first create the view \(X_v \in \mathbb{R}^{N_b \times b_n \times C_b \times b_c}\). Next, we can reorder the dimensions using at::permute and finally call at::contiguous to obtain our desired memory layout.

Hint

ATen assumes a row-major ordering of the involved matrices when calling at::matmul. This means, if we’d like to compute \(C = AB\) for the matrices \(A\) and \(B\) stored row-major in l_a and l_b, we’d implement something similar to at::Tensor l_c = at::matmul(l_a, l_b). Now, assume that l_a and l_b hold the matrices in column-major format as is the case for our blocked layout also shown in Fig. 7.1.1. In this case we simply compute the “transposed” product, i.e., \(C^T = B^TA^T\) through at::Tensor l_c = at::matmul(l_b, l_a)

7.2. Small JITted GEMMs

The blocking boilerplate is in place and tested using calls to ATen. Now, we replace the calls to ATen with JITted small matrix kernels. Once again, to accelerate your developments, a template is provided by mini_dnn’s unfinished class mini_dnn::backend::MatmulLibxsmm. Your job is to finish the implementation and measure its performance.

Further, OpenMP allows us to harness all available cores of a processor. We will parallelize the two outermost blocking loops, i.e., those over \(K_b\) and \(N_b\), by using the collapse clause which is defined in Ch. 2.9.2 of the OpenMP Application Programming Interface. Be aware that we keep the \(C_b\) loop sequential w.r.t. OpenMP since we accumulate the results of the small matrix kernels in respective blocks of \(Y_b\). This means that a read-after-write dependence exists which prohibits a straightforward implementation.

Hint

Use OMP_PLACES={0}:48:1 OMP_NUM_THREADS=48 to spawn 48 threads and pin each of them to a single core of an Intel Xeon Gold 6248R processor. You may display the pinning by additionally setting OMP_DISPLAY_AFFINITY=true.

Tasks

  1. Finish the implementation of mini_dnn::backend::MatmulLibxsmm::forward by following the structure of your mini_dnn::backend::MatmulAtenBlocked::forward function. However, call LIBXSMM-generated kernels instead of ATen.

  2. Test your implementation thoroughly through unit tests. Make sure to include non-square matrices. For example, use \(N=128\) as your batch size, \(K=256\) output features and \(C=512\) input features. A corresponding blocking might set \(b_n = 64\), \(b_k=32\) and \(b_c=128\).

  3. Benchmark the performance of your implementation on a single core. Resemble a linear layer of the VGG architecture by using a batch size of 256, and 4,096 input and output features in your performance tests. Report the sustained GFLOPS.

  4. Add an OpenMP parallelization to your implementation. Benchmark and report the performance when running on all cores of a single node.

7.3. Operator Fusion

Feed-forward neural networks are typically composed of a series of layers. This composition is often static and might be formulated as a computational graph. When the graph is available to us in the backend, we may harness the available information to get the entire workload fast not only an isolated layer. In fact, we implicitly assumed such a situation already by employing a highly customized data layout for the inputs and outputs of our linear layers. The data layout greatly eased our job when formulating the blocked matrix multiplication. However, such an approach is only viable if we are in charge of the procedure which passes data from one layer to another. Otherwise the conversion overhead w.r.t. our custom data layout would most likely eliminate performance gains made elsewhere.

Note

Let’s change our viewpoint briefly to the frontend. PyTorch and TensorFlow both support the so-called eager mode. While this was a somewhat unique feature of PyTorch in the past, eager execution was introduced in TensorFlow 2.0. Eager mode means that the operations are executed immediately. This is great for prototyping and debugging but hurts overall performance. In contrast, when running in graph mode one would assemble a computational graph first and then run the operations collectively. Additional information on the topic is available from a blog post by Google AI and a PyTorch tutorial.

In this last part of the lab we assume that our linear layer is directly followed by a ReLU. A naive approach would compute the result of the linear layer first and then apply the ReLU to the output element-wise. If we assume sufficiently large inputs or weights, the linear layer’s result only fits in a high cache level or main memory. This means that for applying the ReLU we have to load all data back to the registers of our CPU, apply the ReLU, and then write it back. Since there’s barely any computational work involved when applying a ReLU, the resulting procedure will be heavily bound by the bandwidth of the respective memory level in which the data resides in.

A better approach applies the ReLU directly to the linear layer’s results while they are still in L1 cache. The innermost loop over \(C_b\) of our blocked approach simply adds the intermediate results of the small GEMMs to a fixed block of the result tensor \(Y_b\). Thus, if we operate on sufficiently small blocks, this block remains in the L1 data cache of a core for the entire duration of the \(C_b\) loop. After the loop finished, the partial result of \(Y_b\) is completely computed and we could move on with the next iteration of the \(N_b\) loop which computes a different partial result. However, instead of continuing our small matrix multiplications, we now first apply the ReLU to the partial results while the data is still hot. Such an approach is called operator fusion and in our case harnesses the high L1 bandwidth for the ReLU.

Tasks

  1. Implement the class mini_dnn::backend::ReluLibxsmm which simply applies the ReLU to the input tensor in the forward function. Generate a ReLU kernel through LIBXSMM and use the same blocking as before. Verify and benchmark your implementation!

  2. Benchmark the performance of computing a linear layer followed by a ReLU through the sequential application of your mini_dnn::backend::MatmulLibxsmm.forward and mini_dnn::backend::ReluLibxsmm.forward functions. Use a batch size of 256, and 4,096 input and output features in your performance tests.

  3. Implement the class mini_dnn::backend::MatmulReluLibxsmm which fuses the application of the ReLU into the blocking loops as discussed above. Verify and benchmark your implementation! Use a batch size of 256, and 4,096 input and output features in your performance tests.