8. Convolutions on Steroids

The examples in Section 7 gave us a solid understanding of how PyTorch’s linear layers can be lowered to JITted machine code. However, at the end of the day, linear layers are typically just large matrix multiplications. Accelerating these is “easy”: one can just call an xGEMM implementation of a BLAS library. Examples are OpenBLAS, BLIS, cuBLAS or oneMKL.

Now, what should we do for other types of layers? One answer to this is: Write optimized implementations for each and every one of them! Libraries dedicated to machine learning provide us with many of these implementations. Examples are cuDNN or oneDNN. Unfortunately, as discussed in Section 6, we are out of luck if a particular layer is not supported.

This part optimizes the application of a two-dimensional convolution, as done in torch.nn.functional.conv2d. We’ll see how, once again, we can get an efficient custom implementation using a single versatile building block: small matrix-matrix multiplications. Of course, highly optimized implementations for two-dimensional convolutions are available in all machine learning backends. However, it is instructive to see how one would write such code and possibly compete with a vendor-optimized implementation.

We’ll stay in C++ for this exercise and use a conv2d version of the mini_dnn playground. This version includes supporting boilerplate for 2D convolutions and an ATen reference implementation that simply calls at::conv2d.

A two-dimensional convolution \(\mathrm{Y} = \mathrm{X} * \mathrm{W}\) operates on the rank-4 tensors \(\mathrm{X}\), \(\mathrm{W}\) and \(\mathrm{Y}\). The input is given by \(\mathrm{X} \in \mathbb{R}^{N, C, H, W}\), the weights by \(\mathrm{W} \in \mathbb{R}^{K, C, R, S}\), and the output by \(\mathrm{Y} \in \mathbb{R}^{N, K, P, Q}\). Fig. 8.1 shows an illustration of the dimensions for a tiny example.

../_images/conv2d_n3c5k4.svg

Fig. 8.1 Illustration of the dimensions for a two-dimensional convolution operation. Shown is an example where \(H=8\), \(W=12\), \(N=3\), \(R=S=3\), \(C=6\) and \(K=4\), \(P=6\) and \(Q=10\).

In machine learning, two-dimensional convolutions are often applied to visual data. Thus, the dimension \(H\) often refers to the height of the input images in pixels and \(W\) refers to the width. Similarly, \(P\) and \(Q\) typically describe the height and width of the output images. The dimension \(N\) is the batch size, \(C\) the number of input channels and \(K\) the number of output channels. Finally, \(R\) and \(S\) are the sizes of the convolving kernel.

Hint

For more details on the dimensions and parameters involved in a two-dimensional convolution see the torch.nn.Conv2d documentation.

Depending on the details of the applied convolution operator, the sizes of the input images may differ from those of the output images, i.e., \(H \ne P\) or \(W \ne Q\), or both. In our implementations we assume the default parameters of torch.nn.functional.conv2d, i.e., bias=None, stride=1, padding=0, dilation=1 and groups=1. In this case \(P = H - R + 1\) and \(Q = W - S +1.\)

8.1. Im2col

Since the “early” days of efficient machine learning, the image-to-column (im2col) transformation has been used to rewrite convolutions as large matrix-matrix multiplications. This approach has the advantage that we can simply call a high-performance xGEMM implementation for the convolution. Once the result is computed, we can simply change its view and get the original data layout. At the same time, using an intermediate data structure and relying on xGEMM are also the major drawbacks of this approach. First, we artificially increase our memory requirements due to the temporary storage required for the rearranged input data. Second, the convolution operation is isolated by the xGEMM library call. This means that the results of the operation end up in a slow memory layer, which prevents operator fusion.

Nevertheless, due to the simplicity of the approach, we’ll set a baseline by using the ATen functions of Table 8.1.1 in our first conv2d implementation.

Table 8.1.1 Set of functions that allow us to implement 2D convolutions without calling conv2d directly.

Operator

ATen

Python

im2col

at::im2col

torch.nn.functional.unfold

xGEMM

at::matmul

torch.matmul

Tasks

  1. Implement the class mini_dnn::backend::Conv2dIm2col, which uses ATen to perform the operations im2col and SGEMM in its forward function.

  2. Verify your implementation and benchmark its performance.

8.2. Blocking

