Back Propagation in Convolutional Neural Networks — Intuition and Code
Disclaimer: If you don’t have any idea of how back propagation operates on a computational graph, I recommend you have a look at this lecture from the famous cs231n course.
I have scratched my head for a long time wondering how the back propagation algorithm works for convolutions. I could not find a simple and intuitive explanation of the algorithm online. So, I decided to write one myself. Hope you enjoy!
Why Understand Back Propagation?
Andrej Karpathy wrote in his blog about the need of understanding back propagation coining it as a Leaky Abstraction
‘‘it is easy to fall into the trap of abstracting away the learning process — believing that you can simply stack arbitrary layers together and backprop will “magically make them work” on your data’’
The Chain Rule
The following figure summarises the use of chain rule for the backward pass in computational graphs.
Here is another illustration which talks about the local gradients.
If you understand the chain rule, you are good to go.
Let’s Begin
We will try to understand how the backward pass for a single convolutional layer by taking a simple case where number of channels is one across all computations. We will also dive into the code later.
The following convolution operation takes an input X of size 3x3 using a single filter W of size 2x2 without any padding and stride = 1 generating an output H of size 2x2. Also note that, while performing the forward pass, we will cache the variables X and filter W. This will help us while performing the backward pass.
Note: Here we are performing the convolution operation without flipping the filter. This is also referred to as the cross-correlation operation in literature. The above animation is provided just for the sake of clarity.
Backward Pass
Before moving further, make note of the following notations.
Now, for implementing the back propagation step for the current layer, we can assume that we get 𝜕h as input (from the backward pass of the next layer) and our aim is to calculate 𝜕w and 𝜕x. It is important to understand that 𝜕x (or 𝜕h for previous layer) would be the input for the backward pass of the previous layer. This is the core principle behind the success of back propagation.
Each weight in the filter contributes to each pixel in the output map. Thus, any change in a weight in the filter will affect all the output pixels. Thus, all these changes add up to contribute to the final loss. Thus, we can easily calculate the derivatives as follows.
Similarly, we can derive 𝜕x. Moving further, let’s see some code.
Note: Much of the code is inspired from a programming assignment from the course Convolutional Neural Network by deeplearning.ai which is taught by Andrew Ng on Coursera.
Naive implementation of forward and backward pass for a convolution function
If you enjoyed this article, you might also want to check the following articles to delve deeper into mathematics:
- A Step by Step Backpropagation Example
- Derivation of Backpropagation in Convolutional Neural Network (CNN)
- Convolutional Neural Networks backpropagation: from intuition to derivation
- Backpropagation in Convolutional Neural Networks
I also found Back propagation in Convnets lecture by Dhruv Batra very useful for understanding the concept.
Since I might not be an expert on the topic, if you find any mistakes in the article, or have any suggestions for improvement, please mention in comments.