Einsum Trees

The einsum summation convention, or simply einsum, is a popular formalism for tensor expressions. It is available in many scientific computing and machine learning libraries, including NumPy, PyTorch, TensorFlow and JAX. The beauty of einsum is that it allows us to formulate the contraction of multiple high-dimensional tensors in a simple and concise way. While einsum is widely used in user-facing frontends, many backends quickly lower it to a set of binary tensor contractions that are individually optimized without considering the big picture. In recent work, we have introduced the concept of einsum trees, which act as a high-level intermediate representation (IR) when compiling einsum expressions. Einsum trees provide a global view of the problem of contracting multiple input tensors to obtain a single output tensor. In this post, we cover the basics of how the IR works and how thinking in terms of einsum trees simplifies the optimization of high-dimensional dense linear algebra.

Contraction Trees

We get started by discussing simple contraction trees that have only few input tensors. One of the simplest examples is that of a matrix-matrix multiplication: Given the matrix AA and the matrix BB, we may want to compute the matrix-matrix product C=ABC=AB. This formula already encodes quite some information on the operation: We obtain a value of the output matrix CC by computing the dot-product of a row of AA and a column of BB. In index notation we could write: Cij=kAikBkjC_{ij} = \sum_k A_{ik} B_{kj}. Keeping only the multi-indices we can shorten the equation to ik,kj->ij. Here, we assume that an index appearing in one or more inputs but not the output tensor is summed over. In the matrix-matrix multiplication example, index k is summed over. Further, the arrow -> separates the indices of the input tensors from those of the output tensor. We can also write the expression ik,kj->ij as a binary tree with the two leaf nodes ik and kj, and the root node ij:

ij
├─ kj
└─ ik

The notation is extended easily to more complex examples. For example, we could shorten the batched matrix-matrix multiplication Clij=kAlikBlkjC_{lij} = \sum_k A_{lik} B_{lkj} to lik,lkj->lij, or also write it is as a tree:

lij
├─ lkj
└─ lik

Even more interesting are cases, where more than two input tensors have to be contracted. For example, we might want to compute E=(AB)CE=(AB)C with input matrices AA, BB and CC. We can write the expression in index notation in two steps. First, an intermediate matrix DD is computed which is then multiplied with CC to obtain the output matrix EE:

Dij=kAikBkjEil=jDijCjlD_{ij} = \sum_k A_{ik} B_{kj} \\ E_{il} = \sum_{j} D_{ij} C_{jl}

Once again, we can shorten the example by writing [ik,kj->ij],jl->il or as a tree where the three input matrices are leaf nodes:

il
├─ jl
└─ ij
   ├─ kj
   └─ ik

Intermediate Representation

The intermediate representation Einsum Tree IR relies on a few additional properties. First, it assumes that all tensors are stored in general row-major format. This means that rightmost dimension in the multi-index of a tensor has unit stride and that the tensor has a contiguous memory format. For example, in the case of the rank-3 tensor AlikA_{lik} or short lik, the dimensions may have sizes l=2|l|=2, i=3|i|=3 and k=4|k|=4. Now, the assumed general row-major format means that index k has unit stride, index i has stride 4, and index l has stride 34=123 \cdot 4 = 12.

Second, the concept of a row-major memory layout implies that we can arrange the data of a tensor in different ways. For the lik example, we could alternatively store the data as lki, ilk, ikl, kil, or as kli. The IR supports changing the memory layout of the tensors through permutation nodes. Assuming that we would like to store the data of tensor lik as ikl, we simply write lik->ikl.

Consequently, an IR tree can now be transformed by inserting permutation nodes. For a batched matrix-matrix multiplication example, we could, for example, permute one of the input tensors before computing the contraction: [lik->ikl],lkj->lij. Written as a tree, we obtain:

lij
├─ lkj
└─ ikl
   └─ lik

Lowering

We now discuss our approach to lowering binary tensor contractions, i.e. bringing the contractions into a form that can be executed in hardware.

A simple approach to lowering is nested loops, where the innermost loop performs scalar operations. Using the expression trus,pqtu->pqrs as an example, we could write code similar to the following:

