9. GPUs

Most ML training and inference workloads are driven by dedicated accelerators. Accelerators have specialized compute engines for a limited set of uses cases. The acceleration comes at the cost of generality, meaning that the devices perform very well for the targeted applications but poorly in many other cases or not at all.

In this lab we will use a Graphics Processing Unit (GPU) as our first ML accelerator. GPUs originate from the computer graphics and image processing domains. In the past years GPUs have also proven to be extremely well suited for ML workloads. They even have been a key enabler of modern deep learning as a whole. Due to the importance of ML, GPUs were further extended with additional ML-specific features unrelated to their graphics origin. Modern ML-optimized server GPUs even lack the option to connect a display.

../_images/rtx_4070_ti_super.jpg

Fig. 9.1 Photo of the NVIDIA GeForce RTX 4070 Ti SUPER GPU used in this lab.

We will get our feet wet by programming a GeForce RTX 4070 Ti SUPER GPU through PyTorch. The GPU is powered by the Ada Lovelace microarchitecture and features 4th generation tensor cores supporting FP16, BF16, TF32, INT8 and INT4 tensor ops.

The lab is split into four parts: First, in Section 9.1, we will have a look at the basics of offloading tensor math to a CUDA-capable device. Section 9.2 studies the performance of the GPU by executing large matrix-matrix multiplication in different floating point formats. In Section 9.3, we will deploy a vision model on the GPU and study its accuracy and end-to-end performance given different execution modes. Last, Section 9.4 heads for a custom implementation of the vision model in Triton.

9.1. Offloading Tensors

The CUDA ecosystem is tailored to NVIDIA GPUs and features a rich software stack for the acceleration of ML workloads. PyTorch has extensive CUDA semantics through the package torch.cuda. Driven by PyTorch’s lazy execution model, the package makes the GPU acceleration of ML workloads rather simple.

Tasks

  1. Create a small FP32 tensor with values in \((-1,1)\) on the CPU and offload it to the GPU. Apply the ReLU function on the GPU and show that you obtain the correct result by printing the tensor.

  2. Create two BF16 tensors on the CPU, transfer them to the GPU, perform a binary function and print the result.

  3. Illustrate the influence of reduced precision arithmetic on the GPU. Write examples for FP16, BF16 and TF32 where the result differs from FP32. Explain your observations theoretically by considering the numbers’ bit-wise representations.

9.2. GEMMs

This part of the lab benchmarks the execution of General Matrix-Matrix Multiplications (GEMMs) on the GPU. We can expect to reach a performance close to peak if the GEMM is large enough and the involved matrices are readily available in GPU memory. We use the following benchmark settings:

  • All dimensions have the same size.

  • We only consider the following sizes: 512, 1024, 2048, 4096, 8192, 10240. For example, we would first benchmark the multiplication of a 512x512 matrix with a 512x512 matrix. Second comes the multiplication of 1024x1024 matrix with a 1024x1024 matrix etc.

  • All matrices in a single GEMM operation have the same data type.

  • The studied data types are FP64, FP32, TF32, BF16 and FP16.

  • We evaluate the GEMM performance on the GPU and the host CPU.

  • All performance runs are executed for at least 10 seconds, i.e., we repeatedly execute the GEMM when necessary.

  • We only measure the GEMM performance but exclude host-device and device-host data transfers from our measurements.

Tasks

  1. Implement a benchmarking scripts which covers the settings above. The script should report the following metrics:

    • Current parameters (incl. the number of repetitions),

    • Required time for the entire measurement,

    • Observed floating-point throughput (GFLOPS).

  2. Visualize your results.

Hint

Be aware of asynchronous execution on the GPU and add synchronization calls through torch.cuda.synchronize() where necessary.

9.3. MLP-Mixer

We already have a good feeling for offloading our computations to the GPU. It is time to deploy these skills in practice. For this we will use trained MLP-Mixer models and execute them on the ILSVRC 2012 ImageNet-1K dataset. We have two goals:

  1. Benchmark the end-to-end performance of the Mixer B/16 and L/16 models on the CPU and GPU.

  2. Study the throughput and the accuracy of the models when using different floating-point arithmetic.

The archive mlp_mixer.tar.xz contains code to get you started. This means that the entry point mixer.py sets up an appropriate a data loader for the validation data of ImageNet-1K and contains performance benchmarking routines. Further, two ready-to-go implementations of mixer models are provided in models/torchfunc.py and models/torchnn.py. The first implementation (torchfunc) is a functional torch model while the second one (torchnn) extends torch.nn.Module. Both implementations implement the function load_parameters which loads the parameters of trained mixer models.

