8. Convolutions on Steroids

The examples in Section 7 gave us a solid understanding on how PyTorch’s linear layers may be lowered to JITted machine code. However, at the end of the day, linear layers are typically just large matrix multiplications. Accelerating these is “simple”: one can just call an xGEMM implementation of a BLAS library. Examples for libraries are OpenBLAS, BLIS, cuBLAS or oneMKL.

Now, what should we do for other types of layers? One answer to this is: Write optimized implementations for every single one them! Libraries specifically targeting machine learning do provide many of these implementations to us. Examples are cuDNN or oneDNN. Unfortunately, as already discussed in Section 6, we are out of luck if a specific layer is not supported.

This part optimizes the forward pass of torch.nn.functional.conv2d. We’ll see how we once again may obtain an efficient custom implementation through a single versatile building block: small matrix-matrix multiplications. Of course highly optimized implementations for two-dimensional convolutions are available in all machine learning backends. Yet, it is enlightening to see how one would write such code and possibly compete with a vendor-optimized implementation.

We’ll remain within C++ for this exercise and use a conv2d version of the mini_dnn playground. This version contains supporting boilerplate for 2D convolutions and an ATen reference implementation which simply calls at::conv2d.

A two-dimensional convolution \(\mathrm{Y} = \mathrm{X} * \mathrm{W}\) operates on the rank-4 tensors \(\mathrm{X}\), \(\mathrm{W}\) and \(\mathrm{Y}\). The input is given by \(\mathrm{X} \in \mathbb{R}^{N, C, H, W}\), the weights through \(\mathrm{W} \in \mathbb{R}^{K, C, R, S}\), and the output through \(\mathrm{Y} \in \mathbb{R}^{N, K, P, Q}\). Fig. 8.1 shows an illustration of the respective dimensions for a tiny example.

../_images/conv2d_n3c5k4.svg

Fig. 8.1 Illustration of the dimensions for a two-dimensional convolution operation. Shown is an example where \(H=8\), \(W=12\), \(N=3\), \(R=S=3\), \(C=5\) and \(K=4\), \(P=6\) and \(Q=10\).

In machine learning two-dimensional convolutions are commonly applied to visual data. Thus, dimension \(H\) often refers to the height of the input images in pixels and \(W\) to the respective width. Similarly, \(P\) and \(Q\) typically describe the height and width of the output images. Dimension \(N\) is the batch size, \(C\) the number of input channels and \(K\) the number of output channels. Lastly, \(R\) and \(S\) are the sizes of the convolving kernel.

Hint

Additional details on the dimensions and parameters involved in a two-dimensional convolution are available from the documentation of torch.nn.Conv2d.

Depending on the details of the applied convolution operator, the sizes of the input images might differ from those of the output images, i.e., \(H \ne P\) or \(W \ne Q\), or both. In our implementations we assume the default parameters of torch.nn.functional.conv2d, i.e., bias=None, stride=1, padding=0, dilation=1 and groups=1. In this case \(P = H - R + 1\) and \(Q = W - S +1.\)

8.1. Im2col

Since the “early” efficient machine learning days, the image-to-column (im2col) transformation is used to rewrite convolutions as large matrix-matrix multiplications. This approach has the advantage that it allows us to simply call a high-performing xGEMM implementation for the convolution. Once the result is computed, we can simply change its view and obtain the original data layout. At the same time, using an intermediate data structure and the approach’s reliance on xGEMM also represent biggest disadvantages of the approach. First, we artificially increase our memory requirements through the temporary storage required for the rearranged input data. Second, the convolution operation is isolated through the use of the xGEMM library call. This means that the op’s results end up in a slow memory layer which prohibits operator fusion.

Nevertheless, due to the simplicity of the approach, we’ll set a baseline by using Table 8.1.1’s ATen functions in our first conv2d implementation.

Table 8.1.1 Set of functions which allow us to implement 2D convolutions without calling conv2d directly.

Operator

ATen

Python

im2col

at::im2col

torch.nn.functional.unfold

xGEMM

at::matmul

torch.matmul

Tasks

  1. Implement the class mini_dnn::backend::Conv2dIm2col which uses ATen to perform the operations im2col and SGEMM in its forward function .

  2. Verify your implementation and benchmark its performance.

8.2. Blocking

