The KL Divergence: From Information to Density Estimation

The KL divergence, also known as "relative entropy", is a commonly used metric for density estimation. I re-derive the relationships between probabilities, entropy, and relative entropy for quantifying similarity between distributions.

In statistics, the Kullback–Leibler (KL) divergence is a metric for how similar two probability distributions are. A standard formulation—and the one I encountered first—is the following. Given two probability distributions PP and QQ, the KL divergence is the integral

DKL[PQ]=p(x)logp(x)q(x)dx(1) D_{\text{KL}}[P \lVert Q] = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} dx \tag{1}

In this post, I want to show how Eq. 11 is both a measure of relative information entropy and a reasonable way to compare densities.

Relative information

Let us reconstruct the notion of information entropy (abbr. information) from first principles. The information received from a random variable taking a particular value can be viewed as a measure of our “surprise” about that value. A value with low information is not surprising; a value with high information is. Now imagine you didn’t know or couldn’t remember the equation for information. What would be a sensible formulation? First, it makes sense that information is a monotonic function of probability. Higher probability means strictly lower information. For example, rolling a die and getting an even number should be less surprising than rolling a die and getting a 22.

It would also make sense that information should be additive for independent events. If I roll a die and flip a coin, my total information should be some additive combination of the two probabilities. This thought exercise is useful because, at least for me, it makes clear why the information about a random variable XX taking on a value xx, denoted h(x)h(x), is defined as it is:

h(x)=logp(x)H(X)=E[h(x)] h(x) = - \log p(x) \quad\quad H(X) = \mathbb{E}[h(x)]

The negative sign is because higher probabilities result in less information. And the log\log is a monotonically increasing function of probabilities that has the useful property that two independent random events have additive information:

h(x,y)=logp(x,y)=log{p(x)p(y)}=logp(x)logp(y)=h(x)+h(y) \begin{aligned} h(x, y) &= - \log p(x, y) \\ &= - \log \{ p(x) p(y) \} \\ &= - \log p(x) - \log p(y) \\ &= h(x) + h(y) \end{aligned}

What does that mean for the KL divergence? Let’s consider Eq. 11 again, but now write it in terms of information:

DKL[PQ]=p(x)logp(x)q(x)dx=Ep(x)[logp(x)q(x)]=Ep(x)[logp(x)logq(x)]=Ep(x)[logq(x)]Ep(x)[logp(x)]=H(Q)H(P)(2) \begin{aligned} D_{\text{KL}}[P \lVert Q] &= \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} dx \\ &= \mathbb{E}_{p(x)} \big[ \log \frac{p(x)}{q(x)} \big] \\ &= \mathbb{E}_{p(x)} \big[ \log p(x) - \log q(x) \big] \tag{2} \\ &= \mathbb{E}_{p(x)} \big[ - \log q(x) \big] - \mathbb{E}_{p(x)} \big[ - \log p(x) \big] \\ &= H(Q) - H(P) \end{aligned}

In other words, one interpretation of the KL divergence is that it captures the relative information or relative entropy between two distributions PP and QQ. Also note that the KL divergence is not symmetric, i.e. DKL[PQ]DKL[QP]D_{\text{KL}}[P \lVert Q] \neq D_{\text{KL}}[Q \lVert P] in general.

At this point, it makes sense that the KL divergence might be a good metric for understanding how similar two distributions are. But why does minimizing the KL divergence between two densities—as in variational inference—guarantee that our optimization objective is performing density estimation? The answer to this relies on the convexity of logarithms and Jensen’s inequality.

Nonnegativity of the KL divergence

First, the big picture. We want to use the notion of convexity to prove Jensen’s inequality. Jensen’s inequality will allow us to move the logarithm in Eq. 11 outside the integral. Since the integral of a density is 11, the log of the integral is 00. This will provide a lower bound on the KL divergence or formally: DKL0D_{\text{KL}} \geq 0 with equality when p(x)=q(x)p(x) = q(x). With that in mind, let’s move forward.

A function ff is convex if the following holds

f(λa+(1λ)b)λf(a)+(1λ)f(b) f(\lambda a + (1 - \lambda) b) \leq \lambda f(a) + (1 - \lambda) f(b)

for some 0λ00 \leq \lambda \leq 0. This is a common formulation, and the reader can find numerous explanations and visualizations for why this is true. Intuitively, the function of any point between aa and bb inclusive is less than or equal to any point between f(a)f(a) and f(b)f(b). Draw a few functions on a piece of paper and see which ones are convex.

An aside: proof of Jensen’s inequality

But at this point, I think many explanations of the KL divergence skip a step. They say something like, “And by Jensen’s inequality…” without proving Jensen’s inequality. Let’s actually do that. Jensen’s inequality is

f(i=1Nλixi)i=1Nλif(xi)(3) f \Big( \sum_{i=1}^{N} \lambda_i x_i \Big) \leq \sum_{i=1}^{N} \lambda_i f(x_i) \tag{3}

