Backpropagation & Gradient Descent
There are many approaches to making our machines learn. If you want to understand how current Large Language Models (LLMs) are trained, one of the important ones you need to understand, at least to some degree, is backpropagation and gradient descent. These techniques are highly successful in optimizing neural networks and have been the driving force behind a lot of the "magic" we see today.
This article is my attempt to explain backpropagation and gradient descent in easy terms. I'll start with an explanation and then give simplified examples to introduce basic ideas, play around with terminology and explore the topic a bit further – all without having to do super complicated math things. The aim is to develop some basic understanding of the concepts involved. I found it really helps deciphering other bits of AI discourse.
- An impossible challenge: Dog!
- Finding direction from randomness
- A greatly simplified example
- Problem 1: Overshooting
- Problem 2: Getting stuck at a local optimum/minimum
- Problem 3: Uninterpretability
- Problem 4: Human Extinction
- Summary: Terminology
- Final words
An impossible challenge: Dog!
Imagine you are in this situation:
- You have a goal: You have a goal you want to achieve with your neural network, let's say you want your network to say "Dog!", when you show it pictures of a dog.
- You can measure success: You can tell how good of a job your network is doing at saying "Dog!" which allows you to compare performance for different configurations of your network.
- You have no idea how to make the network recognize dogs: Unfortunately, you have no way to know how to configure the network, so it is able to recognize dogs. Even with a super computer, if you just tried to guess all possible configurations and see which ones can recognize dogs, you are likely to still sit there and wait when the universe dies.
Finding direction from randomness
The brillant idea behind backpropagation is to not even try and configure the network correctly or to understand how that might work. You just begin with any random configuration of the network's nodes. Let's call that config-1 – and voilà, you built a network that is really bad at recognizing dogs. Next, you fiddle with some of the values in your configuration and test how well the network would do with these variations. Then you update config-1 to the variation that worked best. This is config-2. Starting from config-2, you then test new variations for the network parameters – but now you don't just do it randomly, you keep exploring in the direction of changes you did in the last step; e. g., if you have lowered a value, you prioritize lowering that value further. You are following a vector or "the gradient". Once you arrive at config-3 you repeat the process again. And again.
In practice, with how many parameters you fiddle before checking how well the network can say "Dog", by how much you change the parameters for each step and how big the influence of past steps is on new steps is something clever math people have to figure out. But, we don't need to understand the algorithms involved in that to get a better understanding of the overall process. Let's look at the same example in more detail:
- Start with a random network configuration: We start with a random network configuration, config-1, meaning all properties are just set to random numbers.
- Vary the network's properties randomly: To get to config-2 you change change millions of random parameters by small increments in random directions and you do that A LOT of times. That gives you A LOT OF variations of config-1 having millions of random changes each.
- Test the variations: Now you check for each variation how well it would recognize dogs. This is complicated. Luckily, for our purposes it is enough to accept that the clever math people can develop tests for this. What we are really interested in is this: Some of the variations will do better than others. And one of the variations is the best. In techno-babble: One set of parameters minimizes the loss function the most (in an example below we will start using a real loss function; don't worry for now).
- Update the network: Now you take a step, meaning you update the network's properties to the variation achieving the best result. We are now at config-2.
- Vary the network's property again, but follow the gradient we discovered: We now create variations of config-2, only that now the changes are not entirely random anymore. We emphasize the direction of changes we took the step before. We follow the gradient down – down, because we are minimizing a loss function; if we'd maximize a reward function, we'd be doing gradient ascent. How this is done exactly is for the clever math people to figure out (again), for our understanding here it is enough that we are zeroing in on the direction of steepest descent (to clarify the terminology: the path of steepest descent/ascent is the gradient.).
In summary: Backpropagation and gradient descent work by iteratively adjusting the model parameters to minimize the loss function as efficiently as possible. We give the model a measurable goal and we give it the algorithms it needs to figure out the fastest way to the desired state on its own. After enough iterations, our network says "Dog".
A greatly simplified example
Let's say we have a really simple system, it consists of one variable x that can take the value of one natural number from 1 to 15. Our goal is that the number is 10. We pick a random number for config-1, let's say 4. Below is our solution space, meaning all possible states our system could be in, config-1 and our goal are in bold:
0 - 1 - 2 - 3 - 4 - 5- 6 -7 -8 - 9 - 10 - 11 - 12 - 13 - 14 - 15
Our step size, or learning rate, is 1, meaning we can either go one step to the left or one step to the right. We are bad at math, so we have no idea which direction is the right one. Luckily, we can use a ridiculously simplified version of backpropagation and gradient descent!
Our config-1 is configured as x=4. We want it to be at 10. Let's use a proper loss function in this example: f(x)=(goal-x)^2
The expression "goal-x" gives us the distance to the target; we square it, so the result is always positive. If we plug config-1 (x=4) into the loss function, we get (10-4)^2=36. Now let's try to find our way to somewhere where the loss is lower than 36.
As the system is one dimensional we only have two possible directions to explore. If we took a step to the left, the state of our system would be x=3 and the loss would be (10-3)^2 = 49. That's more than 36. So that must be the wrong direction! Let's take a step from config-1 to the right. The state or our system would be x=5, and the loss function would amount to (10-5)^2 = 25. Yes! 25 is lower than 36, we found a way to descend. As there are not other directions left to check in this system, we also found the steepest way to descend, the gradient. Therefor, we take the step to the right and update our model to config-2 (x=5).
0 - 1 - 2 - 3 - 4 - 5- 6 -7 -8 - 9 - 10 - 11 - 12 - 13 - 14 - 15
From here, we look for the next step. This time, there is no need to experiment with taking a step to the left. The gradient was going to the right. So, let's continue in this direction and see what would happen. The loss function would be (10-6^2)=16. Great, 16 is smaller than 25. Let's take the step and update our state to config-3 and value 6.
0 - 1 - 2 - 3 - 4 - 5 - 6 -7 -8 - 9 - 10 - 11 - 12 - 13 - 14 - 15
This is huge! There was no need to go left, run the loss function, compare the result to the loss function of our current state at config-2, then simulate a step to right and compare results again. We could just go to the right and perform one check, because we were following the vector of descent we identified before. If you imagine that process in a network with billions of nodes connected by weights expressed in floating point numbers, it becomes clear that a lot of compute can be saved that way. Let's end this example here; the way I set it up, the process will keep going to the right by 1 each step, minimizing the loss function until the system reaches x=10, the desired goal. Let's look at some potential problems with the whole approach now.
Problem 1: Overshooting
There would be another way of saving compute and getting to our solution faster: increasing out learning rate, meaning the size of the steps we take. In our simple system above, we needed 6 steps to get from 4 to 10. That sounds pretty wasteful! Let's say we are in a rush and increase our learning rate to 2. So we go from x=4 to 6 to 8 and to 10. No we arrived in just 3 steps! Great.
But, imagine we randomize our starting conditions again and config-1 starts at x=5 this time. We'd go to x=7, to 9 and then to 11. Uh, oh… when we had arrived at x=9, the loss function returned 1. We were so close! But with the next step we arrive at x=11, we overshot our goal, the loss function returns 1 again. If we continue to the right, things get even worse, step to x=13 and the loss function returns 9! There is nowhere to go. We will never arrive at the optimal solution. The best we can do is to remain at x=11 and minimize our loss function to 1. Our neural network might forever be unable to tell dogs from cats. Or think nuclear strikes are funny.
Problem 2: Getting stuck at a local optimum/minimum
We start again with config-1 at x=4 and chose a learning rate of 1. The desired state for our system is still x=10. Though, to understand this problem, we need to modify our solution space a little bit by adding another 9 between 6 and 8 (please ignore that makes no sense mathematically, just follow what the system does):
0 - 1 - 2 - 3 - 4 - 5 - 6 - 9 -8 - 9 - 10 - 11 - 12 - 13 - 14 - 15
This starts of like our first example. We quickly realize our vector is to the right and we start sliding down the gradient. After 2 steps, this is config-3:
0 - 1 - 2 - 3 - 4 - 5 - 6 - 9 -8 - 9 - 10 - 11 - 12 - 13 - 14 - 15
The loss function is down to 16. Things are looking great. We continue to the right. And arrive at config-4:
0 - 1 - 2 - 3 - 4 - 5 - 6 - 9 -8 - 9 - 10 - 11 - 12 - 13 - 14 - 15
Wow. The loss function dropped all the way to (10-9)^2=1, we must be really close. We simulate another step to the right and...
…the loss function goes up to (10-8)^2=4. That isn't isn't smaller than 1! We are not at the optimal solution, but we cannot get out of the dip in the loss function with our learning rate. We are trapped in a so called local optimum (or local minimum). If we plot out the loss function for all possible states in our solution space, we can even see the sad little machine learning expert trapped in the minimum:
Of course, the clever math people have all kind of solutions for this. We could experiment with different learning rates. Or we could reset the system with a new random initial config-1 – maybe we'd avoid getting trapped that way, but for this example, we don't need to look into this any further.
There are some fundamental lessons here: High learning rates can get to the goal faster and save compute, but they risk overshooting. Lower learning rates make it less likely to overshoot, but need more steps (compute) and have a higher risk at getting stuck in local minima. Another lesson is that it is OK to shoehorn a 9 between a 6 and an 8, if it helps illustrating a point.
Problem 3: Uninterpretability
Optimization methods like backpropagation and gradient descent provide a big upside that comes with a big downside: They allow us to find solutions for problems that are so complicated we wouldn't know how to solve them by thinking them through. The downside is that we don't understand the solutions. We don't know how they work, we just know that they work. That was one of the first really big surprises for me when I started to look into how neural nets function – we really do not know. There is a subfield of AI research, called Interpretability, that tries to understand what is going on inside these systems, but progress is slow. Not understanding how the increasingly capable systems we build actually work is a good segue to the next potential issue...
Problem 4: Human Extinction
Yes, this is a bit clickbait-y, but many problems related to AI alignment are directly related to how we use optimization algorithms to essentially trade interpretability for capability. One of the core issues is this: We want to optimize the system so its output is as close as possible to our goals – but for that, we need to be able to specify our goals, measure the output, and reliably assess how close the outcome comes to what we really want as opposed to all the other ways we measure success that might produce a positive training signal. In other words: If you tell a powerful AI system to reduce unemployment, better keep it away from guns or at least make sure you know how to prevent it from getting certain ideas. One way to help would be the ability to know what is going on inside our systems. But currently, we don't have that ability. If you want to see examples of misaligned AI systems, this video has some really interesting ones.
Summary: Terminology
Learning about backpropagation and gradient descent made me learn a lot more AI jargon that is actually useful when trying to understand what people on the interwebs have to say. Here is a quick summary of terminology:
- Backpropagation & gradient descent: Try which variations work best, then take small steps into that direction until that stops working best. Stop when you can find no other direction that is better than where you are now. In other words: Take iterative steps down the gradient (direction of sharpest descent) until you have minimized the loss function.
- Loss function: A function that tells you how big the difference between the predicted output and the desired output is.
- Gradient ascent: The same as gradient descent, but we use a reward function that we want to go up instead of a loss function we want to go down. Same thing, but the other way around.
- Learning rate: Determines the step size at which the model parameters are updated during gradient descent.
- Local optimum/minimum: A point in the solution space where the loss function value is locally smaller than all the neighboring variations that can be reached with the current learning rate, but not necessarily the the best possible solution overall.
- Overshooting: Overshooting is a potential issue when the learning rate is too large. The algorithm takes steps that are too big, causing it to "overshoot" the optimal solution.
- Training signal: Your measure of success, telling the neural network how good of a job a step would do. Training signals can be produced in great many ways, e. g. from techniques involving human feedback to fully automated training system where an algorithm performs some kind of check on the model's output vs. the given goals.
Final words
It's really fascinating to me how good of a backdrop even this relatively shallow understanding of backpropagation and gradient descent provides when following AI discourse. RLHF? What a clever way to provide a continuous training signal to drive backpropagation & gradient descent. An AI agent misbehaving in a computer game? Clearly, the AI found behavior that is more effective at minimizing the loss function than the behavior we actually wanted to illicit. How similar is ChatGPT to human cognition? Who knows, but it was built using an optimization method drastically different from evolution's way to optimize things.
There is lot's more to learn. SGD, Momentum, Adagrad, RMSprop, Adaptive Moment Estimation – I have no idea how any of these things work in detail, but I know they are mathematical techniques to minimize the loss function. Often, that will be enough to understand what a conversation or article talks about. The basic idea of backpropagation and gradient descent provides a lot of context to otherwise opaque content. So, if you, dear reader, are like me – an outsider looking in, trying to understand what those AI people are doing – I hope you'll find this article is useful as I found writing it.