A Fully Annotated Autograd Tutorial

In this tutorial, you will learn the theory behind reverse mode automatic differentiation, and see how to implement a real autograd system in practice. We will start with a theoretical overview and move on to implement the modules you'd typically see in an autograd framework, from basic functions like addition and matrix multiplication, to linear and convolution layers. This tutorial assumes the reader is familiar with multivariable calculus and strong with NumPy. Familiarity with basic PyTorch modules is also recommended, as this system will have a similar interface. The source code for this tutorial is available here, and the full documentation for the system is available here.

Finally, you'll see how to use this system to train a ResNet50 on the CIFAR10 dataset! Let's get started.

The Chain Rule, Backpropagation, and Autodiff

From the point of view of an autograd system, there are two main entities: variables and functions. A variable can be a scalar, a matrix, or a tensor of arbitrary dimensions. A function \(f\) is a mapping

\[ f(V_1, V_2, \ldots,V_k)\rightarrow V_o \]

from a fixed-size list of variables to a single variable. However, for now we will assume that all variables are real valued scalars, and that all functions output scalars, in order to develop the needed theory. Then, we will expand the theory to include the general formulation of variables and functions above.

Functions can be composed in the sense that the output variable of one function is used as an input variable of another. In this way one can construct a computation graph, where the nodes are variables, and for each variable pair \(x,y,\) there exists an edge if \(x\) is an input to a function whose output is \(y\). An example of such a graph is depicted below. \[ \begin{align} y_1 &= f_1(x_1, x_2, x_3, x_4) = x_1+ x_2 +x_3 + x_4\\ y_2 &= f_2(x_4) = x_4^2\\ z &= f_3(x_1, y_1, y_2) = x_1 \cdot y_1 + y_2\\ \end{align} \]

computation graph example
Figure 1: Computation graph example.

There are two key takeaways from this example. Notice first that there are no cycles in the graph. All compositions of functions must respect this rule, meaning that their graphs are directed acyclic graphs (DAGs). Second, as demonstrated by \(x_1\) and \(x_4\), variables are allowed to be inputs to more than just one function.

We're interested in developing a system that, given a computation graph, can calculate the partial derivative of any variable in the graph with respect to any of its ancestors. For this, we must study the inductively defined chain rule.

Theorem (Inductive Chain Rule on a Computation Graph)

Let \( G = (V, E) \) be a computation graph, meaning that:

  1. \(V\) is the set of all defined variables.
  2. For all pairs of variables \( x,y \in V\), if a function takes \(x\) as an input and outputs \(y\), then \((x,y) \in E\).

Then, let \( x, z \in V \) be two variables such that \( z \) is a descendant of \( x \). That is, there exists a directed path from \( x \) to \( z \) in the computation graph. Finally, let \(c(x)\) denote the set of children of \(x\).

Then the derivative of \( z \) with respect to \( x \) is given by:

\[ \frac{dz}{dx} = \sum_{y \in c(x)} \frac{dz}{dy} \cdot \frac{\partial y}{\partial x} \tag{1} \]

where we take the derivative of a variable with respect to itself to be \(1\).

Let's think about what this theorem means for us. Suppose we are attemping to calculate the derivative of some variable \(z\) with respect to an ancestor \(x\), and furthermore we have already calculated \(\frac{dz}{dy}\) for all \(y \in c(x)\). Then all that remains in order to obtain \(\frac{dz}{dx}\) is to first calculate \(\frac{\partial y}{\partial x}\) for all children \(y \in c(x)\), and to invoke Equation (1). This is the heart of backpropagation. If we know the derivative of \(z\) with respect to all the children of \(x\), then the derivative of \(z\) with respect to \(x\) is straightforward to calculate.

Going back to the example in Figure 1, let's think about how we would obtain \(\frac{\partial z}{\partial x_1}\). The children of \(x\) are \(c(x) = \{y_1, z\}\), and thus the inductive chain rule gives us

\[ \begin{align} \frac{dz}{dx_1} &= \frac{dz}{dz} \cdot \frac{\partial z}{\partial x_1} + \frac{dz}{dy_1} \cdot \frac{\partial y_1}{\partial x_1}\tag{2} \\ &= 1 \cdot y_1 + x_1 \cdot 1\\ &= y_1 + x_1 \end{align} \]

A word on notation, the difference between \(\frac{dz}{dx_1}\) and \(\frac{\partial z}{\partial x_1}\) is important. The former represents the total derivative, whereas the latter represents a direct functional dependence.

Since this is a simple example, we were able to replace each derivative in Equation (2) with a directly computed value, and our invokation of the inductive chain rule was unsystematic, but the example shows a glimpse of its power. In the next subsection, we'll explore the algorithm at the heart of automatic differentiation.

Backpropagation

Suppose instead we had an elaborate computation graph with thousands of variables and functions. Suppose also that \(z\) is some variable in the computation graph and that we'd like to calculate the derivative of \(z\) with respect to every single ancestor of \(z\). We'd need a method that systematically applies the chain rule to make these calculations. This is exactly what backpropagation accomplishes.

Think about the earlier statement, that said if we have already calculated \(\frac{dz}{dy}\) for all children \(y\) of \(x\), then the inductive chain rule can be invoked to calculate the \(\frac{dz}{dx}\). What we'd like then, is an ordering of all ancestors of \(z\) such that if we calculated the derivatives in that order, each ancestor would have its children's derivatives calculated before itself.

The algorithm that returns this ordering is called topological sort. I recommend learning about this algorithm if you are not familiar with it, there are many resources online. By triggering topological sort on \(z\), treating edges as dependencies, we obtain an ordering such that for any ancestor \(x\) of \(z\), \(x\) appears after in the ordering after its own children. Backpropagation then calculates derivatives in this order, and keeps track of all computed derivatives at each step in a dictionary. The pseudocode is shown below.

BACKPROPAGATION(Computation graph G, node z in G):

Let topo_order = TopologicalSort(G, z)

Let computed = {dz/dz : 1}

For each node x in topo_order (excluding z):

    Let dz/dx = 0

    # Invoke Inductive Chain Rule
    For each child y of x:
        Let partial_x = calculate(∂y/∂x)
        Let total_y = computed.get(dz/dy)
        dz/dx += total_y * partial_x

    insert dz/dx into computed

Returning to the computation graph in Figure 1, the ordering returned by topological sort might be \((z, y_2, y_1, x_2, x_1, x_3, x_4)\). Other orderings are possible too, as multiple orderings satisfy the dependencies in this case. From here, the loop on line 7 would proceed:

  1. current node: \(y_2\)

    computed = \(\{\frac{dz}{dz} : 1\}\)

    calculate: \(\frac{\partial z}{\partial y_2} = 1\)

    invoke chain rule: \(\frac{dz}{dy_2} = \frac{dz}{dz} \cdot \frac{\partial z}{\partial y_2} = 1 \cdot 1 = 1\)

  2. current node: \(y_1\)

    computed = \(\{\frac{dz}{dz} : 1, \frac{dz}{dy_2} : 1\}\)

    calculate: \(\frac{\partial z}{\partial y_1} = x_1\)

    invoke chain rule: \(\frac{dz}{dy_1} = \frac{dz}{dz} \cdot \frac{\partial z}{\partial y_1} = 1 \cdot x_1 = x_1\)

  3. current node: \(x_2\)

    computed = \(\{\frac{dz}{dz} : 1, \frac{dz}{dy_2} : 1, \frac{dz}{dy_1} : x_1\}\)

    calculate: \(\frac{\partial y_1}{\partial x_2} = 1\)

    invoke chain rule: \(\frac{dz}{dx_2} = \frac{dz}{dy_1} \cdot \frac{\partial y_1}{\partial dx_2} = x_1 \cdot 1 = x_1\)

  4. current node: \(x_1\)

    computed = \(\{\frac{dz}{dz} : 1, \frac{dz}{dy_2} : 1, \frac{dz}{dy_1} : x_1, \frac{dz}{dx_2} : x_1\}\)

    calculate: \(\frac{\partial z}{\partial x_1} = y_1\) and \(\frac{\partial y_1}{\partial x_1} = 1\)

    invoke chain rule: \(\frac{dz}{dx_1} = \frac{dz}{dz} \cdot \frac{\partial z}{\partial x_1} + \frac{dz}{dy_1} \cdot \frac{\partial y_1}{\partial x_1} = 1 \cdot y_1 + x_1 \cdot 1 = y_1 + x_1\)

  5. current node: \(x_3\)

    computed = \(\{\frac{dz}{dz} : 1, \frac{dz}{dy_2} : 1, \frac{dz}{dy_1} : x_1, \frac{dz}{dx_2} : x_1, \frac{dz}{dx_1} : y_1 + x_1\}\)

    calculate:\(\frac{\partial y_1}{\partial x_3} = 1\)

    invoke chain rule: \(\frac{dz}{dx_3} = \frac{dz}{dy_1} \cdot \frac{\partial y_1}{\partial x_3} = x_1 \cdot 1 = x_1\)

  6. current node: \(x_4\)

    computed = \(\{\frac{dz}{dz} : 1, \frac{dz}{dy_2} : 1, \frac{dz}{dy_1} : x_1, \frac{dz}{dx_2} : x_1, \frac{dz}{dx_1} : y_1 + x_1, \frac{dz}{dx_3} : x_1\}\)

    calculate: \(\frac{\partial y_1}{\partial x_4} = 1\) and \(\frac{\partial y_2}{\partial x_4} = 2x_4\)

    invoke chain rule: \(\frac{dz}{dx_4} = \frac{dz}{dy_1} \cdot \frac{\partial y_1}{\partial x_4} + \frac{dz}{dy_2} \cdot \frac{\partial y_2}{\partial x_4} = x_1 \cdot 1 + 1 \cdot 2x_4 = x_1 + 2x_4\)

  7. Final output: \(\{\frac{dz}{dz} : 1, \frac{dz}{dy_2} : 1, \frac{dz}{dy_1} : x_1, \frac{dz}{dx_2} : x_1, \frac{dz}{dx_1} : y_1 + x_1, \frac{dz}{dx_3} : x_1, \frac{dz}{dx_4} : x_1 + 2x_4 \}\)

At each iteration of backpropagation, for the current node \(x\), we invoke the chain rule using the previously calculated total derivatives of \(z\) with respect to children of \(x\), combined with the newly calculated partial derivatives of children of \(x\) with respect to \(x\) itself. At each iteration, the dictionary of computed total derivatives grows by one.

Extending Backpropagation to Tensors

We just saw the inductive chain rule and backpropagation in the context of scalar variables and functions. We will now extend it to tensor valued variables and functions. This will allow us to express variables and compositions of functions in a more compact way, and it will allow us to invoke the chain rule "in bulk". We'll start with an example.

Let \(A,B \in \mathbb{R}^{2 \times 2}\) be two matrices, and suppose we let \(z = \text{sum}(A \times B)\), meaning \(z\) is a scalar resulting from summing the entries of the matrix given by \(A \times B\). It makes perfect sense to for example be intersted in the derivative of \(z\) with respect to the entry \(a_{11}\) of \(A\). In fact, \(z = \text{sum}(A \times B)\) can secretly be expressed as a computation graph with scalar valued functions and variables, like we saw before:

\[ A = \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix} = \begin{bmatrix} 3 & 7 \\ 2 & 5 \end{bmatrix} \] \[ B = \begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \end{bmatrix} = \begin{bmatrix} 2 & 0 \\ 0 & 4 \end{bmatrix} \] \[ C = A \times B = \begin{bmatrix} c_{11} & c_{12} \\ c_{21} & c_{22} \end{bmatrix} = \begin{bmatrix} 3*2 + 7*0 & 3*0+7*4 \\ 2*2+5*0 & 2*0 + 5*4 \end{bmatrix} = \begin{bmatrix} 6 & 28 \\ 4 & 20 \end{bmatrix} \]

where the computation graph is given by:

  1. \(c_{11} = a_{11}*b_{11} + a_{12}*b_{21}\)
  2. \(c_{12} = a_{11}*b_{12} + a_{12}*b_{22}\)
  3. \(c_{21} = a_{21}*b_{11} + a_{22}*b_{21}\)
  4. \(c_{22} = a_{21}*b_{12} + a_{22}*b_{22}\)
  5. \(z = c_{11} + c_{12} + c_{21} + c_{22}\)

We can invoke backpropagation on this computation graph just like before, and arrive at the derivatives (try and calculate these for yourself):

\(\frac{dz}{dc_{11}} = \frac{dz}{dc_{12}} = \frac{dz}{dc_{21}} = \frac{dz}{dc_{22}} = 1, \frac{dz}{da_{11}} = \frac{dz}{da_{21}} = b_{11} + b_{12}, \frac{dz}{da_{12}} = \frac{dz}{da_{22}} = b_{21} + b_{22}, \) \(\frac{dz}{db_{11}} = \frac{dz}{db_{12}} = a_{11} + a_{21}, \text{ and } \frac{dz}{db_{21}} = \frac{dz}{db_{22}} = a_{12} + a_{22}.\)

It gets cumbersome to keep track of this many individual derivatives, thus introducing the need for a new kind of notation. Let's define the derivative of a scalar \(z\) with respect to a whole tensor \(T\) as \(\frac{dz}{dT}\), where \(\frac{dz}{dT}\) has the same shape as \(T\) itself, and where each entry in \(\frac{dz}{dT}\) is the derivative of \(z\) with respect to the corresponding entry in \(T\). So all the above derivatives can now be expressed more compactly:

\[ \begin{align} \frac{dz}{dA} &= \begin{bmatrix} \frac{dz}{da_{11}} & \frac{dz}{da_{12}} \\ \frac{dz}{da_{21}} & \frac{dz}{da_{22}} \end{bmatrix} = \begin{bmatrix} b_{11} + b_{12} & b_{21} + b_{22} \\ b_{11} + b_{12} & b_{21} + b_{22} \end{bmatrix},\\ \frac{dz}{dB} &= \begin{bmatrix} \frac{dz}{db_{11}} & \frac{dz}{db_{12}} \\ \frac{dz}{db_{21}} & \frac{dz}{db_{22}} \end{bmatrix} = \begin{bmatrix} a_{11} + a_{21} & a_{11} + a_{21} \\ a_{12} + a_{22} & a_{12} + a_{22} \end{bmatrix},\\ \frac{dz}{dC} &= \begin{bmatrix} \frac{dz}{dc_{11}} & \frac{dz}{dc_{12}} \\ \frac{dz}{dc_{21}} & \frac{dz}{dc_{22}} \end{bmatrix} = \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix}.\\ \end{align} \]

Now you may have noticed that these derivatives closely follow a pattern. Whenever the entries of a matrix have a summation pattern, your instinct should be that it is the result of multiplying two matrices, which is in fact the case here.

\[ \begin{align} \frac{dz}{dA} &= \begin{bmatrix} \frac{dz}{da_{11}} & \frac{dz}{da_{12}} \\ \frac{dz}{da_{21}} & \frac{dz}{da_{22}} \end{bmatrix}\\ &= \begin{bmatrix} b_{11} + b_{12} & b_{21} + b_{22} \\ b_{11} + b_{12} & b_{21} + b_{22} \end{bmatrix}\\ &= \begin{bmatrix} 1 * b_{11} + 1 * b_{12} & 1 * b_{21} + 1 * b_{22} \\ 1 * b_{11} + 1 * b_{12} & 1 * b_{21} + 1 * b_{22} \end{bmatrix}\\ &= \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix} \times \begin{bmatrix} b_{11} & b_{21} \\ b_{12} & b_{22} \end{bmatrix}\\ &= \frac{dz}{dC} \times B^\mathsf{T} \end{align} \]

and similarly \(\frac{dz}{dB} = A^\mathsf{T} \times \frac{dz}{dC}\). I won't prove this here, but this result is no coincidence. It is a special case of a more general result that compactly summarizes the derivative of a scalar with respect to a matrix that was multiplied with another matrix.

Result (Gradient of Scalar Function of Matrix Product)

Let \( A \in \mathbb{R}^{m \times n} \), \( B \in \mathbb{R}^{n \times p} \) and \( C = A \times B \). Furthermore let \( z = f(C) \) where \(f\) is a scalar valued differentiable function. Then

\[ \frac{dz}{dA} = \frac{dz}{dC} \times B^{\mathsf{T}}, \tag{3} \quad \text{and} \quad \frac{dz}{dB} = A^{\mathsf{T}} \times \frac{dz}{dC}. \]

This is a very exciting result, and hints at how autodiff systems actually make computations. The pseudocode backpropagation algorithm I showed earlier is theoretically sound, and it is the most general form of backpropagation where every edge is explicitly present in the computation graph, but it is not how backpropagation is implemented in practice. Instead, autodiff uses formulas like those shown above to apply the scalar chain rule "in bulk". The computation graph then does not need to keep track of every edge and scalar entry in every tensor. Our autodiff system will express the example in this section like so:

\[ \begin{array}{ccccccccc} A& \searrow & & &&\\ B & \rightarrow & \text{matmul} & \rightarrow & C & \rightarrow & \text{sum} & \rightarrow & z\\ \end{array} \]

Then knowing \(z = \text{sum}(C)\) it will reason that \(\frac{dz}{dC} = \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix}\) , and finally knowing \(C = \text{matmul}(A,B)\), it will apply Equations (3) thus obtaining \(\frac{dz}{dA}\) and \(\frac{dz}{dB}\).

More generally from the point of view of our autodiff system, any function \(f\) takes a list of tensors \(X_1, X_2, \ldots X_k\) as input and produces a tensor \(Y\) as output, like so:

\[f(X_1, X_2, \ldots X_k) = Y.\]

If autodiff knows the derivative of a scalar \(z\) with respect to \(f\)'s output (also known as the upstream gradient, denoted \(\frac{dz}{dY}\)), then autodiff uses what's called the "backward function" of \(f\) to compute the derivative of \(z\) with respect to each of \(f\)'s inputs. Every function \(f\) in the system must have a backward function implemented. This backward function is typically a clever result that implicitly invokes the inductive chain rule upon all scalars in the input tensors. In the example above, the backwards function receives \(\frac{dz}{dC}\) as the upstream gradient and invokes Equations (3) to calculate the derivative of \(z\) with respect to both \(A\) and \(B\).

needs more explanations here before transitioning

This approach to autodiff introduces a problem. The issue is that tensors can be used as input to more than just one function. Suppose we augmented the example by letting \( D = \begin{bmatrix} 5 & 3 \\ 1 & 9 \end{bmatrix} , E = B + D, W = C + E,\) and finally \(z = \text{sum}(W)\) . The computation graph would change to

\[ \begin{array}{ccccccccc} A& \searrow & & &&\\ B & \rightarrow & \text{matmul} & \rightarrow & C & \searrow & \\ & \searrow&&&&&\text{add} & \rightarrow & W & \rightarrow& \text{sum} & \rightarrow & z\\ D & \rightarrow &\text{add} & \rightarrow &E &\nearrow \\ \end{array} \]

Now \(B\) is an input to both matmul and add. So autodiff as described so far, given upstream gradients \(\frac{dz}{dC}\) and \(\frac{dz}{dE}\), would calculate two different values for \(\frac{dz}{dB}\), coming from matmul backwards and add backwards. Let's call these \(G_{BC}\) and \(G_{BE}\), respectively. What should autodiff do now with these values?

While neither \(G_{BC} = \frac{dz}{dB}\) nor \(G_{BE} = \frac{dz}{dB}\), it actually turns out that \(\frac{dz}{dB} = G_{BC} + G_{BE}\), which is a result that follows from the inductive chain rule. Let's see why. Remember, the inductive chain rule dictates that for any scalar \(b_{ij} \in B\),

\[ \frac{dz}{db_{ij}} = \sum_{y \in c(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}}, \]

which we can break down further by letting \(c_{C}(b_{ij})\) denote the children of \(b_{ij}\) that are in \(C\), and similarly defining \(c_{E}(b_{ij})\),

\[ \frac{dz}{db_{ij}} = \sum_{y \in c_{C}(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}} + \sum_{y \in c_{E}(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}}. \]

Recalling that \( G_{BC} = A^{\mathsf{T}} \times \frac{dz}{dC}\),

\[ \begin{align} A^\mathsf{T} \times \frac{dz}{dC} &= \begin{bmatrix} a_{11} & a_{21} \\ a_{12} & a_{22} \end{bmatrix} \times \begin{bmatrix} \frac{dz}{dc_{11}} & \frac{dz}{dc_{12}} \\ \frac{dz}{dc_{21}} & \frac{dz}{dc_{22}} \end{bmatrix}\\ &= \begin{bmatrix} \frac{dz}{dc_{11}} \cdot a_{11} + \frac{dz}{dc_{21}} \cdot a_{21} & \frac{dz}{dc_{12}} \cdot a_{11} + \frac{dz}{dc_{22}} \cdot a_{21} \\ \frac{dz}{dc_{11}} \cdot a_{12} + \frac{dz}{dc_{21}} \cdot a_{22} & \frac{dz}{dc_{12}} \cdot a_{12} + \frac{dz}{dc_{22}} \cdot a_{22} \end{bmatrix}\\ &= \begin{bmatrix} \frac{dz}{dc_{11}} \cdot \frac{\partial c_{11}}{\partial b_{11}} + \frac{dz}{dc_{21}} \cdot \frac{\partial c_{21}}{\partial b_{11}} & \frac{dz}{dc_{12}} \cdot \frac{\partial c_{12}}{\partial b_{12}} + \frac{dz}{dc_{22}} \cdot \frac{\partial c_{22}}{\partial b_{12}} \\ \frac{dz}{dc_{11}} \cdot \frac{\partial c_{11}}{\partial b_{21}} + \frac{dz}{dc_{21}} \cdot \frac{\partial c_{21}}{\partial b_{21}} & \frac{dz}{dc_{12}} \cdot \frac{\partial c_{12}}{\partial b_{22}} + \frac{dz}{dc_{22}} \cdot \frac{\partial c_{22}}{\partial b_{22}} \end{bmatrix}\\ &= \begin{bmatrix} \sum_{y \in c_{C}(b_{11})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{11}} & \sum_{y \in c_{C}(b_{12})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{12}} \\ \sum_{y \in c_{C}(b_{21})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{21}} & \sum_{y \in c_{C}(b_{22})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{22}} \\ \end{bmatrix}\\ \end{align} \]

and so the \(ij\) entry of \(G_{BC}\) is

\[ (G_{BC})_{ij} = \sum_{y \in c_{C}(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}} . \]

So we can see that \((G_{BC})_{ij}\) is exactly the part of the sum in \(\frac{dz}{db_{ij}}\) that corresponds to the children of \(b_{ij}\) in \(C\). Similarly, you could conclude that

\[ (G_{BE})_{ij} = \sum_{y \in c_{E}(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}}, \]

and finally

\[ (G_{BC})_{ij} + (G_{BE})_{ij} = \sum_{y \in c_{C}(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}} + \sum_{y \in c_{E}(b_{ij})} \frac{dz}{dy} \cdot \frac{\partial y}{\partial b_{ij}} = \frac{z}{db_{ij}}, \]

thus establishing \(\frac{dz}{dB} = G_{BC} + G_{BE}\).

This result generalizes beyond just this example. Suppose some tensor \(X\) is the input to several functions, so

\[ \begin{align} f_1(\ldots, X, \ldots) &= Y_1,\\ f_k(\ldots, X, \ldots) &= Y_2,\\ \vdots\\ f_k(\ldots, X,\ldots) &= Y_k,\\ \end{align} \]

and \(\frac{dz}{dY_1}, \frac{dz}{dY_2}, \ldots, \frac{dz}{dY_k}\) have already been calculated. Then autodiff will invoke the backwards function for each of \(f_1, f_2, \ldots, f_k\), thus obtaining \(G_{XY_1}, G_{XY_2}, \ldots, G_{XY_k}\), and finally set \(\frac{dz}{dX} = G_{XY_1} + G_{XY_2} + \ldots + G_{XY_k}\). Remember, \((G_{XY_n})_{ij}\) is the portion of the summation of \(\frac{dz}{dX_{ij}}\) corresponding to the children of \(X\) present in \(Y_n\).

Backpropagation via Autodiff System

