At the core of many modern AI systems lies a so-called
neural network. These are mathematical models inspired by biological brains
that can be trained on huge amounts of data to 'learn' almost anything (see
universal approximation theorem). From evaluating chess positions in engines like
Stockfish (open-source), or recognizing handwritten digits as demonstrated in this article, to natural
language prediction in chatbots like
ChatGPT (OpenAI). Whenever a task is extremely difficult, or even impossible, to describe using
classical code, neural networks might just be the solution.
To gain a deeper understanding of these models, I decided to implement them completely
I can already hear some of you saying '
The model used above correctly identified an
of the MNIST test data, which consists of an additional 10,000 examples not included
in the training set. The evaluation produced by any properly trained neural network
reflects how closely a given input matches the patterns it has learned during its
training process. Therefore a misclassification neither invalidates the model's
performance nor your handwriting; it simply indicates the network wasn't sufficiently
exposed to that particular style of handwriting during training.
Creating a functional neural network from scratch involves two main steps. First, we formulate a mathematical model inspired by the structure of biological brains. Next, we define and solve the so-called training problem, which is what allows the model to 'learn'. Taking a look at biology, we see that the fundamental unit of the brain is the neuron. A neuron is a highly specialized cell capable of receiving, processing, and transmitting chemical and electrical signals to and from other neurons.
Following the signal flow inside a neuron leads us initially to the dendrites. Here,
special molecules released by other neurons, so-called neurotransmitters, bind
to receptors, triggering the movement of charged ions toward the soma (cell
body). Within the soma, the ions arriving from all dendrites accumulate and combine
their charge. Mathematically, this process can be represented as an input vector and a
summation over its components.
\[ \begin{array}{cc} \text{input signal} & x = \bigg(\begin{smallmatrix} x_1 \\ \vdots
\\ x_m \end{smallmatrix}\bigg) \in \mathbb{R}^m \\[5px] \text{accumulation} &
\displaystyle \sum_{i = 1}^{m}x_i \in \mathbb{R} \end{array} \]
Once the total charge inside the soma exceeds a certain threshold, the neuron
'activates', causing the ions to flow along the axon until they reach the neuron's
terminals. Mathematically, such thresholding behavior is well modeled by the sigmoid
function. However, a variety of activation functions, such as tanh, ReLU, or
softmax, can also be used. See this
wikipedia article
for more details.
\[ \begin{array}{cc} \text{activation} & \displaystyle \sigma(x_i) = \frac{1}{1 +
e^{-x_i}},\ x_i \in \mathbb{R} \end{array} \]
Lastly, when the ions reach the terminals, the charge they carry causes the terminals to release neurotransmitters into the synapses. Synapses are small gaps that act as connection points between terminals and dendrites, enabling signal transfer between neurons. Depending on several factors, such as receptor density and sensitivity or the amount of neurotransmitters released, the strength of these connections can vary. Mathematically, this is easily represented by introducing weights to the input signals described above.
Each neuron can have thousands of dendrites and terminals, i.e. in- and outputs, forming a vast network of interconnected cells. In our model we can do the same: each neuron may receive an arbitrary number of inputs, and its output can be connected to as many other neurons as desired. When it comes to the exact wiring of a neural network, many different architectures are possible; however, we will focus on what is known as a multilayer perceptron (MLP).
Mathematically, a MLP is a function \(f(x;\vartheta)\) with parameters \(\vartheta\), mainly the connection weights, defined by a layered neural network and its corresponding forward-pass equation (more on this later). Each layer can contain any number of neurons with any activation function, and is fully connected to both the previous and the next layer. Furthermore, each connection is assigned a weight that modulates the signals passing through it.
In the case of classifying handwritten digits (MNIST) the input layer contains 784 neurons, one for each pixel in a 28x28 grayscale image. The output layer contains 10 neurons, one for each digit from 0 to 9. The number and size of the hidden layers is arbitrary; suitable configurations are typically found through trial and error. It is also common to use the identity function as the activation function for the input layer, allowing it to simply hold the given inputs.
Finally, in biological brains learning occurs through the formation and removal of synapses, strengthening or weakening the connections between neurons. We can mimic this process by adjusting the connection weights between neurons. To do so, we first gather a set of example input-output pairs \((x,t)\) such that \(f(x;\vartheta) \approx t\) is the desired outcome. We then define a so-called loss function \(\mathcal{L}\), which measures how far the network's output deviates from the corresponding target. Finding weights that minimize this loss across all possible data pairs is known as the training problem and is what enables the model to 'learn'. \[ \begin{array}{cc} \text{euclidian loss} & \displaystyle \mathcal{L}(x,t) = \tfrac{1}{2} ||t-x||^2\\[5px] \text{training problem} & \displaystyle \min_\vartheta\, \mathbb{E}\Big[ \mathcal{L}\big(f(x;\vartheta),\,t\big) \Big] \end{array} \] Once again, there is a wide range of possible choices for the loss function; see this wikipedia article for more details. While formulating the training problem is straightforward, actually solving it is much more challenging. For now though, do not worry about the details, everything will be covered fully and rigorously in dedicated sections below. The goal of this section was simply to give a general picture of what is required to build a neural network.
We begin this section by giving formal definitions of activation functions and multilayer perceptrons. Next, we take a closer look at how the forward-pass equation is derived. Finally, we provide formal definitions of loss functions and the training problem.
Definition. Let \(\tilde{\sigma} : \mathbb{R} \to \mathbb{R}\) be a continuously differentiable function (denoted by \(\tilde{\sigma} \in \mathcal{C}^1)\). Then the elementwise extension \(\sigma: \mathbb{R}^n \to \mathbb{R}\), defined by \(\sigma(x)_i := \tilde{\sigma}(x_i)\), is called an activation function.
In principle, the definition of an activation function can be chosen differently. Here, we adopt one that includes the sigmoid function \(\tilde{\sigma}(x_i) := 1 / (1 + e^{-x_i})\) and works well (i.e. is \(\mathcal{C}^1\)) with the general theory we are about to present.
Definition. Let \(\sigma\) be an activation function and \(\vartheta := \big\{ W^{(k)},\, b^{(k)} \mid k \in [n]\big\}\) a set of parameters. Then a function \(f : \mathbb{R}^{m_0} \to \mathbb{R}^{m_n}\) mapping \(x^{(0)} \mapsto f\big(x^{(0)};\vartheta\big) := x^{(n)}\), defined by the forward-pass equation \[ x^{(k)} := \sigma\big( W^{(k)} x^{(k-1)} + b^{(k)} \big),\ k\in[n]\] is called an \(n\)-layer perceptron. The parameters \(W^{(k)} \in \mathbb{R}^{m_{k} \times m_{k-1}}\) and \(b^{(k)} \in \mathbb{R}^{m_{k}}\) are called weights and biases, respectively. We also use the shorthand \([n] := \{1,\ldots,n\}\).
Next, let us unravel this definition, starting with the parameters \(\vartheta\). Let \(\phi^{(k)}_j\) denote the \(j\)-th neuron of the \(k\)-th layer. Then the weight \(\omega_{i,j}^{(k)}\) modulates the connection \[\phi_i^{k-1} \overset{\omega_{i,j}^{(k)}}{\longrightarrow} \phi_j^{k}.\] Collecting all weights corresponding to layer \(k\) in the matrix \(W^{(k)} = \big( \omega_{i,j}^{(k)} \big)_{i \in [m_{k-1}],\, j \in [m_k]}^T\) provides an elegant way to represent the weighted signal accumulation described earlier. \[ \big( W^{(k)}x^{(k-1)} \big)_j = \sum_{i = 1}^{m_{k-1}} \omega_{i,j}^{(k)} x_i^{(k-1)} \] In addition, it is often useful to allow neurons to exhibit certain biases \(b^{(k)}\). Although this was not covered int the overview section, biases simply introduce a base signal to each neuron, leading to the pre-activation signal \[ z^{(k)} := W^{(k)}x^{(k-1)} + b^{(k)}. \] Applying the activation function \(\sigma\) to the pre-activation signal yields the forward-pass equation defined above and produces the post-activation signal \(x^{(k)} = \sigma(z^{k})\). The function \(f\) then propagates the initial signal \(x^{(0)}\) according to this equation until the final output \(f(x^{(0)};\vartheta) = x^{(n)}\) is obtained, matching the desired network behavior.
Notice that we defined a single global activation function; this will be sufficient for our purposes. Also note that the signals \(x^{(k)}\) are indexed starting from \(0\), whereas both \(W^{(k)}\) and \(b^{(k)}\) start from \(1\). Consequently, there is an initial layer that simply holds the input signal \(x^{(0)}\).
Now consider the training problem. Assume a set of input-output pairs \((x,t) \sim P_\text{data}\ iid\), randomly sampled from some unknown distribution \(P_\text{data}\). Next, we will need to measure the deviation of the network output \(f(x;\vartheta)\) and the target output \(t\) given input \(x\).
Definition. Let \(f\) be an \(n\)-layer perceptron with layer sizes \(m_k\) and parameter set \(\vartheta\). A function \(\mathcal{L} : \mathbb{R}^{m_n} \times \mathbb{R}^{m_n} \to \mathbb{R}^{+}\) mapping \((x,t) \mapsto \mathcal{L}(x,t)\) is called a loss function if \[ \vartheta \mapsto \mathcal{L}\big(f(x;\vartheta),\, t\big) \,\text{ is }\, \mathcal{C}^1. \]
The exact form of \(\mathcal{L}(x,t)\) is, as already mentioned, a matter of choice; we will work with the euclidean loss \(\mathcal{L}(x,t) = \tfrac{1}{2} ||t-x||^2\). What is essential, however, is the requirement that \(\mathcal{L}\) must be \(\mathcal{C}^1\) with respect to \(\vartheta\), ensuring that the gradient \(\nabla_\vartheta\, \mathcal{L}\) exists. This gradient is the central object used in the stochastic gradient descent (SGD) algorithm, which forms the basis for solving the training problem.
Definition. Let \(f\) be an \(n\)-layer perceptron with layer sizes \(m_k\) and loss function \(\mathcal{L}\). Further, let \((x,t) \sim P_\text{data}\ iid\) be randomly sampled from \(\mathbb{R}^{m_0} \times \mathbb{R}^{m_n}\). Then the training problem is the optimization problem \[ \min_\vartheta \underset{P_\text{data}}{\mathbb{E}} \Big[ \mathcal{L}\big( f(x;\vartheta),\, t \big) \Big]. \]
Here we minimize over the parameters \(\vartheta\), evaluating the expected loss \(\mathbb{E}_{P_\text{data}} \mathcal{L}\) of all possible data points. This distinction is quite important, as simply minimizing the loss of the observed dataset may lead to a phenomenon known as overfitting. In that case, the network performs very well on the training data but poorly on new, unseen data. Effectively, the network just 'memorized' the examples rather than 'learning' general patterns.
As mentioned before, the basis for solving the training problem is an algorithm called stochastic gradient descent. For this, let \[Q : \vartheta \mapsto \underset{P_\text{data}}{\mathbb{E}} \Big[ \mathcal{L}\big( f(x;\vartheta),\, t \big) \Big]\] where \((x,t) \sim P_\text{data}\ iid\) is unknown. Next, assume interchangeability of \(\nabla\) and \(\mathbb{E}\) operators (in practice this usually holds and can easily be proven, for example, using the dominated convergence theorem or similar results). This yields \[ \nabla_\vartheta\, Q = \underset{P_\text{data}}{\mathbb{E}} \Big[ \nabla_\vartheta\, \mathcal{L}\big( f(x;\vartheta),\, t \big) \Big]. \] Without knowledge of the distribution \(P_\text{data}\), this expression can not be evaluated further. Instead, we estimate \(\nabla_\vartheta\, Q\) using batches \(B_\nu\) of observed datapoints \((x,t)\), which is why the algorithm is called 'stochastic' gradient descent. It holds \[ \nabla_\vartheta\, Q \approx \frac{1}{|B_\nu|} \sum_{(x,t) \,\in\, B_\nu} \nabla_\vartheta\, \mathcal{L}\big( f(x;\vartheta),\, t \big) =: \nabla_\vartheta\, Q_\nu \] by applying the law of large numbers (LLN). Finally, this allows us to update the current model parameters according to \[ \vartheta_{\nu + 1} \gets \vartheta_\nu - \eta_\nu\, g\big(\nabla_\vartheta\, Q_\nu\big) \] where \(\eta_\nu > 0\) is the SGD step size, called the learning rate, and \(g\) determines the SGD update direction. As is common in optimization problems, \(\eta_nu\) should either be chosen very small or gradually decrease toward zero in order to avoid overshooting minima. Furthermore, it can be beneficial not to strictly follow the direction of steepest descent \(-\nabla_\vartheta\, Q_\nu\), but instead to use some modified version \(-g(-\nabla_\vartheta Q_\nu)\).
Considering the update rule above, there are three unknown components: the learning rate \(\eta_\nu\), the direction \(g\), and the gradient \(\nabla_\vartheta\, \mathcal{L}\). Each of these involve additional choices, either entirely (\(\eta_\nu,\, g\)) or partially (\(\mathcal{L}\)), and therefore deserve their own sections respectively. Despite this, given \(\nabla_\vartheta\, \mathcal{L}\), some simple choices for the learning rate and direction can already work in practice. For example, \(\eta_\nu \equiv \eta > 0\) small and \(g = \text{id}\), or slightly better, \[ \eta_\nu = \frac{\eta_0}{1 + \alpha \nu},\, \eta_0, \alpha > 0 \] which satisfies Robbins-Monro conditions leading to classical convergence results for SGD.
We begin by deriving the gradient \(\nabla_\vartheta\, \mathcal{L}(x^{(n)}, t)\) for the euclidean loss \(\mathcal{L}(x,t) = \tfrac{1}{2} || t-x||^2\), where \(x^{(n)} = f(x; \vartheta)\). Recall the definition of the parameter set \[ \vartheta = \big\{ W^{(k)},\, b^{(k)} \,|\, k \in [n] \big\}. \] Each weight matrix \(W^{(k)}\) has components \(\omega_{i,j}^{(k)}\), and each bias vector \(b^{(k)}\) has components \(b_j^{(k)}\). Therefore, to calculate \(\nabla_\vartheta\, \mathcal{L}\), we must determine the partial derivatives of \(\vartheta \mapsto \mathcal{L}\big( f(x;\vartheta),\, t \big)\) with respect to \(\omega_{i,j}^{(k)}\) and \(b_j^{(k)}\). This is precisely why we required \(\mathcal{L}(\vartheta)\) to be \(\mathcal{C}^1\) in the definition of loss functions. Using the chain rule for the derivative with respect to \(b^{(k)}\), we obtain \[ \frac{\partial \mathcal{L}(x^{(n)},t)}{\partial b_j^{(k)}} = \underbrace{\frac{\partial \mathcal{L}(x^{(n)},t)}{\partial z_j^{(k)}}}_{=\, \delta_j^{(k)}} \cdot \underbrace{\frac{\partial z_j^{(k)}}{\partial b_j^{(k)}}}_{=\, 1} = \delta_j^{(k)} \] where \(z^{(k)} = W^{(k)}x^{(k-1)} + b^{(k)}\) is the pre-activation signal. Take special note of the quantity \(\delta^{(k)} := \partial \mathcal{L} / \partial z^{(k)}\), which represents the change in loss with respect to the pre-activation signal. It will play an important role shortly. Using the same approach for the derivative with respect to \(W^{(k)}\), we obtain \[ \begin{align} \frac{\partial \mathcal{L}(x^{(n)}, t)}{\partial \omega_{i,j}^{(k)}} &= \underbrace{\frac{\partial \mathcal{L}(x^{(n)},t)}{\partial z_j^{(k)}}}_{=\, \delta_j^{(k)}} \cdot \frac{\partial z_j^{(k)}}{\partial \omega_{i,j}^{(k)}} \overset{\star}{=} \delta_j^{(k)}x_i^{(k-1)}\\[5px] \star: \frac{\partial z_j^{(k)}}{\partial \omega_{i,j}^{(k)}} &= \sum_{r=1}^{m_{k-1}} x_r^{(k-1)} \underbrace{\frac{\partial \omega_{r,j}^{(k)}}{\partial \omega_{i,j}^{(k)}}}_{=\, I\{r = i\}} +\, \underbrace{\frac{\partial b_j^{(k)}}{\partial \omega_{i,j}^{(k)}}}_{=\, 0} = x_i^{(k-1)}.\end{align} \] Therefore \(\nabla_{b^{(k)}}\, \mathcal{L} = \delta^{(k)}\) and \(\nabla_\vartheta\, \mathcal{L} = \delta^{(k)} (x^{(k-1)})^T\); here \(\mathcal{L}\) always denotes \(\mathcal{L}(x^{(n)},t)\). To compute \(\delta^{(k)}\), it works best to use an iterative approach starting from \(\delta^{(n)}\). This is where the \(\mathcal{C}^1\) requirement of the activation function \(\sigma\) is needed. \[ \begin{align} \delta_j^{(n)} &= \frac{\partial \mathcal{L}(x^{(n)},t)}{\partial z_j^{n}} = \frac{1}{2} \sum_{r=1}^{m_n} \underbrace{\frac{\partial}{\partial z_j^{(n)}} \big(t_r - x_r^{(n)}\big)^2}_{=0\, \text{unless}\, r = j\, \text{since}\, x = \sigma(z)}\\[5px] &= \frac{1}{2} \frac{\partial}{\partial z_j^{(n)}} \big( t_j - x_j^{(n)} \big)^2 \overset{\star}{=} \big( x_j^{(n)} - t_j \big)^2 \sigma^\prime(z_j^{(n)}) \end{align} \] At \(\star\) we use the chain rule \(\frac{\partial f(x)}{\partial z} = \frac{\partial f(x)}{\partial x} \frac{\partial x}{\partial z}\) together with the fact that \(\frac{\partial x}{\partial z} = \frac{\partial \sigma(z)}{\partial z}\). Denoting elementwise multiplication as \(\odot\), this gives \(\delta^{(n)} = (x^{(n)} - t) \odot \sigma^\prime(z^{(n)})\). Now assume \(\delta^{(k)}\) is known. Using the multivariable chain rule, we obtain \[ \begin{align} \delta_j^{(k-1)} &= \frac{\partial \mathcal{L}(x^{(n)}, t)}{\partial z_j^{(k-1)}} = \sum_{s = 1}^{m_k}\ \underbrace{\frac{\partial \mathcal{L}(x^{(n)}, t)}{\partial z_s^{(k)}}}_{=\, \delta_s^{(k)}} \cdot \frac{\partial z_s^{(k)}}{\partial z_j^{(k-1)}} \\[0px] &\overset{\star}{=} \sum_{s = 1}^{m_k} \delta_s^{(k)} \omega_{j,s}^{(k)} \sigma^\prime(z^{(k-1)})\\[10px] \star : \frac{\partial z_s^{(k)}}{\partial z_j^{(k-1)}} &= \sum_{r = 1}^{m_k-1} \omega_{r,s}^{(k)} \underbrace{\frac{\partial x_r^{(k-1)}}{\partial z_j^{(k-1)}}}_{=\,0\,\text{unless}\,r=j} \hspace{-4px}+\ \underbrace{\frac{\partial b_s^{(k)}}{\partial z_j^{(k-1)}}}_{=\, 0} = \omega_{j,s}^{(k)} \sigma^\prime(z_j^{(k-1)}). \end{align} \] This yields the formula \(\delta^{(k-1)} = (W^{(k)})^T \delta^{(k)} \odot \sigma^\prime(z^{(k-1)})\), which describes the process known as backpropagation, since \(\delta^{(k)}\) is propagated from the last layer back toward the first. Finally, for our choice of the sigmoid activation function \(\sigma(x_i) = 1 / (1+e^{-x_i})\), the derivative can be calculated easily and is given by \(\sigma^\prime(x_i) = \sigma(x_i)(1-\sigma(x_i))\).
Now let us examine the learning rate \(\eta\) in more detail. Earlier, we presented some simple choices, for example \(\eta_\nu = \eta_0 / (1+\alpha\nu)\), which satisfies \(\eta_\nu \to 0\). In practice, a wide range of functions can be used; however, they are usually required to satisfy the following definition.
Definition. A continuously differentiable function \(\eta : [0,T] \to \mathbb{R}^+\) that is decreasing and satisfies \(\forall \nu : \eta_0 \geq \eta_\nu \geq \eta_T\), where \(\eta_\nu := \eta(\nu)\), is called a learning rate.
Notice, instead of decreasing to zero, a minimal learning rate \(\eta_T\) is used. It is also not strictly necessary for \(\eta \in \mathcal{C}^1\); however, the additional smoothness is often desired compared to just \(\eta \in \mathcal{C}\). A popular choice for the learning rate is a function known as cosine decay, given by \[ \eta_\nu^\text{cos} := \eta_T + \tfrac{1}{2} \big(\eta_0 - \eta_T\big) \big(1 + \cos(\pi\nu/T) \big) .\]
The learning rate \(\eta^{\text{cos}}\) corresponds to the first halve of a cosine period, scaled and shifted to lie in the interval \([\eta_T, \eta_0]\). Compared to other smooth curves within these bounds, it maintains a longer initial high phase and a longer final low phase, while keeping a moderate decay rate. This allows for both rapid progress early on and finer adjustments toward the end of the learning process.
Finally, one more idea is often incorporated into the learning rate: to prevent poor initial examples from steering the optimization in an unfavorable direction, a so-called warmup phase is introduced before the actual learning rate takes over. For this, we simply use the linear function \(\eta^\text{up}_\nu := \nu/T_w + \eta_T\), resulting in the final combined learning rate \[ \eta_\nu := \begin{cases} \eta_\nu^\text{up} & 0 \leq \nu \leq T_w\\[5px] \eta_\nu^\text{cos} & T_w \lt \nu \leq T\, .\end{cases} \]
The final component needed in the SGD update rule \(\vartheta_{\nu + 1} \gets \vartheta_\nu - \eta_\nu\,g\big( \nabla_\vartheta\, Q_\nu \big)\) is the update direction \(g\). For this, we use an algorithm called adaptive moment estimation, or Adam for short. More precisely, we follow the method described in this paper, which introduces an improvement over the base algorithm known as weight decay decoupling, resulting in the AdamW algorithm (more on this later).
The main idea of Adam is to maintain exponential moving averages of both the previous directions \(g(\nabla_\vartheta\,Q_\nu) := g_\nu\) and their magnitudes \(g_\nu^2\) (where squaring is applied elementwise). These averages can then be used to estimate the first and second moments, \(\mathbb{E}(g_\nu)\) and \(\mathbb{E}(g_\nu^2)\), which form the basis of Adam's update direction. Let \(m_0,v_0 := 0\) and define \[ \begin{array}{c} m_\nu := \beta_1 m_{\nu-1} + (1-\beta_1) g_\nu\\[5px] v_\nu := \beta_2 v_{\nu-1} + (1-\beta_2) g_\nu^2 \end{array} \] where \(\beta_1, \beta_2 \in (0,1)\) are choosable hyperparameters (see the referenced paper for value recommendations). The quantities \(m_\nu\) and \(v_\nu\) as defined above are called exponential moving averages as recursively evaluating them yields \[ \begin{align} m_\nu &= \beta_1 \Big[ \beta_1 m_{\nu - 2} + (1-\beta_1) g_{\nu-1} \Big] + (1 - \beta_1) g_\nu\\[7.5px] &= \beta_1^2m_{\nu-2} + \beta_1(1-\beta_1)g_{\nu-1} + (1-\beta_1)g_\nu\\[5px] &= \cdots = (1 - \beta_1) \sum_{k = 0}^{\nu-1} \beta_1^k\, g_{\nu - k}\,.\end{align} \] In essence, we are summing all previous directions \(g_{\nu-k}\) while assigning exponentially decreasing weights \(\beta_1^k\) to older terms (larger \(k\)). An analogous expression holds for \(v_\nu\). In this form, the expected values of \(m_\nu\) and \(v_\nu\) can easily be analyzed for potential biases. \[ \begin{align} \mathbb{E}(m_\nu) &= (1 - \beta_1) \sum_{k=0}^{\nu-1} \beta_1^k\, \underbrace{\mathbb{E}(g_{\nu - k})}_{\overset{\star}{\approx}\, \mu\, const.}\\[2px] &\approx (1-\beta_1)\, \mu\, \underbrace{\sum_{k = 0}^{\nu-1} \beta_1^k}_{\text{geom.}\,\text{sum}} = \mu (1-\beta_1^\nu) \end{align} \] Note that \(g_\nu = g(\nabla_\vartheta\, Q_\nu)\) depends on the parameters \(\vartheta\), which are updated before each new \(g_\nu\) is computed. Therefore, the expectations \(\mathbb{E}(g_{\nu-k})\) also change over time. Still, at \(\star\) specifically, it is reasonable to assume some constant value \(\mu\), as sufficiently small learning rates force the parameters to only change slightly between updates. Moreover, earlier directions \(g_{\nu - k}\) contribute progressively less to the sum due to the exponentially decreasing weights \(\beta_1^k\). This leads to the unbiased estimators \[ \hat{m}_\nu := \frac{m_\nu}{1-\beta_1^\nu},\ \hat{v}_\nu := \frac{v_\nu}{1-\beta_2^\nu}\,. \] Let \(\alpha,\epsilon \in (0,1)\) be additional hyperparameters, then the final update direction in Adam is given by \(g_\nu := \alpha \hat{m}_\nu / (\sqrt{\hat{v}_\nu} + \epsilon)\), where all operations are applied elementwise. Here, \(\alpha\) acts as a global scaling factor, while \(\epsilon\) is meant to be some small value preventing division by zero (see the referenced paper for value recommendations). Based on \(\hat{v}_\nu\), the magnitude of the direction \(g_\nu\) adjusts automatically and independently of the learning rate \(\eta_\nu\). This behavior is precisely what the term adaptive in adaptive moment estimation (Adam) refers to.
While we have now technically covered everything required to implement a multilayer perceptron, it is beneficial to discuss one additional idea. This will also lead us to the previously mentioned improvement in AdamW over Adam. Once again, consider the forward-pass equation \[ x^{(k)} = \sigma\big( W^{(k)}x^{(k-1)} + b^{(k)} \big)\,. \] If the weights \(W^{(k)}\) become too large (either positive or negative), the pre-activation signal may also grow excessively, leading to what is known as oversaturation of the neurons. In this case, \(\sigma\) fails to model a smooth transition and rather produces only extreme values. To prevent this, we can introduce a parameter regularization term \(\mathcal{R}(\vartheta)\) into the loss function \[ \mathcal{L}_\text{reg}(x^{(n)},t) := \mathcal{L}(x^{(n)},t) + \mathcal{R}(\vartheta) \] which penalizes large parameters. Note that, usually only model weights, and not model biases, are regularized, since forcing biases toward any particular value would actively contradict their purpose. It holds that \(\nabla \mathcal{L}_\text{reg} = \nabla\mathcal{L} + \nabla\mathcal{R}\), so this idea can be incorporated easily into our existing theory. Concretely, we use \(\mathcal{R}(\vartheta) := \frac{\lambda}{2}||\vartheta||^2\) for some \(\lambda > 0\), which is called \(L_2\) regulation due to the used norm. \[ \frac{\partial \mathcal{R}(\vartheta)}{\partial \vartheta_i} = \frac{\lambda}{2} \sum_j \frac{\partial \vartheta_j^2}{\partial \vartheta_i} \overset{\star}{=} \frac{\lambda}{2} \frac{\partial \vartheta_i^2}{\partial \vartheta_i} = \lambda \vartheta_i \] At \(\star\) we use, similarly to before, that the derivative \(\partial \vartheta_j^2 / \partial \vartheta_i\) is zero unless \(j=i\). Therefore, it follows that \(\nabla \mathcal{R} = \lambda \vartheta\). Because of the already mentioned linearity of the gradient operator we finally get (it is an exercise to the reader to verify this) \[ \begin{align} (\nabla_\vartheta\,Q_\nu)_\text{reg} &:= \frac{1}{|B_\nu|} \sum_{(x,t) \in B_\nu} \nabla_\vartheta\, \mathcal{L}_\text{reg}(x^{(n)},\,t) \\[8px] &= \nabla_\vartheta\,Q_\nu + \lambda\vartheta \, . \end{align} \] In standard SGD \(L_2\) regularization causes parameters to shrink by a uniform multiplicative factor, a behavior commonly referred to as weight decay. However, in adaptive methods such as Adam, the automatic gradient scaling interferes with the regularization term, breaking uniform weight decay. AdamW fixes this issue by decoupling the regularization term from the gradient calculation and applying it directly to the parameter update, i.e. using \[ \nabla_\vartheta\,Q_\nu\, \text{ and }\, g_\nu := \alpha \hat{m}_\nu / \big(\sqrt{\hat{v}_\nu} + \epsilon\big) + \lambda \vartheta \] thereby restoring the intended behavior. Empirically, this has been shown to substantially improve the model's generalization performance (see the referenced paper for more details).
This final section describes the implementation of the mathematical framework developed above, as well as several practical details that have not been covered yet. Consider the following UML diagram; for clarity, only conceptually relevant components are shown. If you wish to explore the full codebase, you can do so on my GitHub.
Each component has a corresponding theoretical expression shown after the |. For
instance, the forward-pass equation is implemented in
PY
class Activation:
_default_config = { "name": "sigmoid", "T": 1.0 }
def __init__(self, config: dict[str, Any] = _default_config):
"""
Create a new Activation instance from given config.
Defaults to sigmoid function 1 / (1 + exp(-x)).
:param config: parameter object for activation factory
"""
self._eval, self._rate = make(config)
def eval(self, z: Vector) -> Vector:
"""
Calculate the post-activation signal sigma(z).
:param z: pre-activation signal
:return: post-activation signal
"""
return self._eval(z)
def rate(self, v: Vector) -> Vector:
"""
Calculate the elementwise rate of change
(derivative) of the activation function.
:param v: (any) input vector
:return: derivative vector
"""
return self._rate(v)
Both methods
PY
@factory
def make_sigmoid(T: float = 1.0) -> Callable, Callable:
# ensure temperature is greater 0
val.check_condition(T > 0, "T must be greater 0")
def eval(x: Vector) -> Vector:
"""
Elementwise sigmoid function with temperature.
:param x: input vector x
:param T: temperature parameter (T > 0)
:return: output vector sigmoid(x)
"""
return 1 / (1 + np.exp(-x / T))
def rate(x: Vector) -> Vector:
"""
Elementwise rate of change (derivative) of
the sigmoid function with temperature.
:param x: Description
:return: Description
"""
val = eval(x)
return val * (1 - val) / T
return eval, rate
To instantiate a sigmoid function and its derivative, simply call
When using neural networks in practice, one is often limited by the quality and
quantity of available training data. To address this, two classes,
In our case, we start from the
MNIST dataset
of 60,000 examples of handwritten digits and apply a random affine transformation to
each of them four times. This produces 240,000 additional examples consisting of
slightly scaled, sheared, rotated, and translated digits for training. Finally, the
Plotted above is the accuracy of the exact neural network used in the demo at the top of this article, evaluated on the MNIST test dataset (10,000 additional examples not used in training) during the first epoch (60,000 batches of 5 examples each). The network architecture consists of two hidden layers with 1000 and 500 neurons, respectively. The weights were initialized from a normal distribution, while the biases were initialized from a uniform distribution. For more details, please refer to the train-configs.json file inside the final model .zip.
Note that the initial accuracy of the model is about 10%, which is expected since
there are 10 possible output digits, which the model, at that stage, is basically
picking at random. During the first epoch alone, the test accuracy increased quickly
to about 95%, already yielding strong performance. After training for 25 epochs, the
model ultimately achieved an
accuracy. Higher performance on MNIST is typically only achieved using specialized
image recognition architectures such as convolutional neural networks (CNNs).
I am honestly super happy with the final result of this project. I never had the opportunity to take any machine learning classes in university, and all I had to go off initially was this book. Unfortunately though, while it is an excellent introduction, it almost completely avoids the proper mathematical foundations. Therefore, I ended up searching for a bunch of additional information online, either deriving parts myself or following research papers (such as the one introducing AdamW).
In the end, I learned a lot though and, most importantly, had a ton of fun working on
this project. I am now exited to try my code on some other machine learning
applications, since MNIST is often regarded as the the