For our targeted efficient implementation, let’s follow the same approach which we used when formulating the linear layer in Section 7. For now, we get started by setting up blocked data structures for the inputs and filters of the convolution. These blocked data structures allow us to assemble the conv2d operation using small GEMMs. Then we take our blocking for a spin by calling ATen for the resulting small matrix multiplications in the blocked implementation. Once we ironed out all issues and tested our implementation, we’ll move on and infuse high-performance matrix kernels through LIBXSMM. Next, we’ll fuse a ReLU activation into our implementation to accelerate the combined performance of Conv2d layers followed by a ReLU.

In our implementation we block the input channels \(C\) and the output channels \(K\). Similarly to the linear layers in Section 7, \(b_c\) and \(b_k\) refer to the block sizes, and \(C_b\) and \(K_b\) to the number of blocks, i.e., \(C = C_b \cdot b_c\) and \(K = K_b \cdot b_k\). Once again, \(X_b\), \(W_b\) and \(Y_b\) describe our blocked and reordered data. We use a rank-5 tensor to store the input, i.e, \(X_b \in \mathbb{R}^{N \times C_b \times H \times W \times b_c}\). The blocked weights are given by a rank-6 tensor: \(W_b \in \mathbb{R}^{K_b \times C_b \times R \times S \times b_c \times b_k}\). Accordingly, the blocked outputs are stored as a rank-5 tensor: \(Y_b \in \mathbb{R}^{N \times K_b \times P \times Q \times b_k}\).

The obtained blocked data layout allows us to implement the convolution operator through six nested loop where the innermost loop performs small matrix multiplications:

for( int64_t l_n = 0; l_n < l_sizes.n; l_n++ ) {
  for( int64_t l_kb = 0; l_kb < l_sizes.kb; l_k++ ) {
    for( int64_t l_p = 0; l_p < l_sizes.p; l_p++ ) {
       for( int64_t l_cb = 0; l_cb < l_sizes.cb; l_cb++ ) {
         for( int64_t l_r = 0; l_r < l_sizes.r; l_r++ ) {
           for( int64_t l_s = 0; l_s < l_sizes.s; l_s++ ) {
             // TODO: execute small matrix kernel
           }
         }
       }
     }
   }
 }

Tasks

  1. Block the input batch \(X\) and the weight matrix \(W\) in the unit tests of a new backend mini_dnn::backend::Conv2dAtenBlocked. Use the file name Conv2dAtenBlocked.test.cpp for your unit tests.

  2. Pass the blocked tensors to mini_dnn::backend::Conv2dAtenBlocked::forward and implement the forward pass by calling ATen for the small matrix multiplications. This means that you have to call at::matmul to sum the individual contributions of the \((b_k \times Q) = (b_k \times b_c) (b_c \times Q)\) matrix multiplications. Keep in mind that we are writing our efficient implementations assuming column-major storage of all involved small matrices. Make sure that you precisely understand the individual contribution of one of the matrix multiplications to the overall Conv2d operation before tackling the implementation.

  3. In the unit tests bring the blocked layout back to the one before you blocked the matrices. Make sure that your computed result matches ATen’s unblocked version.

8.3. Small JITted GEMMs and Operator Fusion

We are now at similar point for 2D convolutions as we have been for the linear layers in Section 7.2. The remaining procedure is obvious 😎. First, we replace Conv2dAtenBlocked’s ATen by calls to JITted LIBXSMM kernels in a new backend called mini_dnn::backend::Conv2dLibxsmm. Second, we fuse the ReLU into our implementation by implementing the backend mini_dnn::backend::Conv2dReluLibxsmm. Once again the obtained performance is our top priority and we deliver on the promise of efficient convolutions by thoroughly benchmarking our implementations.

Tasks

  1. Implement the function mini_dnn::backend::Conv2dLibxsmm::forward by following the structure of mini_dnn::backend::Conv2dAtenBlocked::forward. Call JITted matrix kernels instead of at::matmul.

  2. Test your implementation thoroughly through respective unit tests.

  3. Benchmark the performance of your implementation on an a single core. Use the following sizes in your tests: \(N=48\), \(H=34\), \(W=34\), \(C=512\), \(K=512\), \(R=3\), \(b_c=128\) and \(b_k=64\). Report the sustained FP32 GFLOPS.

  4. Add an OpenMP parallelization over \(N\) to your implementation. Benchmark and report the performance when running on all compute cores of a node.

  5. Implement the class mini_dnn::backed::Conv2dReluLibxsmm which fuses the application of the ReLU into your loops. Verify and benchmark your implementation. Use the same sizes as for the convolution-only tests.