where λi0\lambda_i \geq 0 and iλi=1\sum_i \lambda_i = 1. The proof is by induction. Let ff be a convex function. Now consider the base case:

f(λ1x1+λ2x2)λxf(x1)+λ2f(x2) f (\lambda_1 x_1 + \lambda_2 x_2) \leq \lambda_x f(x_1) + \lambda_2 f(x_2)

This clearly holds because ff is convex and x1+x2=1    (1x1)=x2x_1 + x_2 = 1 \iff (1 - x_1) = x_2. Now for the inductive case, we want to show that

f(i=1Kλixi)i=1Kλif(xi)    f(i=1K+1λixi)i=1K+1λif(xi) f \Big( \sum_{i=1}^{K} \lambda_i x_i \Big) \leq \sum_{i=1}^{K} \lambda_i f(x_i) \implies f \Big( \sum_{i=1}^{K+1} \lambda_i x_i \Big) \leq \sum_{i=1}^{K+1} \lambda_i f(x_i)

First, let’s start with our inductive hypothesis and add λK+1f(xK+1)\lambda_{K+1} f(x_{K+1}) to both sides:

f(i=1Kλixi)+λK+1f(xK+1)i=1K+1λif(xi) f \Big( \sum_{i=1}^{K} \lambda_i x_i \Big) + \lambda_{K+1} f(x_{K+1}) \leq \sum_{i=1}^{K+1} \lambda_i f(x_i)

Now the λ\lambdas on the right-hand-side no longer sum to 11. Let’s normalize both sides of the equation by multiplying by 11+λK+1\frac{1}{1 + \lambda_{K+1}}:

11+λK+1Af(i=1Kλixi)+λK+11+λK+1Bf(xK+1)11+λK+1i=1K+1λif(xi) \overbrace{\frac{1}{1 + \lambda_{K+1}}}^{A} f \Big( \sum_{i=1}^{K} \lambda_i x_i \Big) + \overbrace{\frac{\lambda_{K+1}}{1 + \lambda_{K+1}}}^{B} f(x_{K+1}) \leq \frac{1}{1 + \lambda_{K+1}} \sum_{i=1}^{K+1} \lambda_i f(x_i)

This normalization constant makes sense because i=1Kλi=1    i=1K+1λi=1+λK+1\sum_{i=1}^{K} \lambda_i = 1 \iff \sum_{i=1}^{K+1} \lambda_i = 1 + \lambda_{K+1}. Now note that the terms labeled AA and BB above sum to 11. And since ff is convex, we can say

f(11+λK+1i=1Kλixi+λK+11+λK+1xK+1)11+λK+1f(i=1Kλixi)+λK+11+λK+1f(xK+1) f \Big( \frac{1}{1 + \lambda_{K+1}} \sum_{i=1}^{K} \lambda_i x_i + \frac{\lambda_{K+1}}{1 + \lambda_{K+1}} x_{K+1} \Big) \leq \frac{1}{1 + \lambda_{K+1}} f \Big( \sum_{i=1}^{K} \lambda_i x_i \Big) + \frac{\lambda_{K+1}}{1 + \lambda_{K+1}} f(x_{K+1})

At this point, we’re basically done. The left-hand-side of the above inequality can be simplified to

f(11+λK+1i=1K+1λixi) f \Big( \frac{1}{1 + \lambda_{K+1}} \sum_{i=1}^{K+1} \lambda_i x_i \Big)

which we have already shown is less than or equal to 11+λK+1i=1K+1λif(xi)\frac{1}{1 + \lambda_{K+1}} \sum_{i=1}^{K+1} \lambda_i f(x_i) as desired.

Jensen’s inequality for distributions

Now consider this: since λi0\lambda_i \geq 0 and iλi=1\sum_{i} \lambda_i = 1, we can interpret λi\lambda_i as the probability of our random variable XX taking on a specific value xix_i, giving us

f(E[X])E[  f(X)] f \big( \mathbb{E}[X] \big) \leq \mathbb{E} \big[ \; f(X) \big]

which for continuous densities is equivalent to

f(xp(x)dx)f(x)p(x)dx f \Big( \int x p(x) dx \Big) \leq \int f(x) p(x) dx

Since logarithms are convex functions, we can apply Jensen’s inequality to the KL divergence to prove a lower bound:

DKL[PQ]=p(x)logp(x)q(x)dx=p(x)logq(x)p(x)dxlogq(x)dx=0 \begin{aligned} D_{\text{KL}}[P \lVert Q] &= \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} dx \\ &= - \int_{-\infty}^{\infty} p(x) \log \frac{q(x)}{p(x)} dx \\ &\geq - \log \int_{-\infty}^{\infty} q(x) dx \\ &= 0 \end{aligned}

We first flip the fraction so that the p(x)p(x) terms cancel, then apply Jensen’s inequality, and finally use the fact that log(1)=0\log (1) = 0.

Conclusion

In summary, we have used convexity to prove Jensen’s inequality to prove that the KL divergence is always nonnegative. If we minimize the KL divergence between two densities, we are minimizing the relative information between the two distributions.