5. Autograd

All major machine learning frameworks include automatic differentiation engines. The engines are, for example, used heavily to realize backpropagation when training neural networks. Respective details are available from the frameworks’ documentation: PyTorch, TensorFlow, JAX.

In this lab we’ll have a closer look at reverse-mode automatic differentiation by implementing our own (tiny) engine. The goal is to get a high-level idea of how the respective pieces are connected and what information of the forward pass is required in the backward pass. This understanding will lay the groundwork to understand PyTorch’s automatic differentiation package torch.autograd and for defining our own custom extensions under torch.autograd’s umbrella in Section 6.

5.1. Examples

This part of the lab decomposes functions with increasing complexity into simpler ones. These simple functions are then used to formulate the forward pass and the backward pass. For now, we will perform the decompositions manually and recursively apply the chain rule in the backward pass. In Section 5.2 we will then change our strategy by defining more complex functions out of elementary building blocks which we can handle automatically.

Our first function \(f\) is a very simple one:

\[f(x,y,z) = x ( y + z )\]

We identify two expressions which allow us to compute the forward pass: \(a=y+z\) and \(b=xa\). Therefore, a piece of Python code realizing the forward pass of function \(f\) could read as follows:

def forward( i_x,
             i_y,
             i_z ):
   l_a = i_y + i_z
   l_b = i_x * l_a

   return l_b

The backward pass is a bit more challenging: We are interested in computing the partial derivatives for the three inputs, i.e., \(\frac{\partial f}{\partial x}\), \(\frac{\partial f}{\partial y}\) and \(\frac{\partial f}{\partial z}\). Function \(f\) is simple enough such that we see the solution without thinking too much about it:

\[\frac{\partial f}{\partial x} = y+z, \; \frac{\partial f}{\partial y} = x, \; \frac{\partial f}{\partial z} = x\]

The approach of formulating the derivatives explicitly gets increasingly more complex for more complex functions. We follow a different idea in software and formulate the partial derivatives by means of the chain rule. This only requires us to code partial derivatives for the identified building blocks which we use in a structured procedure for the composed function (enabling automation down the road).

In the example, we start with \(b\) and observe the following:

\[\begin{split}\begin{aligned} \frac{\partial b}{\partial x} = a, \; \frac{\partial b}{\partial a} = x, \\ \frac{\partial a}{\partial y} = 1, \; \frac{\partial a}{\partial z} = 1. \end{aligned}\end{split}\]

Application of the chain rule for \(\frac{\partial b}{\partial y}\) and \(\frac{\partial b}{\partial z}\) then gives us:

\[\begin{split}\begin{aligned} \frac{\partial f}{\partial x} = \frac{\partial b}{\partial x} = a = y + z \\ \frac{\partial f}{\partial y} = \frac{\partial b}{\partial y} = \frac{\partial b}{\partial a} \frac{\partial a}{\partial y} = x \cdot 1 = x \\ \frac{\partial f}{\partial z} = \frac{\partial b}{\partial z} = \frac{\partial b}{\partial a} \frac{\partial a}{\partial z} = x \cdot 1 = x \end{aligned}\end{split}\]

Once again, we can formulate the backward pass in a single piece of Python code:

def backward( i_x,
              i_y,
              i_z ):
  l_a = i_y + i_z

  l_dbda = i_x
  l_dbdx = l_a

  l_dady = 1
  l_dadz = 1

  l_dbdy = l_dbda * l_dady
  l_dbdz = l_dbda * l_dadz

  return l_dbdx, l_dbdy, l_dbdz

Tasks

  1. Implement the two given methods forward and backward for \(f(x,y,z) = x ( y + z )\). Test your implementation in appropriate unit tests using the library unittest.

  2. Implement the forward and backward pass for the following function:

    \[g(w_0,w_1,w_2,x_0,x_1) = \frac{1}{1 + e^{-(w_0 x_0 + w_1 x_1 + w_2)} }\]

    Follow the described procedure, i.e., harness the chain rule in the backward pass! Test your implementation in appropriate unit tests.

  3. Implement the forward and backward pass for the following function:

    \[h(x,y) = \frac{ \sin(xy) + \cos(x+y) }{ e^{x-y} }\]

    Follow the described procedure, i.e., harness the chain rule in the backward pass! Test your implementation in appropriate unit tests.

5.2. Engine

In Section 5.1 we formulated the forward pass for “complex” functions by splitting them into simpler expressions. The chain rule then allowed us to formulate the respective backward pass. This part of the lab defines a set of building blocks which can then be used to assemble complex composite functions.

Roughly following the approach taken in torch.autograd, we first define modules for our building blocks. A single module has a forward and a backward method. For example, to realize scalar additions we could define the module Add.py as follows:

 1## Forward method which compute a+b.
 2# @param i_ctx context object.
 3# @param i_a node a.
 4# @param i_b node b.
 5# @return result a+b.
 6def forward( io_ctx,
 7             i_a,
 8             i_b ):
 9  l_result = i_a + i_b
10  return l_result
11
12## Backward method.
13# @param i_ctx context object.
14# @param i_grad gradient w.r.t. to output of forward method.
15# @return gradient w.r.t. to input of forward method.
16def backward( i_ctx,
17              i_grad ):
18  l_grad_a = i_grad
19  l_grad_b = i_grad
20  return [ l_grad_a, l_grad_b ]

