7. Primitives#

Our overall plan is to build a tensor compiler based on a toolbox containing a few primitive operations. The primitives will play a special role in the construction of our compiler. In particular, we will instantiate the primitives with highly optimized kernels tailored to our target hardware, and not strive for portability at that level. Be aware of the implications of this choice: We are deliberately breaking with portability and automation, a core idea of compilers, and opting instead for hand-crafted optimizations. Therefore, we must be very careful in choosing our primitives: The fewer and less complex primitives we have, the better.

7.1. Background#

We distinguish three terms:

  1. A primitive is a high-level description of an operation that takes one or more input matrices and produces a single output. It has a set of parameters that allow the user to control the details of the operation.

  2. A kernel is a function that instantiates a primitive for some of the primitive’s parameters, while others are passed to the kernel as function parameters at runtime.

  3. A microkernel is a basic building block of a kernel and consists of linear code, that is code without loops, and does the heavy lifting through vector and matrix instructions.

Table 7.1.1 Summary of the primitives used in our tensor compiler.#

Primitive

Description

Hardcoded Parameters

Runtime Parameters

BRGEMM

\(C \mathrel{+}= \sum_i A_i B_i\)

  • Matrix dimension sizes (M,N,K)

  • Batch size

  • Column- or row-major A, B and C

  • Datatype

  • Pointers to A, B and C

  • Leading dimensions of A, B and C

  • Batch strides for A and B

Unary

\(B := \text{op}(A)\)

  • Operation: Zero, Identity, ReLU

  • #Rows, #Columns

  • Column-major A

  • Column- or row-major B

  • Datatype

  • Pointer to A and B

  • Leading dimensions of A and B

Table 7.1.1 summarizes the primitives that we will use in our tensor compiler. The BRGEMM primitive is the most complex and critical to the compiler’s overall performance. It takes a batch of matrices \(A_i\) and a batch of matrices \(B_i\) as inputs and computes the batch-reduce matrix-matrix multiplication operation \(C \mathrel{+}= \sum_i A_i B_i\). The class of unary primitives includes the zero primitive, the identity primitive and the Rectified Linear Unit (ReLU) primitive. These primitives have a similar structure in that they compute \(B:=0\), \(B:=A\), and \(B:=\text{ReLU}(A)\) respectively.

When specifying the details of a primitive, we distinguish between two types of parameters:

Hardcoded parameters

These are parameters set when a kernel is instantiated from a primitive description. This means that primitive instantiations with different hardcoded parameters result in different kernels. Thus, the scope of an individual kernel is limited. The upside is that we know these parameters before executing a kernel. This enables us to optimize each kernel for its specific configuration without considering other cases.

Runtime parameters

These are parameters that are not known until a kernel is executed. Therefore, these parameters may change from one kernel execution to the next. We consider pointers to the data on which the kernels are executed, as well as strides, to be runtime parameters.

This section discusses the efficient design of kernels that instantiate the primitives summarized in Table 7.1.1. In our discussion, we use a set of hardcoded parameter configurations, write all kernels in assembly language, and discuss the basic building blocks that comprise a typical kernel. Therefore, this section focuses on structures that streamline kernel development. Combining this knowledge with that on just-in-time (JIT) code generation, discussed in Section 8, allows us to generate efficient kernels that cover a wide range of hardcoded primitive parameters on the fly.

7.2. Batch-Reduce GEMM#

We begin our discussion with the batch-reduce GEMM (BRGEMM) primitive. This primitive will be the main driver of performance in our tensor compiler. Therefore, achieving high performance for BRGEMM kernels is paramount. As shown in Table 7.1.1, the BRGEMM primitive has a number of hardcoded parameters, but leaves the matrix pointers and strides as runtime parameters.

Listing 7.2.1 BRGEMM runtime interface. The matrix dimensions M, N, and K, the batch sizes, and datatype are hardcoded parameters.#
/**
 * @brief Batch-reduce GEMM that computes: C+=sum(Ai*Bi) over a batch.
 * @param a           Pointer to first of a batch of A matrices.
 * @param b           Pointer to first of a batch of B matrices.
 * @param c           Pointer to C matrix.
 * @param ld_a        Leading dimension of A.
 * @param ld_b        Leading dimension of B.
 * @param ld_c        Leading dimension of C.
 * @param br_stride_a Stride (in elements, not bytes) between A matrices.
 * @param br_stride_b Stride (in elements, not bytes) between B matrices.
 **/
void brgemm( void    const * a,
             void    const * b,
             void          * c,
             int64_t         ld_a,
             int64_t         ld_b,
             int64_t         ld_c,
             int64_t         br_stride_a,
             int64_t         br_stride_b );

Listing 7.2.1 shows the signature of the function brgemm, which will serve as our BRGEMM interface at runtime. Our goal in this section is to derive a structured approach that allows us to instantiate BRGEMM kernels from the hardcoded parameters.

Table 7.2.1 Hardcoded parameters of the discussed BRGEMM kernels. Given are the datatype (Dtype), the format of the three matrices (A, B and C), which can be either column-major (C) or row-major (R), the matrix dimension sizes (M, N and K), and the batch size (BS).#

ID

Dtype

A

B

C

M

N

K

BS

0

FP32

C

C

C

16

6

1

1

1

FP32

C

C

C

16

6

64

1

We will do this by writing fast Neon code that instantiates the BRGEMM primitive for a few sets of hardcoded parameters. The parameters of the examined kernels are listed in Table 7.2.1 and will be covered step by step in the following sections.

7.2.1. Neon Microkernel#

Microkernels are the basis of every BRGEMM kernel. Depending on the hardcoded parameters, a BRGEMM kernel is built around one or more microkernels. A microkernel updates an accumulator and contains only linear code, that is code without branches. An accumulator is a submatrix of the result matrix C and is held as a whole in the vector registers.

In general, we choose the accumulator to be as large as possible and keep it in the vector registers until we have finished computing that submatrix of C. When we are done, we store the accumulator in memory. A large accumulator maximizes the distance between instructions that update data in the same vector registers, thus avoiding read-after-write conflicts. Always processing an accumulator to completion avoids repeatedly loading and storing the matrix C. In the literature, this approach is also called register blocking.

Our first kernel has the hardcoded parameters of setting 0 in Table 7.2.1. The kernel computes the FP32 matrix-matrix product C+=AB of a 16x1 matrix A and a 1x6 matrix B, and adds the result to the 16x6 matrix C. In other words, we compute the outer product of a 16-element vector with a 6-element vector and add the result to the matrix C. This outer product formulation will be the basis of our microkernels. More specifically, this set of parameters was chosen to be consistent with many “standard” microkernels, which use 24 out of 32 vector registers for the accumulator.

Listing 7.2.2 Microkernel used in the gemm_neon_16_6_1 kernel that computes C+=AB for column-major matrices A, B, and C with M=16, N=6, and K=1.#
 39// hold addresses to A, B, C in work registers
 40mov x7, x0 // A
 41mov x8, x1 // B
 42mov x9, x2 // C
 43
 44// convert strides to bytes
 45lsl x3, x3, #2 // stride of A (unused)
 46lsl x4, x4, #2 // stride of B
 47lsl x5, x5, #2 // stride of C
 48
 49/*
 50 * Part 1:
 51 * Load 16*6 accumulator.
 52 */
 53ld1 { v0.4s,  v1.4s,  v2.4s,  v3.4s}, [x9]
 54add x9, x9, x5
 55ld1 { v4.4s,  v5.4s,  v6.4s,  v7.4s}, [x9]
 56add x9, x9, x5
 57ld1 { v8.4s,  v9.4s, v10.4s, v11.4s}, [x9]
 58add x9, x9, x5
 59ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9]
 60add x9, x9, x5
 61ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x9]
 62add x9, x9, x5
 63ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x9]
 64mov x9, x2
 65
 66/*
 67 * Part 2:
 68 * Stream A and B.
 69 * Execute fused-multiply-adds (FMAs).
 70 */
 71// load 16 values of A
 72ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7]
 73
 74// load first value of B
 75// each value is mutliplied by 16 values of A
 76ldr s28, [x8]
 77add x8, x8, x4
 78
 79// perform the fmas
 80fmla  v0.4s, v24.4s, v28.s[0]
 81fmla  v1.4s, v25.4s, v28.s[0]
 82fmla  v2.4s, v26.4s, v28.s[0]
 83fmla  v3.4s, v27.4s, v28.s[0]
 84
 85// load second value of B
 86ldr s29, [x8]
 87add x8, x8, x4
 88
 89// perform the fmas
 90fmla  v4.4s, v24.4s, v29.s[0]
 91fmla  v5.4s, v25.4s, v29.s[0]
 92fmla  v6.4s, v26.4s, v29.s[0]
 93fmla  v7.4s, v27.4s, v29.s[0]
 94
 95// load third value of B
 96ldr s30, [x8]
 97add x8, x8, x4
 98
 99// perform the fmas
100fmla  v8.4s, v24.4s, v30.s[0]
101fmla  v9.4s, v25.4s, v30.s[0]
102fmla v10.4s, v26.4s, v30.s[0]
103fmla v11.4s, v27.4s, v30.s[0]
104
105// load fourth value of B
106ldr s31, [x8]
107add x8, x8, x4
108
109// perform the fmas
110fmla v12.4s, v24.4s, v31.s[0]
111fmla v13.4s, v25.4s, v31.s[0]
112fmla v14.4s, v26.4s, v31.s[0]
113fmla v15.4s, v27.4s, v31.s[0]
114
115// load fifth value of B
116ldr s28, [x8]
117add x8, x8, x4
118
119// perform the fmas
120fmla v16.4s, v24.4s, v28.s[0]
121fmla v17.4s, v25.4s, v28.s[0]
122fmla v18.4s, v26.4s, v28.s[0]
123fmla v19.4s, v27.4s, v28.s[0]
124
125// load sixth value of B
126ldr s29, [x8]
127add x8, x8, x4
128
129// perform the fmas
130fmla v20.4s, v24.4s, v29.s[0]
131fmla v21.4s, v25.4s, v29.s[0]
132fmla v22.4s, v26.4s, v29.s[0]
133fmla v23.4s, v27.4s, v29.s[0]
134
135/*
136 * Part 3:
137 * Store 16*6 accumulator.
138 */
139st1 { v0.4s,  v1.4s,  v2.4s,  v3.4s}, [x9]
140add x9, x9, x5
141st1 { v4.4s,  v5.4s,  v6.4s,  v7.4s}, [x9]
142add x9, x9, x5
143st1 { v8.4s,  v9.4s, v10.4s, v11.4s}, [x9]
144add x9, x9, x5
145st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9]
146add x9, x9, x5
147st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x9]
148add x9, x9, x5
149st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x9]

Listing 7.2.2 shows a possible implementation of the microkernel in assembly code. In general, a microkernel consists of three parts:

  1. Load the accumulator from memory into vector registers (lines 49-64).

  2. Stream subcolumns of A and subrows of B into the remaining vector registers, and execute fused multiply-add vector instructions to update the accumulator with the outer product of the subcolumns and subrows (lines 66-133).

  3. Store the accumulator back into memory (lines 135-149).

The code for loading the accumulator in lines 49-64 uses 64-byte loads. This means that we use LD1 (multiple structure) to load 16 values per of C per instruction. After each load, the scratch register X9 is updated. Initially, in line 53, X9 holds the address of the first value of C. Then, in line 54, we increment the value in X9 by the stride in bytes from one column of C to the next. This means that X9 holds the address of the second column of C in line 55, the address of the third column of C in line 57, and so on. After the six loads, the vector registers V0-V23 hold the 16x6 accumulator. The mov x9, x2 instruction in line 64 resets the scratch register X9 so that it again holds the address to the first entry of B. Similarly, the code in lines 135-149 stores the accumulator in memory. The only difference is that we use ST1 (multiple structure) to store it in 64-byte chunks.

The code in lines 66-133 loads the A and B values and performs the fused multiply-add instructions. In line 72, we use a 128-byte load to put all 16 FP32 values of A into vector registers V24-V27. Then the ldr s28, [x8] instruction in line 76 performs a four-byte load to put the first value of B into the lower 32 bits of register V28. Line 77 increments the address in X8 by the B stride in bytes. Thus, after the increment, register X8 holds the address to the second column of B. The FMLA (by element) instructions in lines 80-83 multiply the 16 A values by the single B value and add the result to vector registers V0-V3. The rest of the code in lines 85-133 is similar; we keep the sixteen A values in registers V24-V27, but always load a new value of B so that we can perform four additional fused multiply-add instructions. In summary, the code computes the outer product of the 16 A values with the 6 B values and updates the entire 16x6 accumulator in registers V0-V23.

Benchmarking the kernel on an NVIDIA Grace core, we get a performance of 23.3 GFLOPS. This is 22% of the 105.6 GFLOPS execution throughput we benchmarked in Table 6.1.2. The main reason for the performance discrepancy is the low arithmetic intensity of the kernel. Specifically, we only perform \(2 \cdot M \cdot N \cdot K = 192\) FP32 operations, but have to load and store the entire accumulator. In detail, we load the 16-element matrix A (64 bytes), the 6-element matrix B (24 bytes), and load and store the 16x6 matrix C (384+384 bytes). This means that we perform only 0.2 FP32 operations for each transferred byte.

Section 7.2.2 shows that kernels with sufficiently large K dimensions can overcome this problem, since the accumulator of our microkernel remains in the vector registers for all K updates.

7.2.2. Loop over K#

The kernel discussed in Section 7.2.1 can be extended to support arbitrary K dimensions. Essentially, we simply loop over the microkernel and adjust the addresses for the A and B loads in each iteration.

Listing 7.2.3 Loop over K using a 16x6 microkernel. The loop is part of the kernel gemm_neon_16_6_64, which computes C+=AB for column-major matrices A, B and C with M=16, N=6, and K=64.#
 66    mov x10, #64
 67loop_k:
 68    /*
 69     * Part 2:
 70     * Stream A and B.
 71     * Execute fused-multiply-adds (FMAs).
 72     */
 73    // load 16 values of A
 74    ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7]
 75
 76    // prep A ptr for next iteration
 77    add x7, x7, x3
 78
 79    // load first value of B
 80    // each value is mutliplied by 16 values of A
 81    ldr s28, [x8]
 82    add x8, x8, x4
 83
 84    // perform the fmas
 85    fmla  v0.4s, v24.4s, v28.s[0]
 86    fmla  v1.4s, v25.4s, v28.s[0]
 87    fmla  v2.4s, v26.4s, v28.s[0]
 88    fmla  v3.4s, v27.4s, v28.s[0]
 89
 90    // load second value of B
 91    ldr s29, [x8]
 92    add x8, x8, x4
 93
 94    // perform the fmas
 95    fmla  v4.4s, v24.4s, v29.s[0]
 96    fmla  v5.4s, v25.4s, v29.s[0]
 97    fmla  v6.4s, v26.4s, v29.s[0]
 98    fmla  v7.4s, v27.4s, v29.s[0]
 99