for p in |p|
  for q in |q|
    for r in |q|
      for s in |s|
        for t in |t|
          for u in |u|
            out[p][q][r][s] += in0[t][r][u][s] * in1[p][q][t][u]

If translated directly into machine code, this code will have poor performance due to missing optimizations for parallelization and data reuse. Instead, we could follow standard optimization steps and perform sophisticated transformations to obtain fast code. Example steps include loop reordering, loop tiling, loop fusion, loop unrolling, packing, and unpacking.

All of this sounds difficult to achieve in software. We still lack the simplicity promised by thinking in terms of einsum trees. This is because our approach to lowering einsum trees is primitive-based, i.e. we lower to primitives instead of scalar operations. In particular, for the trus,pqtu->pqrs example, we can loop over matrix-matrix multiplications:

for r in |q|
  for t in |t|
    GEMM( A   = in0[t][r],
          B   = in1[0][0][t],
          C   = out[0][0][r],
          m   = |s|,
          n   = |p|*|q|,
          k   = |u|,
          ldA = |s|,
          ldB = |t| * |u|,
          ldC = |r| * |s| )

In the pseudocode, the naming of the GEMM parameters follows the convention of the BLAS GEMM routines. We see that many of the low-level optimizations, such as register blocking or vectorization, can now be done by the primitive-providing library.

Of course, there will be many cases where our primitive-based mapping to GEMMs fails when working with arbitrary tensor contractions. This means that we need a simple way to identify GEMMs in tensor contractions, so that we can guide our tree-level optimization accordingly. Here, thinking in terms of einsum becomes incredibly helpful. We need to identify three blocks in the einsum string. One block corresponding to m in the GEMM call, one corresponding to n, and one corresponding to k. This boils down to finding the following pattern in the einsum expression:

[...]K[...]M,[...]N[...]K->[...]N[...]M

In the example, we have M=s, N=pq and K=u, i.e. we can rewrite the contraction as trKM,NtM->NrM. This allows us to clearly see where the GEMM operates in the tensors.

Optimization

So far, we have introduced Einsum Tree IR and derived a straightforward way to lower contraction nodes in the IR to loops over primitives. The question now is whether and how we can modify IR trees to make our primitive-based lowering work. Spoiler alert: it is possible, and can be done using a remarkably simple heuristic. We will write a tutorial on this later, but for now we refer interested readers to our ASPLOS 2025 presentation in the ML compilers session: Einsum Trees: An Abstraction for Optimizing the Execution of Tensor Expressions.

We also invite you to try thinking in einsum trees and to check out the Einsum Tree Visualizer developed by undergraduate student Thorsten Kröhl. To get you started, here is the trus,pqtu->pqrs example. And for a more complex tree, you can see the initial version and a corresponding optimized one by opening the links.

Hello SME!

Our lab has just started analyzing Apple’s M4 chip. It turns out that M4 supports the Scalable Matrix Extension (SME) of the Arm Architecture. This opens the way for open source developments that support M4’s matrix accelerator(s). We plan to add SME support to the JITter LIBXSMM in the next few weeks, which will allow us to integrate SME into upstream software. We are documenting these efforts on a dedicated homepage. Check it out!

Image of the iPad Pro on which our M4 work is done.

Guest Lecture: Paul Springer

Paul Springer a senior software developer at NVIDIA will give a guest lecture in the class Parallel Computing. The lecture will take place on Thursday, January 11, 2024, from 12:00PM - 02:00PM. The location is HS 235 (Fürstengraben 1). Interested students outside of the class are cordially invited to join. Please write a brief email to alex.breuer@uni-jena.de or in Matrix if you would like to attend and are not part of the Parallel Computing class.

About the speaker: Dr. Paul Springer is a senior software developer at NVIDIA with a strong interest in low-level kernel development and their applications to quantum circuit simulations, computational physics and machine learning. Before joining NVIDIA in 2018, he received his Ph.D. in computer science from RWTH Aachen university where he focused on the development of high-performance tensor operations and dense linear algebra. Nowadays he is primarily tasked with the design of CUDA math libraries—most noticeably cuTENSOR.


Latest Posts