So we know that the essence of training neural networks is optimizing the loss function, and the core algorithm used for optimization is gradient descent, which requires the computation of gradients.

The purpose of backpropagation is to improve the efficiency of gradient computation. It does so by decomposing the end-to-end (global, with respect to original inputs) gradient into hop (per-subfunction, local, with respect to local inputs) gradients and uses dynamic programming to effectively aggregate local gradients to calculate the global gradient.

NOTE: we first focus on calculating the gradient of a general function instead of specifically computing the gradient for neural network training. This approach allows for the principles learned to be later applied to neural networks. Additionally, it highlights the flexibility in neural network design, emphasizing its capability to incorporate a wide range of functions.

Computation Graphs: A Tool for Visualizing the Chain Rule

To effectively grasp the application of the chain rule in the differentiation of composite functions, it's essential to integrate computation graphs with calculus. These visual representations not only clarify the process but also illuminate the path from input to output, making the abstract principles of calculus more tangible.

Let's dive into a practical example by dissecting the function $f(x) = \sin((x+1)^2)$. The beauty of computation graphs unfolds as we trace the function's transformation through a series of steps, each represented as a node within the graph.

embed - 2024-01-21T153730.963.svg

Initially, we assign a unique intermediate variable to the outcome of each computation stage, facilitating a clearer understanding of the derivative's progression:

  1. First Step: Define $z = x + 1$.
  2. Second Step: Elevate $z$ to its square, $h = z^2 = (x + 1)^2$.
  3. Final Step: Apply the sine function to $h$, yielding $y = \sin(h) = \sin((x + 1)^2)$.

Having established the computation graph, we proceed to calculate the derivative of the entire function using the chain rule. The process unfolds as follows:

$$ \begin{aligned} \frac{dy}{dx} &=\frac{dy}{dh}\cdot\frac{dh}{dz}\cdot\frac{dz}{dx}

\\&=\frac{d\sin((x+1)^2)}{d(x+1)^2}\cdot\frac{d(x+1)^2}{d(x+1)}\cdot\frac{d(x+1)}{dx}

\end{aligned} $$

The computation breaks down into digestible parts:

Thus, the comprehensive derivative of $f(x)$ integrates these local derivatives into a cohesive whole:

$$ \frac{d\sin((x+1)^2)}{dx}=\cos((x+1)^2)\cdot2(x+1)\cdot1 $$

This stepwise multiplication not only reinforces the chain rule's application but also underscores the critical role of local derivatives in computing the global gradient. Through this example, the computation graph emerges as an invaluable tool, elucidating the intricate dance of calculus within composite functions.