Finally, we have all the theory needed for the automatic differentiation system that we will implement.

Autodiff System Formalization

Suppose \( G = (V, E) \) is a computation graph such that

  1. \(V\) is a set of tensors.
  2. For any pair \( X,Y \in V\), if a function takes \(X\) as an input and outputs \(Y\), then \((X,Y) \in E\).

Furthermore, let \( z \in V \) be a \(0-\text{dimensional}\) tensor (i.e., \(z\) is a scalar).

Then the autodiff systems calculates \(\frac{dz}{dX}\) for all ancestors \(X\) of \(z\) by

  1. Obtaining an ordering \(z, X_1, X_2, \ldots, X_N\) of all ancestors via topological sort.
  2. For each tensor \(X\) in the ordering, calling its backwards function using the already computed upstream gradient \(\frac{dz}{dX}\), where
    • \(\frac{dz}{dX} = G_{XY_1} + G_{XY_2} + \ldots + G_{XY_k}, \) where \(Y_1, Y_2, \ldots, Y_N\) are the direct children of \(X\), which therefore appeared before \(X\) in the ordering and already had their backwards functions called. This means that \(G_{XY_1}, G_{XY_2}, \ldots, G_{XY_k} \) have already been computed and \(\frac{dz}{dX}\) is available before \(X\) calls its own backwards function.
    • The backwards function of \(z\) requires no upstream gradient.

Autodiff Implementation

With the autodiff algorithm formalized, we will now implement it on top of NumPy. In order to represent variables (tensors) in the context of a computation graph, we extend the nd.array class to include a "Node" attribute. It is through these node objects that the computation graph will be constructed.

class Variable(np.ndarray):
    def __new__(cls, input_array, keep_grad=False):
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(input_array).view(cls)

        # Add the Node attribute to the created instance
        obj.node = Node(keep_grad=keep_grad)

        # Finally, we must return the newly created object:
        return obj

    def __array_finalize__(self, obj):
        if obj is None:
            return
    
# Creating a Variable
X = Variable(np.array([1,2,3,4,5]))

In order to create Variable instances, already existing numpy arrays are casted, which equips them with the Node attribute. Variable instances can either be created directly like in the example, or be the output of a function. We will need to define every single function used in the system (addition, ReLU, softmax, matmul etc.), as well as their corresponding backwards functions.

I'll introduce the Node class and follow up by explaining each element of it.

class Node:
    def __init__(self, propagate_grad=False, keep_grad=False):

        self.grad = None
        self.keep_grad = keep_grad
        self.propagate_grad = propagate_grad

    def propagate(self):

        # Obtain a list of input gradients, corresponding to the list of inputs.
        input_grads = self.backward_fn(
            params=self.backward_fn_params, upstream=self.grad
        )

        # Accumulate grads in input_nodes.
        for input, input_grad in zip(self.input_nodes, input_grads):
            if input.grad is None:
                input.grad = input_grad
            else:
                input.grad += input_grad

        # Disconnect the node from the graph.
        if self.propagate_grad:
            del self.input_nodes
            del self.backward_fn_params
            del self.backward_fn
            del self.topo_visited
            self.propagate_grad = False

        if not self.keep_grad:
            self.grad = None

    # Connects self to computation graph. Used by forward method of a function/module.
    def connect(self, input_nodes, backward_fn, backward_fn_params):
        self.propagate_grad = True if len(input_nodes) > 0 else False

        if self.propagate_grad:
            self.input_nodes = input_nodes
            self.backward_fn_params = backward_fn_params
            self.backward_fn = backward_fn
            self.topo_visited = False

    # The calling node must have backward_fn that accepts None as upstream. The Variable that this node represents must be a scalar.
    def backward(self):

        self.ordering = []
        toposort_nodes(self, self.ordering)

        while self.ordering:
            self.ordering.pop().propagate()
        del self.ordering

    def clear_grad(self):
        self.grad = None


def toposort_nodes(node, ordering):
    for input_node in node.input_nodes:
        if not node.topo_visited and input_node.propagate_grad:
            toposort_nodes(input_node, ordering)

    node.topo_visited = True
    ordering.append(node)

