9. Tiled Execution IR#

We have assembled the necessary tools to execute efficient tensor operations. This chapter introduces the Tiled Execution Intermediate Representation that describes the execution of tensor operations through primitives operating on subtensors, called tiles. The IR guides the backend implementation and controls when and where primitives are executed. Conceptually, the IR is closely related to tile-based programming models such as Triton, cuTile and Tile IR, Pallas, and TileIR.

9.1. Specification#

Domains and Notation

\[\begin{split}\begin{aligned} D &\in \mathbb{N}^+ \\[2pt] \mathbf{dim\_types} &= (t_0,\ldots,t_{D-1}),\quad t_i \in \{\mathrm{C},\mathrm{M},\mathrm{N},\mathrm{K}\} \\ \mathbf{exec\_types} &= (e_0,\ldots,e_{D-1}),\quad e_i \in \{\mathrm{seq},\,\mathrm{shared},\,\mathrm{prim}\} \\ \mathbf{dim\_sizes} &\in (\mathbb{N}^+)^{D} \\ \mathbf{strides\_{in0}},\mathbf{strides\_{in1}},\mathbf{strides\_{out}} &\in \mathbb{N}^{D} \\[4pt] \mathrm{data\_type} &\in \{\mathrm{FP32},\mathrm{FP64}\} \\ \mathrm{prim\_first} &\in \{\mathrm{None},\mathrm{Zero},\mathrm{ReLU}\} \\ \mathrm{prim\_main} &\in \{\mathrm{None},\mathrm{Copy},\mathrm{GEMM},\mathrm{BRGEMM}\} \\ \mathrm{prim\_last} &\in \{\mathrm{None},\mathrm{ReLU}\} \end{aligned}\end{split}\]

Records

\[\begin{split}\begin{aligned} \mathrm{TEIR\mbox{-}Schedule} &= \langle \mathbf{dim\_types}, \mathbf{exec\_types}, \mathbf{dim\_sizes},\\ & \quad \;\;\; \mathbf{strides\_{in0}}, \mathbf{strides\_{in1}}, \mathbf{strides\_{out}} \rangle \\[4pt] \mathrm{TEIR\mbox{-}Primitives} &= \langle \mathrm{data\_type}, \mathrm{prim\_first}, \mathrm{prim\_main}, \mathrm{prim\_last} \rangle \end{aligned}\end{split}\]

Axis Roles

\[\begin{split}\mathbf m:\{\mathrm{C},\mathrm{M},\mathrm{N},\mathrm{K}\}\to\{0,1\}^3,\qquad \mathbf m(t)= \begin{cases} (1,1,1) & t=\mathrm{C},\\ (1,0,1) & t=\mathrm{M},\\ (0,1,1) & t=\mathrm{N},\\ (1,1,0) & t=\mathrm{K}, \end{cases} \quad\text{(order: }\mathrm{in0}, \mathrm{in1}, \mathrm{out}\text{).}\end{split}\]

Well-Formedness

  • All schedule vectors have length \(D\):

\[\begin{split}\begin{aligned} D & = |\mathbf{dim\_types}|=|\mathbf{exec\_types}|=|\mathbf{dim\_sizes}| \\ & = |\mathbf{strides\_{in0}}|=|\mathbf{strides\_{in1}}|=|\mathbf{strides\_{out}}|. \end{aligned}\end{split}\]
  • If a tensor does not participate in axis \(i\) (the corresponding component of \(\mathbf m(t_i)\) equals \(0\)), its stride at \(i\) MUST be \(0\).

Primitive-Specific Requirements

\[\begin{split}\begin{aligned} &P_{\mathrm{prim}}=\{\, i \mid e_i=\mathrm{prim}\,\}, \\ &P_{\mathrm{C}}=\{\, i \in P_{\mathrm{prim}} \mid t_i=\mathrm{C}\,\}, \quad P_{\mathrm{M}}=\{\, i \in P_{\mathrm{prim}} \mid t_i=\mathrm{M}\,\}, \\ &P_{\mathrm{N}}=\{\, i \in P_{\mathrm{prim}} \mid t_i=\mathrm{N}\,\}, \quad P_{\mathrm{K}}=\{\, i \in P_{\mathrm{prim}} \mid t_i=\mathrm{K}\,\}. \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \textbf{R1:}\;& \bigl(\mathrm{prim\_main}=\mathrm{Copy}\ \lor\ \mathrm{prim\_first}\in\{\mathrm{Zero},\mathrm{ReLU}\}\ \lor\ \mathrm{prim\_last}\in\{\mathrm{ReLU}\}\bigr) \\ & \Rightarrow\ |P_{\mathrm{C}}| + |P_{\mathrm{M}}| + |P_{\mathrm{N}}| \ge 1.\\[4pt] \textbf{R2:}\;& \mathrm{prim\_main}=\mathrm{GEMM} \\ & \Rightarrow\ |P_{\mathrm{C}}| = 0 \ \land\ |P_{\mathrm{M}}| = 1 \ \land\ |P_{\mathrm{N}}| = 1 \ \land\ |P_{\mathrm{K}}| = 1.\\[4pt] \textbf{R3:}\;& \mathrm{prim\_main}=\mathrm{BRGEMM} \\ & \Rightarrow\ |P_{\mathrm{C}}| = 0 \ \land\ |P_{\mathrm{M}}| = 1 \ \land\ |P_{\mathrm{N}}| = 1 \ \land\ |P_{\mathrm{K}}| = 2. \end{aligned}\end{split}\]

Execution Semantics

Execution Types

If \(e_i=\mathrm{prim}\), axis \(i\) is consumed inside the primitive(s). Values \(\mathrm{seq}\) and \(\mathrm{shared}\) denote sequential and shared-memory parallel traversal of axis \(i\) in the schedule. The overall schedule order is determined by the order of all axes with \(e_i \neq \mathrm{prim}\) as they appear in the TEIR-Schedule. Traversal proceeds from the first such axis (outermost) to the last (innermost). No ordering guarantees are imposed between multiple axes marked as \(\mathrm{shared}\).

First/Last-Access Primitives

Primitives \(\mathrm{prim\_first}\) and \(\mathrm{prim\_last}\) define initialization and finalization steps applied to output tiles:

  • \(\mathrm{prim\_first}\) is applied the first time an output tile is accessed in a given schedule.

  • \(\mathrm{prim\_last}\) is applied the last time an output tile is accessed.

9.2. Tensor Operation Configuration#

Section 9.1 contains the formal specification of the Tiled Execution Intermediate Representation (TEIR). TEIR comprises two records: TEIR-Primitives, which specifies the primitives to be executed, and TEIR-Schedule, which defines how these primitives are applied to tiles of the tensors. This section describes the IR from the perspective of a user who configures a tensor operation using TEIR.

Table 9.2.1 TEIR Schedule — Axis Mapping#

Field

Meaning

Domain & constraints

dim_types

Axis roles across tensors (D axes)

{C, M, N, K} per axis

exec_types

Execution policy per axis (D axes)

{seq, shared, prim} per axis

dim_sizes

Positive extent per axis

array[D] of ℕ⁺

strides_in0

Strides of first input tensor

array[D] of ℕ — use 0 on axes not used by in0

strides_in1

Strides of second input tensor

array[D] of ℕ — use 0 on axes not used by in1

strides_out

Strides of output tensor

array[D] of ℕ — use 0 on axes not used by out

Table 9.2.1 provides a concise informal form of TEIR-Schedule. The field dim_types describes whether an axis is part of both inputs and the output (C), one of the inputs and the output (M or N), or only in the two inputs (K). The field exec_types specifies the execution type of each axis. Setting seq results in sequential execution of an axis. Shared memory parallelization is achieved with shared. Axes with type prim are consumed inside the primitives. The remaining fields describe the sizes of the axes in field dim_sizes and the data layout of the input and output tensors in strides_in0, strides_in1, and strides_out.

Table 9.2.2 TEIR Primitives — Primitive Specification#

Field

Meaning

Allowed values

data_type

Data type of inputs and output

{FP32, FP64}

prim_first

First-access primitive

{None, Zero, ReLU}

prim_main

Main primitive

{None, Copy, GEMM, BRGEMM}

prim_last

Last-access primitive

{None, ReLU}

Table 9.2.2 provides a short form of TEIR-Primitives. TEIR-Primitives contains four fields. data_type determines the data type of the input and the output tensors. In addition, up to three primitives can be used in TEIR. The first-access primitive (prim_first) is applied to a tile of the output tensor when it is accessed for the first time. Similarly, the last-access primitive (prim_last) is applied to a tile of the output tensor when accessing it for the last time. The possible types for prim_first and prim_last are:

None

No primitive is executed.

Zero

Zero the output tile.

ReLU

Apply ReLU to the output tile’s values.

By contrast, the main primitive (prim_main) is executed for every valid combination of input and output tiles in the TEIR schedule. The main primitive can have one of four types:

None

No primitive is executed.

Copy

Copy the input tile’s values to the output tile.

GEMM

Multiply two 2D input tiles and add the result to the 2D output tile.

BRGEMM

Perform a BRGEMM operation on two 3D input tiles and add the result to the 2D output tile.