In PyTorch, a variety of classification models are available, encompassing domains such as image classification, object detection, and models leveraging recurrent and self-attention mechanisms. These models predominantly utilize cross-entropy as their loss function during the training process. In this tutorial, we aim to explore the rationale behind the widespread adoption of cross-entropy as a loss function and examine its relationship with negative log-likelihood loss.

We will explore the theory and applications of Kullback-Leibler (KL) divergence and cross-entropy in classification, emphasizing their roles in model optimization and the importance of loss functions in enhancing prediction accuracy.

Problem Statement

In classification tasks, it is crucial to align the model's output $\hat{y}^{(i)}$ distribution with the target label ${y}^{(i)}$ distribution for each individual prediction $i$. This alignment is fundamental in assessing the model's capability to accurately predict labels for unseen data.

Key Point: The distribution we are talking about is not the distribution of the samples, but the distribution of output classes for each sample.

Example: If our model's output for sample $i$ includes 10 classes, this length-10 output vector can be viewed as a multinomial distribution, i.e., the probability of occurrence for each category. Thus, the distribution comparison here is between this output's multinomial distribution and the multinomial distribution of the sample's label. For the sample's label's multinomial distribution, the probability is 1 at the label's position and 0 elsewhere.

Prior Knowledge - Entropy

Entropy, denoted as $H(p)$, measures the average amount of information produced by a stochastic source of data. For a discrete random variable with probability mass function $p(x)$, entropy is defined as:

$$ H(p) = -\sum_{x} p(x) \log p(x) $$

The concept of entropy is central in information theory, capturing the unpredictability or uncertainty of a random variable. Higher entropy means more unpredictability.

Kullback-Leibler Divergence

Kullback-Leibler (KL) Divergence measures the dissimilarity between two probability distributions $p$ and $q$, defining the extra bits needed to encode samples from $p$ using $q$. A lower value implies similar distributions. The KL divergence from a source distribution $q$ to a target distribution $p$ over a discrete variable $x$ is mathematically defined as:

$$ D_{KL}(p \parallel q) = \sum_{x} p(x) \log \frac{p(x)}{q(x)} = \sum_{x} p(x) \log p(x) - \sum_{x} p(x) \log q(x) $$

Here, the summation iterates over all possible values of $x$, computing the product of the probability $p(x)$ and the logarithm of the ratio $\frac{p(x)}{q(x)}$ at each $x$. In the context of classification, $p$ represents the true label distribution, often a one-hot vector indicating the correct class, while $q$ embodies the model's predicted probabilities across classes.

Optional: The KL divergence is not symmetric, meaning generally $D_{KL}(p \parallel q) \neq D_{KL}(q \parallel p)$. Although $D_{KL}(p \parallel q)$ is inherently asymmetric, indicating that swapping $p$ and $q$ alters the discrepancy, this characteristic does not undermine its efficacy in evaluating and optimizing prediction accuracy. A lower KL divergence indicates a better model that closely approximates the true label distribution.

Cross-Entropy

We can observe that the term $\sum_{x} p(x) \log p(x)$ represents the negative of the entropy of distribution $p$. In the context of machine learning, since $p$ represents the distribution of label categories and is given, it is considered a constant. Therefore, from an optimization perspective, this part can be ignored. Hence, our optimization target becomes the following:

$$ H(p, q)=- \sum_{x} p(x) \log q(x) $$

This expression definition corresponds to cross-entropy, which is also denoted as $H(p, q)$.

Cross-entropy effectively quantifies the inefficiency of using a predicted probability distribution $q(x)$ to encode events compared to the true distribution $p(x)$. Imagine you have a codebook designed based on $q(x)$ for encoding messages, but the real messages come from $p(x)$. The cross-entropy $H(p, q)$ measures, on average, how many extra bits you would need per message using this mismatched codebook compared to an ideal one tailored to $p(x)$. If $q(x)$ perfectly matches $p(x)$, there's no inefficiency, and the cross-entropy equals the true entropy of $p(x)$. However, any discrepancy between $q(x)$ and $p(x)$ increases this value, reflecting additional bits needed due to the prediction error in $q(x)$.

Trivial: Cross entropy measures the total number of bits required on average, while KL divergence specifies the extra bits needed beyond what is optimal.

This equation may seem unfamiliar when related to classification outputs $\hat{y}$ and labels $y$. By substituting the predictions and labels into the expression, we obtain: