7. Linear Layers on Steroids
Let’s briefly summarize the progress we made in Section 6 before moving on:
First, we implemented the necessary forward and backward pass operations for a linear layer.
PyTorch’s built-in torch.nn.Linear
computes the values we expected from our little pen-and-paper example.
We then successfully wrote our own linear layer by extending torch.nn.Module
and torch.autograd.Function
in Python.
This can be lowered to C/C++ via pybind, where we used ATen and eventually wrote source code for our own matrix-matrix multiplication function.
However, we still relied on either PyTorch’s Python bindings, ATen, or the compiler to call/generate fast code for us.
As a rule of thumb, this usually works pretty well for the Python bindings and ATen.
We should not trust the compiler with this job.
This lab follows two closely related approaches to formulate linear layers that give us full control over all performance relevant parts! Once this is done, PyTorch’s role is reduced to that of a mere frontend where we control the enabling backend. First, we’ll use a hand-written assembly kernel for our custom layers. Specifically, we’ll use a fixed-size small matrix-matrix multiplication kernel. The kernel issues Neon vector instructions and is therefore limited to Arm processors. We are interested in Neon kernels because our tests run on NVIDIA Grace processors (Arm Neoverse V2 cores).
In a second step, we’ll use just-in-time (JIT) code generation for the kernels in our custom layers. This overcomes two limitations: 1) we can run our code on different processors with different vector extensions, and 2) we can generate kernels with different matrix shapes as needed by our custom layers. We’ll use the LIBXSMM library to generate machine code for our custom layers. LIBXSMM supports JITting of fast matrix multiplication kernels for the most common vector instruction sets, i.e. AVX2, AVX512, Neon, and SVE. In this lab, we are only running on NVIDIA Grace, but the library is versatile enough that our code would run out-of-the-box on other processors, such as those from AMD, Apple, Intel or Qualcomm.
Note
Interested in how to JIT efficient machine code for small matrix kernels? Maybe the High Performance Computing class is for you 😇.
In addition to small matrix kernels LIBXSMM can generate code for other Tensor Processing Primitives (TPPs) that are needed outside of linear layers or convolutions. An example is activation functions, where the Rectified Linear Unit (ReLU) is often used. We’ll use the ReLU unary TPP when discussing operator fusion.
Our exploration of accelerated deep learning workloads through a hand-written kernel and JITted TPPs begins where we left off: Linear layers, also known as fully connected layers. As discussed in Section 6, the forward and backward pass operations of linear layers are dominated by matrix-matrix products. The sizes of the matrices involved depend on the batch size used, the number of input features, and the number of output features. In this section, we assume a batch size of 256 and a linear layer with 4096 input features and 4096 output features. The linear layer is immediately followed by a ReLU activation function. These specifications are chosen to resemble one of the final layers of the VGG architecture described in the paper Very Deep Convolutional Networks for Large-Scale Image Recognition. A PyTorch implementation is available as part of the torchvision library, where our desired linear layer is part of the network’s classifier.
From a performance perspective we’re likely to compete with highly optimized implementations for two reasons. First, fully connected layers are mainstream deep learning. It is likely that library maintainers have a strong interest in making them fast on any supported architecture. Second, medium- and large-scale matrix-matrix products have long been optimized as part of libraries that accelerate Basic Linear Algebra Subprograms (BLAS) in the computational sciences
To simplify our development efforts, we’ll stay within C++ for the time being.
Since most of PyTorch is built on top of the ATen tensor library, we’ll just call ATen directly when comparing the sustained performance of our custom implementations to PyTorch.
The mini_dnn
playground forms the basis for our tests.
mini_dnn brings supporting boilerplate and contains the ATen-based reference implementation mini_dnn::backend::MatmulReluAten
which simply calls at::matmul and at::relu.
7.1. Blocked Matrix Multiplications
We use the sizes \(N\), \(K\) and \(C\) to describe the matrix-matrix product \(Y = XW\). \(X \in \mathbb{R}^{N \times C}\) is the input batch, \(W \in \mathbb{R}^{C \times K}\) the weight matrix, and \(Y \in \mathbb{R}^{N \times K}\) the output batch. Therefore, \(N\) is the batch size, \(C\) the number of input channels, and \(K\) the number of output channels. Fig. 7.1.1 illustrates these sizes for \(N=6\), \(K=15\) and \(C=8\).
We assume that we can use any suitable memory layout to store the data. Since we will be using small matrix-matrix kernels as our building block, we will block the matrices \(X\), \(W\), and \(Y\) accordingly. Specifically, we choose block sizes \(n_b\), \(k_b\) and \(c_b\) for \(N\), \(K\) and \(C\) respectively. The illustrative example in Fig. 7.1.1 uses \(n_b = 3\) , \(k_b = 5\) and \(c_b = 4\). With \(N_b\), \(K_b\) and \(C_b\) describing the number of blocks per dimension, we get \(N = N_b \cdot b_n\), \(K = K_b \cdot b_k\) and \(C = C_b \cdot b_c\). The example in Fig. 7.1.1 thus uses \(N_b = 2\), \(K_b = 3\), and \(C_b = 2\). We use rank-4 tensors to store our blocked matrices \(X_b\), \(W_b\) and \(Y_b\), i.e., \(X_b \in \mathbb{R}^{N_b \times C_b \times b_c \times b_n}\), \(W_b \in \mathbb{R}^{K_b \times C_b \times b_k \times b_c}\) and \(Y_b = \mathbb{R}^{K_b \times N_b \times b_k \times b_n}\). Here we assume that the rightmost dimension of each tensor has stride 1 and that the following dimensions simply follow contiguously. For example, for \(X_b \in \mathbb{R}^{N_b \times C_b \times b_c \times b_n}\) we get the strides \((C_b \cdot b_c \cdot b_n, \; b_c \cdot b_n, \; b_n, \; 1)\).
Given our blocked memory layout, the actual matrix multiplication is broken into a series of small matrix multiplications. We use three loops over \(K_b\), \(N_b\) and \(C_b\) to implement the blocked matrix multiplication using small matrix kernels:
for( int64_t l_kb = 0; l_kb < l_sizes.kb; l_kb++ ) {
for( int64_t l_nb = 0; l_nb < l_sizes.nb; l_nb++ ) {
for( int64_t l_cb = 0; l_cb < l_sizes.cb; l_cb++ ) {
// TODO: execute small matrix kernel
}
}
}
Fig. 7.1.2 shows the first small matrix kernel, i.e., l_kb=0
, l_nb=0
and l_cb=0
, for our running example shown in Fig. 7.1.1.
The first matrix multiplication computes the partial result \(Y^*_b[0,0] := \text{Matmul} \left( X_b[0,0], W_b[0,0] \right)\).
The cb
loop is our innermost loop in terms of the blocking.
This means that we add the contribution of l_kb=0
, l_nb=0
, l_cb=1
to the partial result \(Y^*_b[0,0]\) next.
As shown in Fig. 7.1.3, we now complete the corresponding block of \(Y_b\):
\(Y_b[0,0] := Y^*_b[0,0] + \text{Matmul} \left( X_b[0,1], W_b[0,1] \right)\).
Next, we compute the partial result \(Y^*_b[0,1] := \text{Matmul} \left( X_b[1,0], W_b[0,0] \right)\) which corresponds to l_kb=0
, l_nb=1
and l_cb=0
.
This situation is illustrated in Fig. 7.1.4.
The remaining small matrix products follow in a similar way.
Tasks
Block the input batch \(X\) and the weight matrix \(W\) in the unit tests, i.e. in the file
MatmulAtenBlocked.test.cpp
in mini_dnn’s unfinishedmini_dnn::backend::MatmulAtenBlocked
implementation. Check that the sizes and strides of the resulting blocked matrices match your expectations.Pass the blocked matrices to
mini_dnn::backend::MatmulAtenBlocked::forward
and finish the forward pass operation by calling ATen for each submatrix, i.e. you have to callat::matmul
\(K_b \cdot N_b \cdot C_b\) times. Return the blocked output batch.In the unit tests, bring the blocked layout back to the one before you blocked the matrices. Make sure your computed result matches the unblocked version of ATen.
Hint
Blocking can be done by first creating a view that splits each dimension. For example, for the input \(X \in \mathbb{R}^{N \times C}\) , we would first create the view \(X_v \in \mathbb{R}^{N_b \times b_n \times C_b \times b_c}\). Next, we can reorder the dimensions using at::permute and finally call at::contiguous to get our desired memory layout.
Hint
ATen assumes row-major order of the matrices involved when calling at::matmul
.
That is, if we want to compute \(C = AB\) for the matrices \(A\) and \(B\) and store them row-major in l_a
and l_b
, we’d implement something similar to at::Tensor l_c = at::matmul(l_a, l_b)
.
Now let us assume that l_a
and l_b
hold the matrices in column-major format as is the case for our blocked layout also shown in Fig. 7.1.1.
In this case we simply compute the “transposed” product, i.e., \(C^T = B^TA^T\) by at::Tensor l_c = at::matmul(l_b, l_a)
7.2. Assembly Kernel
The blocking boilerplate is in place and tested with calls to ATen.
Now we replace the ATen calls with a hand-written small matrix kernel.
The kernel performs the operation \(C \mathrel{+}= AB\) with sizes \(N=K=C=64\).
You can find it in the gemm_neon.tar.xz
archive, which also contains a small example to illustrate its use.
Again, to speed up your development, a template is provided by mini_dnn’s unfinished class mini_dnn::backend::MatmulAsmNeon
.
Your job is to finish the implementation and measure its performance.
Furthermore, OpenMP allows us to use all available cores of a processor. We will parallelize the two outermost blocking loops, i.e., those over \(K_b\) and \(N_b\), using the collapse clause defined in Sec. 2.9.2 of the OpenMP Application Programming Interface. Note that we keep the \(C_b\) loop sequential with respect to OpenMP since we accumulate the results of the small matrix kernels in corresponding blocks of \(Y_b\). This means that there is a read-after-write dependency that prevents a straightforward implementation.
Hint
Use OMP_PLACES={0}:72:1 OMP_NUM_THREADS=72
to spawn 72 threads and pin each of them to a single core of a NVIDIA Grace processor.
You can display the pinning by additionally setting OMP_DISPLAY_AFFINITY=true
.
Tasks
Finish the implementation in
mini_dnn::backend::MatmulAsmNeon::forward
by following the structure of yourmini_dnn::backend::MatmulAtenBlocked::forward
function. However, call the hand-written kernelgemm_asm_asimd_64_64_64
instead of ATen.Test your implementation thoroughly with unit tests. Be sure to include non-square matrices. For example, use \(N=128\) for your batch size, \(K=256\) for output features and \(C=512\) for input features.
Benchmark the performance of your implementation on a single core. Mimic a linear layer of the VGG architecture by using a batch size of 256 and 4,096 input and output features in your performance tests. Report the sustained GFLOPS.
Add an OpenMP parallelization to your implementation. Benchmark and report the measured performance when running on all cores of a node.
7.3. Small JITted GEMMs
The overall structure of using the JITter LIBXSMM for our custom layers is very similar to the previous section.
However, LIBXSMM generates the requested kernel at runtime, which means that the instruction set architecture and kernel shape used are not hardwired into our compiled program.
Since JITting requires some additional boilerplate, mini_dnn provides the unfinished class mini_dnn::backend::MatmulLibxsmm
to get you started.
Again, your job is to finish the implementation and measure its performance.
Tasks
Finish the implementation in
mini_dnn::backend::MatmulLibxsmm::forward
by calling LIBXSMM-generated kernels.Test your implementation thoroughly with unit tests.
Benchmark the performance of your implementation on a single core. Try different sizes for the small matrix kernel. Report the sustained GFLOPS.
Add an OpenMP parallelization to your implementation. Benchmark and report the performance when running on all cores of a node.
7.4. Operator Fusion
Feed-forward neural networks are typically composed of a number of layers. This composition is often static and can be formulated as a computational graph. If the graph is available to us in the backend, we can use the available information to speed up the entire workload, not just an isolated layer. In fact, we implicitly assumed such a situation by using a highly customized data layout for the inputs and outputs of our linear layers. The data layout made it much easier for us to formulate the blocked matrix multiplication. However, such an approach is only viable if we are responsible for the procedure that passes data from one layer to another. Otherwise the conversion overhead associated with our custom data layout would most likely eliminate performance gains made elsewhere.
Note
Let’s shift our focus to the frontend for a moment. PyTorch and TensorFlow both support what is known as eager mode. While this was a somewhat unique feature of PyTorch in the past, eager execution was introduced in TensorFlow 2.0. Eager mode means that operations are executed immediately. This is great for prototyping and debugging, but hurts overall performance. In contrast, when running in graph mode, one would first assemble a computational graph and then execute the operations together. For more information, see the Google AI blog post on eager execution and the PyTorch tutorial on graph mode.
In this final part of the lab, we will assume that our linear layer is immediately followed by a ReLU. A naive approach would compute the result of the linear layer first and then apply the ReLU to the output element-wise. If we assume sufficiently large inputs or weights, the result of the linear layer will only fit in a high cache level or main memory. This means that to apply the ReLU, we have to load all the data back into the registers of our CPU, apply the ReLU, and then write it back. Since there’s very little computation involved in applying a ReLU, the resulting procedure is severely limited by the bandwidth of the particular memory level where the data resides.
A better approach is to apply the ReLU directly to the results of the linear layer while they are still in the L1 cache. The innermost loop over \(C_b\) of our blocked approach simply adds the intermediate results of the small GEMMs to a fixed block of the result tensor \(Y_b\). Thus, if we operate on sufficiently small blocks, this block will remain in the L1 data cache of a core for the entire duration of the \(C_b\) loop. After the loop completes, the partial result of \(Y_b\) is fully computed and we could move on to the next iteration of the \(N_b\) loop, which computes a different partial result. However, instead of continuing with our small matrix multiplications, we now first apply the ReLU to the partial results while the data is still hot. Such an approach is called operator fusion, and in our case it takes advantage of the high L1 bandwidth for the ReLU.
Tasks
Implement the class
mini_dnn::backend::ReluLibxsmm
which simply applies the ReLU to the input tensor in theforward
function. Generate a ReLU kernel through LIBXSMM and use the same blocking as before. Verify and benchmark your implementation!Benchmark the performance of computing a linear layer followed by a ReLU by sequentially applying your
mini_dnn::backend::MatmulLibxsmm.forward
andmini_dnn::backend::ReluLibxsmm.forward
functions. Use a batch size of 256 and 4,096 input and output features in your performance tests.Implement the class
mini_dnn::backend::MatmulReluLibxsmm
which fuses the application of the ReLU into the blocking loops as discussed above. Verify and benchmark your implementation! Use a batch size of 256, and 4,096 input and output features in your performance tests.