Similarly, for scalar multiplications we could define Mul.py:

 1## Forward method which compute a*b.
 2# @param i_ctx context object.
 3# @param i_a node a.
 4# @param i_b node b.
 5# @return result a*b.
 6def forward( io_ctx,
 7             i_a,
 8             i_b ):
 9  io_ctx.save_for_backward( i_a,
10                            i_b )
11  l_result = i_a * i_b
12  return l_result
13
14## Backward method.
15# @param i_ctx context object.
16# @param i_grad gradient w.r.t. to output of forward method.
17# @return gradient w.r.t. to input of forward method.
18def backward( i_ctx,
19              i_grad ):    
20  l_a, l_b = i_ctx.m_saved_data
21  l_grad_a = l_b * i_grad
22  l_grad_b = l_a * i_grad
23  return [ l_grad_a, l_grad_b ]

As done in PyTorch we use the context object to pass data from the forward method to the backward method. For example, in Mul.py (lines 9 and 10), the two input values i_a and i_b are temporarily stored as part of the context object and then read (line 20) and used in the backward method.

We continue the development of our tiny autograd engine by embedding the function classes in a Node class. Node keeps track of the elementary functions used in the forward pass and allows our users to conveniently trigger the backward pass for the composite function. The first version of the class could read as follows:

 1from . import context
 2from . import functions
 3
 4class Node:
 5  ## Initializes a node of the computation graph.
 6  # @param optional value used for leaf nodes.
 7  def __init__( self,
 8                i_value = None ):
 9    self.m_value = i_value
10    self.m_grad = 0
11    self.m_grad_fn = functions.Nop.backward
12    self.m_children = []
13    self.m_ctx = context.Context()
14
15  ## String representation of a node.
16  # @return newline separated string with value, grad and grad_fn.
17  def __str__( self ):
18    l_string  = "node:\n"
19    l_string += "  value: "   + str( self.m_value ) + "\n"
20    l_string += "  grad: "    + str( self.m_grad ) + "\n"
21    l_string += "  grad_fn: " + str( self.m_grad_fn )
22    return l_string
23
24  ## Backward pass of the node.
25  # @param i_grad  grad w.r.t. to output of forward pass.
26  def backward( self,
27                i_grad ):
28    self.m_grad += i_grad
29    l_grad_children = self.m_grad_fn( self.m_ctx,
30                                      i_grad )
31
32    for l_ch in range( len(self.m_children) ):
33      self.m_children[l_ch].backward( l_grad_children[l_ch] )
34
35  ## Zeroes the grad of the node and all children.
36  def zero_grad( self ):
37    self.m_grad = 0
38    for l_ch in self.m_children:
39      l_ch.zero_grad()
40
41  ## Returns a new node which represents the addition of the two input nodes.
42  # @param self first input node.
43  # @param i_other second input node.
44  # @return node representing the addition.
45  def __add__( self,
46               i_other ):
47    l_node = Node()
48    l_node.m_grad_fn = functions.Add.backward
49    l_node.m_children = [self, i_other]
50    l_node.m_value = functions.Add.forward( l_node.m_ctx,
51                                            self.m_value,
52                                            i_other.m_value )
53    return l_node
54
55  ## Returns a new node which represents the multiplication of the two input nodes.
56  # @param self first input node.
57  # @param i_other second input node.
58  # @return node representing the multiplication.
59  def __mul__( self,
60               i_other ):
61    l_node = Node()
62    l_node.m_grad_fn = functions.Mul.backward
63    l_node.m_children = [self, i_other]
64    l_node.m_value = functions.Mul.forward( l_node.m_ctx,
65                                            self.m_value,
66                                            i_other.m_value )
67    return l_node

We see that Node defines the two functions __add__ and __sub__. This emulates a numeric object and allows our users to simply add two nodes using the binary + operator and multiply two nodes using the binary * operator. However, in background, much more than simply adding two numeric values is done. First, in lines 48 and 62 the two methods store the required functions which have to be executed in the backward pass. Second, in lines 49 and 63, the child nodes which require the gradient of the newly create l_node object in the backward pass are stored. Last, in lines 50-52 and 64-66, the forward method is executed and the result is stored as part of the Node object.

The stored information, i.e., the children and the respective method for the backward pass is then used in the method backward defined in lines 26-33. If backward is called for a Node object, the input gradient is added to the node’s member variable m_grad in line 28. Next, the (previously stored) gradient function m_grad_fn is executed in lines 29-30. The output of this function is then passed on recursively to the children for which the backward method is called in lines 32-33.

In summary, calling backward of a node object will traverse backward through the computation graph which was assembled in the forward pass. For each node, the respective derivatives are computed and passed on to the node’s children. The procedure completes once all dependents have been reached. Corresponding leaf nodes which were used to initiate the forward pass do not have any children. For these the dummy function function.Nop.backward is called which is set in the constructor (line 11).

Note, that we are adding to a node’s member variable m_grad in line 28. This means that we could, e.g., initiate the backward pass multiple times and accumulate the gradients internally. To reset the gradients, one has to call the function zero_grad defined in lines 36-39.

Tasks

  1. Make yourself familiar with the code frame. Add unit tests for the module eml/autograd/functions/Mul.py. Implement a unit test in the file eml/autograd/functions/test_node.py which realizes the function \(f(x,y,z)\) of Section 5.1.

  2. Extend the code by adding the modules Reciprocal.py, i.e., \(\frac{1}{x}\) for a scalar \(x\) and Exp.py, i.e., \(e^x\) to eml/autograd/functions. Define the respective methods __truediv__ and exp in eml.autograd.node.Node. Write appropriate unit tests! Test your extended code by realizing the function \(g(w_0,w_1,w_2,x_0,x_1)\) of Section 5.1.

  3. Proceed similarly for \(\sin\) and \(\cos\). Test your implementations and realize the function \(h(x,y)\) of Section 5.1.