.. _ch:custom_extensions: Custom Extensions ================= PyTorch is a popular machine learning framework with a rich collection of available ready-to-use features. In most cases when harnessing commonly-used machine learning technology it's sufficient to stick to a rather high-level of the library. However, when conducting machine learning research occasionally custom features are needed, e.g., a special layer which is not built-in. Obtaining high computational performance for these features is a highly nontrivial task. This challenge might hinder innovation in machine learning research and is recognized by the research community. An exemplary discussion of current issues of machine learning systems and possible future directions is given in the paper `Machine Learning Systems are Stuck in a Rut `__. This part of the class provides a look into PyTorch's guts. The gained in-depth knowledge will allow us to understand respective challenges and to implement *efficient* custom operators. .. warning:: Keep in mind that we are learning about the backends of machine learning frameworks. Knowing the details about how these frameworks operate internally is extremely helpful. Most of the time though, a machine learning practitioner would stay within the comfort zone of the framework, i.e., stick to PyTorch's built-in layers in our case, and not go down the route of custom extensions. Abstractions are a good thing after all.. 😏 From PyTorch to Machine Code ---------------------------- In this exercise we'll have an in-depth look at linear layers. Specifically, given a weight matrix :math:`W \in \mathbb{R}^{\text{n}_\text{out} \times \text{n}_\text{in}}` our studied layer derives the outputs :math:`y(x) \in \mathbb{R}^{\text{n}_\text{out}}` from the inputs :math:`x \in \mathbb{R}^{\text{n}_\text{in}}` through the following linear transformation: .. math:: y = x W^T This transformation is implemented in the PyTorch layer `torch.nn.Linear `__ where we assume a zero-valued bias, i.e., ``bias = 0`` in the layer's constructor and :math:`b = 0` in the respective formula of the documentation. Using the forward and backward pass of ``torch.nn.Linear`` is simple. However, understanding what's happening under the hood is more demanding but crucial for our custom PyTorch extensions. In the first part of this exercise, we'll start by applying the forward pass to exemplary input data. Next, we'll have a look at the backward pass, print respective vector-Jacobian products and derive formulas for their computation. Once we have reached a solid understanding of the basics, we'll move on and implement our own operators for the linear layer in the following sections. Again we start simple and stay in Python before moving to `PyTorch's C++ API `__. :numref:`ch:tpps_linear` and :numref:`ch:tpps_conv2d` go completely custom by using `Tensor Processing Primitives `__ (TPPs) which allow us to harness just-in-time generated machine code for our workload. .. admonition:: Tasks #. Assume the following exemplary data for the input batch :math:`X`, the weights :math:`W`, and the vector-Jacobian product batch of the loss function :math:`L` w.r.t. the output batch :math:`Y`: .. math:: :nowrap: \begin{aligned} W &= \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix},\\ X &= \begin{bmatrix} \vec{x}_1 \\ \vec{x}_2 \end{bmatrix} = \begin{bmatrix} 7 & 8 & 9 \\ 10 & 11 & 12 \end{bmatrix},\\ \frac{ \partial L }{ \partial Y } &= \begin{bmatrix} 1 & 2 \\ 2 & 3 \end{bmatrix}. \end{aligned} Derive the output of the linear layer :math:`Y(X) = [y(\vec{x}_1), y(\vec{x}_2)]^T` on paper. Manually derive the vector-Jacobian products :math:`\frac{ \partial L }{ \partial X }` and :math:`\frac{ \partial L }{ \partial W }`. #. Construct a ``torch.nn.Linear`` layer with the following parameters ``in_features = 3``, ``out_features = 2`` and ``bias = False``. Initialize the weights :math:`W` of the layer using the given exemplary input data. #. Apply the layer's forward function to the given input batch :math:`X` and print the result. #. Run the backward pass and print the gradients w.r.t. the input :math:`X` and the weights :math:`W`. .. hint:: Use ``requires_grad = True`` for constructing the PyTorch tensors containing the input :math:`X` and the weights :math:`W`. You may adjust the weights of ``torch.nn.Linear`` by setting the member-variable ``weight``. For example, if your linear layer is stored in ``l_linear_torch`` and your local weight-tensor in ``l_weights``, use: .. code-block:: python l_linear_torch.weight = torch.nn.Parameter( l_weights ) Python Extensions ----------------- Until this point nothing was won. We managed to print the vector-Jacobian products but still relied on the pre-defined high-level layer ``torch.nn.Linear``. Time for our custom implementations! We'll start by staying in Python as this simplifies our life a bit and outlines required steps for the targeted C++ and TPP versions. PyTorch's documentation provides `detailed information `__ on how to extend the library. In this task we'll implement our new linear layer in the class ``eml.ext.linear_python.Layer`` from the base class `torch.nn.Module `__. This also requires us to implement a function which performs the forward and backward pass. We'll implement this function in ``eml.ext.linear_python.Function`` which subclasses `torch.autograd.Function `__. .. admonition:: Tasks #. Implement the class ``eml.ext.linear_python.Layer``. The constructor should have the following signature: .. code-block:: python def __init__( self, i_n_features_input, i_n_features_output ) For the time being, implement the forward-pass operation exclusively in the member function ``forward`` of the class. #. Add the class ``eml.ext.linear_python.Function``. Move the details of the forward function to ``eml.ext.linear_python.Function``. For the time being, skip function ``backward`` which is called in the backward pass. #. Add the backward function to ``eml.ext.linear_python.Function``. Test your implementation! .. hint:: Use the function `save_for_backward `__ to transfer required data from the forward to the backward pass. C++ Extensions -------------- Now, that we have an overview on how to implement a PyTorch extension in Python, we'll use C++ and move one step closer to hardware. The general roadmap is the same: 1. Implement a custom layer by subclassing ``torch.nn.Module`` which stores internal data and acts as our entry point; 2. Implement a custom function by subclassing ``torch.autograd.Function`` which contains the actual implementation of the forward and backward pass. Compared to the pure Python version, we'll now compute the forward and backward pass in C++. This means that we require a way to call C++ functions from Python. Once again PyTorch's documentation provides detailed documentation on `custom C++ extensions `__. In short, we'll use `pybind11 `__ to glue the two worlds together. `setuptools `__ will build our extension ahead of time so that we are able to simply import the pre-built extension in Python. .. admonition:: Tasks #. Implement the C++ function ``hello()`` which simply prints "Hello World!". Use pybind11 and setuptools to make your function available in Python. Call it! #. Implement the two Python classes ``eml.linear_cpp.Layer`` and ``eml.linear_cpp.Function``. Call the two C++ functions ``forward`` and ``backward`` from ``eml.linear_cpp.Function`` for the forward- and backward-pass ops. Implement the two functions in the file ``eml/linear_cpp/FunctionCpp.cpp`` and use the following declarations: .. code-block:: cpp torch::Tensor forward( torch::Tensor i_input, torch::Tensor i_weight ); std::vector< torch::Tensor > backward( torch::Tensor i_grad, torch::Tensor i_input, torch::Tensor i_weights ); Use the `ATen tensor library `__ for the actual matrix-matrix multiplication. #. Adjust the implementation of the functions ``forward`` and ``backward`` by replacing the ATen matrix-matrix multiplications with your own C++-only implementation. .. hint:: You may access the raw data of an ATen-tensor by calling `torch.tensor.data_ptr `__. Make sure to call `torch.tensor.contiguous `__ on the Python side before assuming memory layouts for the raw data!