100    // load third value of B
101    ldr s30, [x8]
102    add x8, x8, x4
103
104    // perform the fmas
105    fmla  v8.4s, v24.4s, v30.s[0]
106    fmla  v9.4s, v25.4s, v30.s[0]
107    fmla v10.4s, v26.4s, v30.s[0]
108    fmla v11.4s, v27.4s, v30.s[0]
109
110    // load fourth value of B
111    ldr s31, [x8]
112    add x8, x8, x4
113
114    // perform the fmas
115    fmla v12.4s, v24.4s, v31.s[0]
116    fmla v13.4s, v25.4s, v31.s[0]
117    fmla v14.4s, v26.4s, v31.s[0]
118    fmla v15.4s, v27.4s, v31.s[0]
119
120    // load fifth value of B
121    ldr s28, [x8]
122    add x8, x8, x4
123
124    // perform the fmas
125    fmla v16.4s, v24.4s, v28.s[0]
126    fmla v17.4s, v25.4s, v28.s[0]
127    fmla v18.4s, v26.4s, v28.s[0]
128    fmla v19.4s, v27.4s, v28.s[0]
129
130    // load sixth value of B
131    ldr s29, [x8]
132    add x8, x8, x4
133
134    // prep B ptr for next iteration
135    add x1, x1, #4
136    mov x8, x1
137
138    // perform the fmas
139    fmla v20.4s, v24.4s, v29.s[0]
140    fmla v21.4s, v25.4s, v29.s[0]
141    fmla v22.4s, v26.4s, v29.s[0]
142    fmla v23.4s, v27.4s, v29.s[0]
143
144    // decrement loop counter
145    sub x10, x10, #1
146    // perform more iterations if not zero
147    cbnz x10, loop_k

Listing 7.2.3 shows an implementation that uses the microkernel discussed in Section 7.2.1 to implement configuration 1 in Table 7.2.1, that is the GEMM C+=AB with M=16, N=6, K=64, where all matrices are stored in column-major order.

Comparing the two implementations, very little has changed. The first change is the added loop logic in lines 66-67 and 144-147, which consists of a loop counter in X10 initialized with 64 in line 66. After one iteration, the counter is decremented by 1 (line 145) and if it is not zero, the next iteration is performed (line 147).

The second change prepares the scratch registers, which hold the addresses to A and B, for the next loop iteration. Specifically, all 16 A values used in a loop iteration are loaded in line 74. Thus, in line 77, we prepare the A address held in X7 for the next loop iteration. This is done by adding the stride from one A column to the next to the current address in X7. The stride is stored in bytes in register X3.

In addition, lines 134-136 set the B address for the next iteration. To do this, we first update the value in register X1, which corresponds to the current row of B. By adding the value 4, we get the address of the next row. Next, in line 136, we copy this value into the scratch register X8, which is then used and incremented by the next loop iteration in lines 79-82, 90-92, 100-102, 110-112, 120-122, and 130-132.

Running this kernel on an NVIDIA Grace core, we get a performance of 87.3 FP32 GFLOPS which is a 3.7x speedup over the K=1 kernel and 82.7% of the microbenchmarked FP32 performance. The large improvement is due to the higher number of FP32 operations relative to the data transferred. In detail, we load the \(M \times N = 16 \times 64\) matrix A (4096 bytes), load the \(K \times N = 64 \times 6\) matrix B (1536 bytes), and load and store the \(M \times N= 16 \times 6\) matrix C (384+384 bytes). We perform \(2\cdot M \cdot N \cdot K = 12288\) FP32 operations. So we perform 3.1 FP32 operations for each transferred byte.

7.3. Unary Primitives#

This section discusses kernels that implement the three unary primitives listed in Table 7.1.1. All three primitives process the output matrix element by element. The identity and ReLU primitives also process the input matrix on a per-element basis. The zero primitive simply sets the output matrix to zero. Compared to the BRGREMM primitive, which performs data reduction with respect to the K dimension and the batch-reduce dimension, this element-wise processing simplifies kernel implementation.

