Why Backprop Goes Backward

Backprogation is an algorithm that computes the gradient of a neural network, but it may not be obvious why the algorithm uses a backward pass. The answer allows us to reconstruct backprop from first principles.

The usual explanation of backpropagation (Rumelhart et al., 1986), the algorithm used to train neural networks, is that it is propagating errors for each node backwards. But when I first learned about the algorithm, I had a question that I could not find answered directly: why does it have to go backwards? A neural network is just a composite function, and we know how to compute the derivatives of composite functions using the chain rule. Why don’t we just compute the gradient in a forward pass? I found that answering this question strengthened my understanding of backprop.

I will assume the reader broadly understands neural networks and gradient descent and even has some familiarity with backprop. I’ll first setup backprop with some useful concepts and notation and then explain why a forward propagation algorithm is supoptimal.

Setup

Recall that the goal of backprop is to efficiently compute f/θi\partial f / \partial \theta_i for every weight θi\theta_i in a neural network ff. To frame the problem, let’s reason about an arbitrary weight θ1\theta_1 and node vv somewhere in ff:

To be clear, the node vv refers to the output value of the node after passing the weighted sum of its inputs through an activation function σ\sigma, i.e.:

u=θ1t1+θ2t2++θntnv=σ(u) \begin{aligned} u &= \theta_1 t_1 + \theta_2 t_2 + \dots + \theta_n t_n \\ v &= \sigma(u) \end{aligned}

Note that in a typical diagram, uu, σ\sigma, and vv would all be a single node, denoted by the dashed line. In my mind, the most important observation needed to understand backprop is this: most of computing f/θ1\partial f / \partial \theta_1 can be done locally at every node because of the chain rule:

fθ1=fvvuuθ1 \frac{\partial f}{\partial \theta_1} = \frac{\partial f}{\partial v} \frac{\partial v}{\partial u} \frac{\partial u}{\partial \theta_1}

We can compute v/u\partial v / \partial u analytically; it just depends on the definition of σ\sigma. And we know that u/θ1=t1\partial u / \partial \theta_1 = t_1. So at every node vv, if we knew f/v\partial f / \partial v, we could compute f/θ1\partial f / \partial \theta_1.

The challenge with computing f/v\partial f / \partial v is that downstream nodes depend on the value of vv. Thankfully, the multivariable chain rule has the answer. Given a multivariable function g(w1,w2,,wm)g(w_1, w_2, \dots, w_m) in which each wiw_i is a single variable function wi(v)w_i(v), the multivariable chain rule says:

gv=vg(w1(v),w2(v),,wm(v))=jgwjwjv \frac{\partial g}{\partial v} = \frac{\partial}{\partial v} g(w_1(v), w_2(v), \dots, w_m(v)) = \sum_{j} \frac{\partial g}{\partial w_j} \frac{\partial w_j}{\partial v}

So we can compute f/θi\partial f / \partial \theta_i for any weight θi\theta_i, meaning we have the necessary machinery to attempt to implement backprop in a forward rather than backward pass. Let’s see what happens.

Repeated terms

We want a forward propagating algorithm that can compute the partial derivative f/θi\partial f / \partial \theta_i for an arbitrary weight θi\theta_i. We showed above that at node vv, this is equivalent to:

fθi=fvvθi \frac{\partial f}{\partial \theta_i} = \frac{\partial f}{\partial v} \frac{\partial v}{\partial \theta_i}

Note that I’ve dropped the intermediate variable uu for ease of notation. To design our forward propagating algorithm, let’s formalize an important fact: in a directed computational graph in which node bb depends upon node aa, it is impossible to compute b/a\partial b / \partial a at any point before node bb:

This claim should be obvious. If our computational graph represents a function f(a)=bf(a) = b, it is impossible to compute f(a)f^{\prime}(a) without access to ff and therefore bb.

In our setup, for every downstream node wjw_j that depends on a node vv, it is impossible to compute wj/v\partial w_j / \partial v at node vv. Therefore, in order to compute f/v\partial f / \partial v, we must decompose the term using the multivariable chain rule and pass the other terms needed to compute f/θi\partial f / \partial \theta_i forward to each node wjw_j that depends on vv:

fθi=(jfwjwjvCompute on wj)vθiPass forward \frac{\partial f}{\partial \theta_i} = \Big( \sum_{j} \frac{\partial f}{\partial w_j} \underbrace{\frac{\partial w_j}{\partial v}}_{\text{Compute on $w_j$}} \Big) \overbrace{\frac{\partial v}{\partial \theta_i}}^{\text{Pass forward}}

We can see that such an algorithm blows up computationally because we’re forward propagating the same message many times over. For example, if we want to compute f/θi\partial f / \partial \theta_i and f/θk\partial f / \partial \theta_k where θi\theta_i and θk\theta_k are different weights in the same layer, we need to compute v/θi\partial v / \partial \theta_i and v/θk\partial v / \partial \theta_k separately, but all the other terms are repeated:

fθi=(j(kfzkzkwj)wjv)Repeated termsvθifθk=(j(kfzkzkwj)wjv)vθk \begin{aligned} \frac{\partial f}{\partial \theta_i} = \overbrace{ \Big( \sum_{j} \Big( \sum_{k} \frac{\partial f}{\partial z_k} \frac{\partial z_k}{\partial w_j} \Big) \frac{\partial w_j}{\partial v} \Big)}^{\text{Repeated terms}} \color{#11accd}{ \frac{\partial v}{\partial \theta_i} } \\ \frac{\partial f}{\partial \theta_k} = \Big( \sum_{j} \Big( \sum_{k} \frac{\partial f}{\partial z_k} \frac{\partial z_k}{\partial w_j} \Big) \frac{\partial w_j}{\partial v} \Big) \color{#bc2612}{ \frac{\partial v}{\partial \theta_k} } \end{aligned}

Here is a diagram of message passing the repeated terms:

I think the above diagram is the lynchpin in understanding why backprop goes backwards. This is the key insight: if we already had access to downstream terms, for example wj/v\partial w_j / \partial v, then we could message pass those terms backwards to node vv in order to compute f/v\partial f / \partial v. Since each node is just passing its own local term, the backward pass could be done in linear time with respect to the number of nodes.

A backward pass

I hope this explanation it clarifies how you might get to backprop from first principles trying to compute derivatives in a directed acyclic graph. On a given node bb that depends on a node aa, we simply message pass b/a\partial b / \partial a back to aa. The multivariable chain rule helps prove the correctness of backprop. For any node vv with downstream weights wjw_j, if vv simply sums the backwardly propagating messages, it computes its desired derivative:

fv=jfwjwjv \frac{\partial f}{\partial v} = \sum_{j} \frac{\partial f}{\partial w_j} \frac{\partial w_j}{\partial v}

Once you understand the main computational problem backprop solves, I think the standard explanation of backpropagating errors makes much more sense. This process is can be viewed as a solution to a kind of credit assignment problem: each node tells its upstream neighbors what they did wrong. But the reason the algorithm works this way is because a naive, forward propagating solution would have quadratic runtime in the number of nodes.

  1. Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning representations by back-propagating errors. Nature, 323(6088), 533.