For our targeted efficient implementation, let’s follow the same approach we used to formulate the linear layer in Section 7. First, we set up blocked data structures for the inputs and filters of the convolution. These blocked data structures allow us to assemble the conv2d operation using small GEMMs. Then we take our blocking for a spin by calling ATen on the resulting small matrix multiplications in the blocked implementation. Once we’ve ironed out all the kinks and tested our implementation, we’ll move on and infuse high-performance matrix kernels via LIBXSMM. Next, we’ll fuse a ReLU activation into our implementation to accelerate the combined performance of the Conv2d layers followed by a ReLU.

In our implementation, we block the input channels \(C\) and the output channels \(K\). Similar to the linear layers in Section 7, \(b_c\) and \(b_k\) refer to the block sizes, and \(C_b\) and \(K_b\) refer to the number of blocks, i.e. \(C = C_b \cdot b_c\) and \(K = K_b \cdot b_k\). Again, \(X_b\), \(W_b\) and \(Y_b\) describe our blocked and reordered data. We use a rank-5 tensor to store the input, i.e. \(X_b \in \mathbb{R}^{N \times C_b \times H \times W \times b_c}\). The blocked weights are given by a rank-6 tensor: \(W_b \in \mathbb{R}^{K_b \times C_b \times R \times S \times b_c \times b_k}\). Accordingly, the blocked outputs are stored as a rank-5 tensor: \(Y_b \in \mathbb{R}^{N \times K_b \times P \times Q \times b_k}\).

The resulting blocked data layout allows us to implement the convolution operator through six nested loops, with the innermost loop performing small matrix multiplications:

for( int64_t l_n = 0; l_n < l_sizes.n; l_n++ ) {
  for( int64_t l_kb = 0; l_kb < l_sizes.kb; l_k++ ) {
    for( int64_t l_p = 0; l_p < l_sizes.p; l_p++ ) {
       for( int64_t l_cb = 0; l_cb < l_sizes.cb; l_cb++ ) {
         for( int64_t l_r = 0; l_r < l_sizes.r; l_r++ ) {
           for( int64_t l_s = 0; l_s < l_sizes.s; l_s++ ) {
             // TODO: execute small matrix kernel
           }
         }
       }
     }
   }
 }

Tasks

  1. Block the input batch \(X\) and the weight matrix \(W\) in the unit tests of a new backend mini_dnn::backend::Conv2dAtenBlocked. Use the file name Conv2dAtenBlocked.test.cpp for your unit tests.

  2. Pass the blocked tensors to mini_dnn::backend::Conv2dAtenBlocked::forward and implement the forward pass by calling ATen for the small matrix multiplications. This means that you need to call at::matmul to sum the individual contributions of the \((b_k \times Q) = (b_k \times b_c) (b_c \times Q)\) matrix multiplications. Note that we write our efficient implementations assuming column-major storage of all involved small matrices. Make sure that you precisely understand the individual contribution of each matrix multiplication to the overall Conv2d operation before your start with the implementation.

  3. In the unit tests, revert the blocked layout back to the one before you blocked the matrices. Make sure your computed result matches the unblocked version of ATen.

8.3. Small JITted GEMMs and Operator Fusion

We are now at a similar point for 2D convolutions as we were for linear layers in Section 7.2. The rest of the procedure is obvious 😎. First, we replace the ATen calls in Conv2dAtenBlocked with calls to JITted LIBXSMM kernels in a new backend called mini_dnn::backend::Conv2dLibxsmm. Second, we fuse the ReLU into our implementation by implementing the backend mini_dnn::backend::Conv2dReluLibxsmm. Again, performance is our top priority, and we deliver on the promise of efficient convolutions by thoroughly benchmarking our implementations.

Tasks

  1. Implement the mini_dnn::backend::Conv2dLibxsmm::forward function by following the structure of mini_dnn::backend::Conv2dAtenBlocked::forward. Call JITted matrix kernels instead of at::matmul.

  2. Test your implementation thoroughly with appropriate unit tests.

  3. Benchmark the performance of your implementation on a single core. Use the following sizes in your tests: \(N=144\), \(H=34\), \(W=34\), \(C=512\), \(K=512\), \(R=3\), \(b_c=128\) and \(b_k=64\). Report the sustained FP32 GFLOPS.

  4. Add an OpenMP parallelization over \(N\) to your implementation. Benchmark and report the performance when running on all compute cores of a node.

  5. Implement the mini_dnn::backed::Conv2dReluLibxsmm class which fuses the application of the ReLU into your loops. Verify and benchmark your implementation. Use the same sizes as for the convolution-only tests.