All data (ImageNet-1K and mixer model parameters) are already downloaded on Murray and available in the directory /mnt/hd1/data.

Tasks

  1. Establish a performance and accuracy (Top-1 and Top-5) baseline of Mixer-B/16 on the CPU. Use a batch size of 64 and FP32 arithmetic. Be aware that the classification of all 50,000 validation images takes about one hour if you are the only user of the CPU.

  2. Enable the functional model (torchfunc) for execution on the GPU. Keep the entry point mixer.py dynamic, i.e., make the CPU/GPU decision at runtime through the parameter config['cuda'].

  3. Adjust the GPU inference setting and document your observations (incl. impacts on the model accuracy). Try at least the following:

    • Different batch sizes,

    • TF32, BF16 and FP16.

  4. Enable the second implementation (torchnn) for GPU inference. How does the implementation perform? What about the L/16 model?

  5. Add GFLOPS as a metric to your analysis. How far are you off the results obtained in Section 9.2?

  6. Try model compilation through torch.compile.

9.4. Triton

We have seen that offloading PyTorch-native operations to a CUDA-capable device is comparably simple. This ease of programming is powered by the CUDA software stack. However, it might happen that we are unsatisfied with the obtained performance or that our targeted workload is simply not covered by PyTorch-native operations. In such a case, we require a new set of skills to accelerate the workload on the GPU.

In this part of the lab, we will have a closer look at Triton for the development of fast GPU kernels. Triton is a language and compiler for parallel programming and has a Python-like programming interface. Currently (24/05), Triton’s compiler has a backend for NVIDIA GPUs. Backends for CPUs and AMD GPUs are under active development.

We target the batched matrix multiplications in the MLP-Mixer network architecture (see Section 9.3) as our workload. Since we have a full view on the problem setting, we know that each matrix multiplication is followed by an additive term (bias). Further, depending on which matrix multiplication of the mixer network is executed, a Gaussian Error Linear Unit (GELU) might follow immediately. For the B/16 mixer model in the torchfunc implementation, the respective operations are given as follows with L denoting the batch size:

  1. Stem: (L, 196, 768) = (L, 196, 768) x (1, 768, 768) + (1, 1, 768)

  2. Token-mixing MLP

    • (L, 384, 768) = GELU( (1, 384, 196) x (L, 196, 768) + (384, 1) )

    • (L, 196, 768) = (1, 196, 384) x (L, 384, 768) + (196, 1)

  3. Channel-mixing MLP

    • (L, 196, 3072) = GELU( (L, 196, 768) x (1, 768, 3072) + (1, 1, 3072) )

    • (L, 196, 768) = (64, 196, 3072) x (1, 3072, 768) + (1, 1, 768)

  4. Head: (L, 1000) = (L, 768) x (768, 1000) + (1, 1000)

We will exploit our knowledge of the network’s architecture by adding the bias term and applying the GELU activation function directly while blocks of the result matrix are still in fast SRAM. This technique is called “operator fusion” (see also Section 7.3): We fuse additive bias terms and the activation function into the matrix multiplication. In summary, we develop our Triton-accelerated MLP-Mixer through the following steps:

  1. Implement a fast batched matrix multiplication kernel in Triton.

  2. Fuse the application of an additive bias into the kernel.

  3. Optionally fuse the GELU activation function into the kernel.

  4. Replace the PyTorch-native operations in the provided torchfunc mixer model of Section 9.3 with calls to the newly developed Triton kernel.

Hint

The Triton documentation contains a set of tutorials to get you started. The Matrix Multiplication tutorial describes the required techniques for a fast baseline. Be aware that our use case is slightly different since we require a batched matrix multiplication where the weight matrix is broadcasted along the batch dimension. An AMD ROCm blog post describes the implementation of an approximate GELU kernel. You can use the described ideas for the fusion of GELU into your own kernel.

Tasks

  1. Implement a Triton-based batched matrix-matrix multiplication kernel. Ensure that your kernel supports at least the B/16 matrix multiplications described above. Verify your kernel by comparing it to torch.matmul in FP32 arithmetic. Benchmark the kernel’s performance.

  2. Fuse the application of bias terms into your kernel. Make this an optional feature. Verify your fused kernel.

  3. Fuse the application of the GELU activation. Make this an optional feature. Verify and benchmark your kernel.

  4. Replace all torch.matmul (and following biases/GELUs) in models/torchfunc.py by your Triton kernel. You may leave the layer norms, skip connections and mean reduction before the head of the network as PyTorch-native ops. Apply your adjusted mixer implementation to the ImageNet-1K validation data and report the end-to-end performance of your implementation. Include at least numbers for B/16 and FP32 arithmetic.