Listing 7.3.1 Unary runtime interface. The matrix dimensions M and N, the datatype, and the data layout of B (column- or row-major) are hardcoded parameters.#
/**
 * @brief Unary kernel that computes: B:=op(A).
 * @param a    Pointer to column-major matrix A, nullptr if zero kernel.
 * @param b    Pointer to matrix B.
 * @param ld_a Leading dimension of A, 0 if zero kernel.
 * @param ld_b Leading dimension of B.
 **/
void unary( void const * a,
            void       * b,
            int64_t      ld_a,
            int64_t      ld_b );

Listing 7.3.1 shows the signature of the unary function, which serves as the interface for all unary kernels. For the zero kernel, B is set to zero and the function parameters a and ld_a are expected to be zero.

7.3.1. Neon Transpositions#

First, we will discuss the implementation of a row-major output matrix B for identity and ReLU kernels. In other words, we must implement the transposition of the column-major input matrix A.

Listing 7.3.2 Identity kernel that transposes a 4x4 block containing word-sized data.#
34// hold addresses to A and B in work registers
35mov x4, x0 // A
36mov x5, x1 // B
37
38// convert strides to bytes
39lsl x2, x2, #2 // stride of A
40lsl x3, x3, #2 // stride of B
41
42/*
43 * Part 1:
44 * Load 4x4 block of A.
45 */
46ldr q0, [x4]
47add x4, x4, x2
48ldr q1, [x4]
49add x4, x4, x2
50ldr q2, [x4]
51add x4, x4, x2
52ldr q3, [x4]
53
54/*
55 * Part 2:
56 * Transpose 4x4 block.
57 */
58trn1 v4.4s, v0.4s, v1.4s
59trn2 v5.4s, v0.4s, v1.4s
60trn1 v6.4s, v2.4s, v3.4s
61trn2 v7.4s, v2.4s, v3.4s
62
63zip1  v8.2d, v4.2d, v6.2d
64zip1  v9.2d, v5.2d, v7.2d
65zip2 v10.2d, v4.2d, v6.2d
66zip2 v11.2d, v5.2d, v7.2d
67
68/*
69 * Part 3:
70 * Store 4x4 block of A into B.
71 */
72str q8, [x5]
73add x5, x5, x3
74str q9, [x5]
75add x5, x5, x3
76str q10, [x5]
77add x5, x5, x3
78str q11, [x5]
../_images/trans_4_4.svg

Fig. 7.3.1 Illustration of an identity kernel that transposes a 4x4 matrix.#

In Neon, the TRN1, TRN2, ZIP1, and ZIP2 instructions can transpose a 4x4 block containing 32-bit elements. The code performing the transposition is shown in Listing 7.3.2. The general structure of this approach is illustrated in Fig. 7.3.1. First, the block is loaded into vector registers V0-V3 in column-major order (lines 46-52).

We can use TRN1 to combine the elements of two vector registers with even IDs. The elements in V4 are obtained from V0 and V1 (line 58), and the elements in V6 are obtained from V2 and V3 (line 60). We use TRN2 to combine the elements with odd IDs. The elements in V5 are obtained by applying TRN2 to V0 and V1 (line 59), and the elements in V7 are obtained by applying TRN2 to V2 and V3 (line 61).

At this point, the vector registers V4-V7 contain the targeted result in two-element chunks. For instance, the first column of the desired 4x4 B block must contain elements with IDs 0, 4, 8, and 12. Elements 0 and 4 are in the lower 64 bits of V4 and elements 8 and 12 are in the lower 64 bits of V6. We can use ZIP1 to combine the lower 64 bits of V4 and V6 and write the first column of the transposed 4x4 block to register V8 (line 63). Combining the lower halves of V5 and V7 with ZIP2 gives us the second column in register V9 (line 64). Similarly, we use ZIP2 to combine the upper 64 bits of two vector registers, obtaining the third and fourth columns of the transposed block in registers V10 and V11 (lines 65 and 66).

Finally, we store vectors V8 through V11 (lines 72-78) to complete the transposition.