For the rest of this section, assume that \(z\) is some scalar variable in the computation graph, and that our goal is to calculate the derivative of \(z\) with respect to each of its ancestors in the graph. For every variable \(X\), the .grad attribute of its Node represents \(\frac{dz}{dX}\). For now, ignore keep_grad and propagate_grad, as these are used for optimizations later on.

Whenever a new variable \(Y\) is created through a function \(f(X_1,X_2, \ldots, X_k)\), then the function itself calls the connect method upon the Node attribute of \(Y\). The connect method adds the nodes of \(X_1,X_2, \ldots, X_k\) to the list input_nodes, thus adding \(Y\) to the computation graph. Furthermore, the connect method equips \(Y\)'s node with the correct backward function and the parameters needed in order to execute the backward function, with the exception of the upstream gradient. Let's revisit the computation graph below as a concrete example, where \(C = \text{matmul}(A, B)\), along with the implementation of the matrix multiplication function.

\[ \begin{array}{ccccccccc} A& \searrow & & &&\\ B & \rightarrow & \text{matmul} & \rightarrow & C & \rightarrow & \text{sum} & \rightarrow & z\\ \end{array} \]

def matmul(A, B):
    # Compute output of module.
    output = np.matmul(A, B)

    # Create node in computation graph.
    output.node = Node()

    input_nodes = []
    backward_fn_params = {}

    if A.node.propagate_grad or A.node.keep_grad:
        input_nodes.append(A.node)
        backward_fn_params["B"] = B

    if B.node.propagate_grad or B.node.keep_grad:
        input_nodes.append(B.node)
        backward_fn_params["A"] = A

    output.node.connect(input_nodes, matmul_backward, backward_fn_params)

    return output


# Backward definition.
def matmul_backward(params, upstream=None):
    grads = []

    # dA
    if "B" in params:
        grads.append(upstream @ params["B"].T)

    # dB
    if "A" in params:
        grads.append(params["A"].T @ upstream)

    return grads

As you can see, the underlying np.matmul is used to actually calculate \(C\), but then the connect method is called in order to add \(C\) to the computation graph. That's how we will implement every function, by using its underlying NumPy implementation, but handling the autodiff elements ourselves.

Propagate is the method that actually calls the backwards function for a variable. Suppose that the Node of \(C\) already has \(\frac{dz}{dC}\) inside of its .grad attribute, meaning its upstream gradient has been calculated. Then propagate calls the backwards function of matmul, and calculates \(G_{AC} = \frac{dz}{dC} \times B^{\mathsf{T}}\) and \(G_{BC} = A^{\mathsf{T}} \times \frac{dz}{dC}\), so that input_grads is the list \([G_{AC}, G_{BC}]\). The goal is now for the .grad attributes of \(A\) and \(B\) to contain \(\frac{dz}{dA}\) and \(\frac{dz}{dB}\), and what happens on line 16 is \(\frac{dz}{dA} = G_{AC}\) and \(\frac{dz}{dB} = G_{BC}\).

The backward method is called only once, upon the node of \(z\) itself. This method immediately triggers topological sort on the computation graph in order to get an ordering of \(z\)'s ancestors. By calling propagate upon each Node in the ordering one at a time, the system calculates every gradient.

With the overall system explained, we can move on to two optimizations made: the keep_grad/propagate_grad attributes, and disconnecting Nodes from the graph. As a user of the system, you can specify keep_grad = False for any variable whose gradient you are not interested in keeping. Every functions checks which of its input variables have keep_grad = True, or whether the input variables have ancestors with keep_grad = True in which case propagate_grad = True. This logic prevents the system from needlessly calculating gradients. The disconnect logic inside propagate removes the Node's references to the backwards parameters, readying them for deletion. These backwards parameters are expensive to keep around.

Conclusion

This concludes the tutorial! We started with the inductive chain rule for scalars, the theorem at the heart of backpropagation. We then saw how more compactly represent computation graphs by allowing functions to have tensor inputs and outputs. Finally, we saw how to implement backpropagation via an autodiff system that uses backwards functions, implicitly invoking the chain rule. With this tutorial understood, head on over to the documentation website, where you can see how each function is implemented, and which introduces two new concepts: modules and optimizers for model training.