Understanding Backpropagation
The fast and efficient computation of gradients is a defining feature of all modern machine learning frameworks such as PyTorch. The algorithm powering all of these frameworks and the modern deep learning revolution is known as backpropagation (or for those more savy, reverse mode automatic differentiation). It’s surprisingly easy to implement, so that’s what we’re going to be doing in this course.
If the explainations below don’t click at first, it’s recommended that you skip this section, read the next chapter, and then come back to it. Don’t spend too much time trying to understand this on the first pass!
What Problem’s Being Solved?
It would be easy to hand wave away the purpose of this algorithm by telling you that it allows us to evaluate the gradient of massive functions with billions of parameters quickly and efficiently. To truly appreciate backpropagation, it’s helpful to look at where other attempts fail.
Numerical Differentiation via Finite Differences
A simple numerical approach we could try would be to use a modified version of the limit definition to approximate the gradient.
Here, is the vector of all parameters, represents any individual parameter, and is the basis vector. Using this method, we need to be careful when selecting a value for . It needs to be small enough to give us a good approximation that can eventually settle down at a minimum while not being so small that we start running into floating-point precision errors.
While this approach is straightforward, it has a critical flaw: computational cost. For every single partial derivative we want to compute, we need to evaluate the function to calculate . This means that as the number of parameters grows, the number of times the function, , needs to be evaluated grows linearly with it. Recall that can already by itself be computationally expensive due to its size.
Modern neural networks can be thought of as functions with million, billions, or even trillions of parameters. This approach fails early on as you try to use it at such scales.
Symbolic Differentiation
Another technique we could try is known as symbolic differentiation. This is an automated version of manual differentiation, and it’s done by using hard-coded mathematical rules to derive analytical expressions for the derivative.
Although this approach fixes the issues with floating-point precision and allows us to compute the gradient in one pass, it suffers from a problem known as expression swell. In cases where you’re dealing with deeply nested compositions of functions, such as neural networks, taking their derivative leads to exponentially massive expressions compared to the original function. This explosion in size is caused by how derivative rules—particularly the chain rule—have a tendency to expand functions.
Storing the massive expressions that result from symbolic differentiation takes up an enormous amount of memory, and evaluating them is extremely computationally expensive. It can even be slower than numerical differentiation for large expressions.
Backpropagation
The key insight into backpropagation comes from recognizing two things. First, complex functions are composed of many simpler functions. This structure forms a natural sequence of dependencies that can be exploited using the chain rule.
This is much easier to understand if we take a look at a specific example such as . We can see that gets exponentiated, and also multiplied by to get us and respectively. We can do this with every operation and create a graph of dependencies as can be seen in the widget below.