9. Tiled Execution IR#
We have assembled the necessary tools to execute efficient tensor operations. This chapter introduces the Tiled Execution Intermediate Representation that describes the execution of tensor operations through primitives operating on subtensors, called tiles. The IR guides the backend implementation and controls when and where primitives are executed. Conceptually, the IR is closely related to tile-based programming models such as Triton, cuTile and Tile IR, Pallas, and TileIR.
9.1. Specification#
Domains and Notation
Records
Axis Roles
Well-Formedness
All schedule vectors have length \(D\):
If a tensor does not participate in axis \(i\) (the corresponding component of \(\mathbf m(t_i)\) equals \(0\)), its stride at \(i\) MUST be \(0\).
Primitive-Specific Requirements
Execution Semantics
- Execution Types
If \(e_i=\mathrm{prim}\), axis \(i\) is consumed inside the primitive(s). Values \(\mathrm{seq}\) and \(\mathrm{shared}\) denote sequential and shared-memory parallel traversal of axis \(i\) in the schedule. The overall schedule order is determined by the order of all axes with \(e_i \neq \mathrm{prim}\) as they appear in the TEIR-Schedule. Traversal proceeds from the first such axis (outermost) to the last (innermost). No ordering guarantees are imposed between multiple axes marked as \(\mathrm{shared}\).
- First/Last-Access Primitives
Primitives \(\mathrm{prim\_first}\) and \(\mathrm{prim\_last}\) define initialization and finalization steps applied to output tiles:
\(\mathrm{prim\_first}\) is applied the first time an output tile is accessed in a given schedule.
\(\mathrm{prim\_last}\) is applied the last time an output tile is accessed.
9.2. Tensor Operation Configuration#
Section 9.1 contains the formal specification of the Tiled Execution Intermediate Representation (TEIR). TEIR comprises two records: TEIR-Primitives, which specifies the primitives to be executed, and TEIR-Schedule, which defines how these primitives are applied to tiles of the tensors. This section describes the IR from the perspective of a user who configures a tensor operation using TEIR.
Field |
Meaning |
Domain & constraints |
---|---|---|
|
Axis roles across tensors (D axes) |
{ |
|
Execution policy per axis (D axes) |
{ |
|
Positive extent per axis |
|
|
Strides of first input tensor |
|
|
Strides of second input tensor |
|
|
Strides of output tensor |
|
Table 9.2.1 provides a concise informal form of TEIR-Schedule.
The field dim_types
describes whether an axis is part of both inputs and the output (C
), one of the inputs and the output (M
or N
), or only in the two inputs (K
).
The field exec_types
specifies the execution type of each axis.
Setting seq
results in sequential execution of an axis.
Shared memory parallelization is achieved with shared
.
Axes with type prim
are consumed inside the primitives.
The remaining fields describe the sizes of the axes in field dim_sizes
and the data layout of the input and output tensors in strides_in0
, strides_in1
, and strides_out
.
Field |
Meaning |
Allowed values |
---|---|---|
|
Data type of inputs and output |
{ |
|
First-access primitive |
{ |
|
Main primitive |
{ |
|
Last-access primitive |
{ |
Table 9.2.2 provides a short form of TEIR-Primitives.
TEIR-Primitives contains four fields.
data_type
determines the data type of the input and the output tensors.
In addition, up to three primitives can be used in TEIR.
The first-access primitive (prim_first
) is applied to a tile of the output tensor when it is accessed for the first time.
Similarly, the last-access primitive (prim_last
) is applied to a tile of the output tensor when accessing it for the last time.
The possible types for prim_first
and prim_last
are:
None
No primitive is executed.
Zero
Zero the output tile.
ReLU
Apply ReLU to the output tile’s values.
By contrast, the main primitive (prim_main
) is executed for every valid combination of input and output tiles in the TEIR schedule.
The main primitive can have one of four types:
None
No primitive is executed.
Copy
Copy the input tile’s values to the output tile.
GEMM
Multiply two 2D input tiles and add the result to the 2D output tile.
BRGEMM
Perform a BRGEMM operation on two 3D input tiles and add the result to the 2D output tile.