A Python Demonstration that Mutual Information Is Symmetric
I provide a numerical demonstration that the mutual information of two random variables, the observations and latent variables in a Gaussian mixture model, is symmetric.
In a previous post, I showed mathematically that mutual information (MI) is a symmetric quantity. In this post, I want to work through a complete example of computing MI: a Gaussian mixture model with observations and latent variables . I’ll show that using quadrature.
Gaussian mixture model
Recall that the probabilistic model for a mixture of Gaussians is
where are weights, and and are means and covariances respectively. In words, our data are modeled as a linear superposition of Gaussian distributions. See (Bishop, 2006) for details. To index mixtures, we introduce a one-hot -vector such that
The goal of this post is to empirically demonstrate that the mutual information between and is symmetric, that Eqs. and are equal,
where . Notice that we can compute the red terms exactly since is a discrete random variable (so expectations are sums) and because we have a closed-form solution to the entropy of the Gaussian . However, we have to approximate the other terms, which we will do using numerical quadrature.
Let’s approximate Eq. and then Eq. and show that we get the same result.
Approximating
Before we get started, let’s load some data and fit our mixture model. We have to fit a model because we’re conditioning on :
from sklearn.datasets import load_iris
from sklearn.mixture import GaussianMixture
X, _ = load_iris(return_X_y=True)
X = X[:, 0][:, None] # Let's stick to 1-dimensional data.
K = 3
gmm = GaussianMixture(n_components=K)
gmm = gmm.fit(X)
mus = gmm.means_
sigs = np.sqrt(gmm.covariances_)
pis = gmm.weights_
Z = gmm.predict(X)
To compute the right-hand-side (RHS) of Eq. ,
we simply sum over the entropies of each Gaussian likelihood. Notice that is just the inferred and entropy of is just the entropy of the Gaussian :
from scipy.stats import norm
def rhs3_exact():
return np.sum([
pis[k] * norm(mus[k], sigs[k]).entropy()
for k in range(K)
])
However, we need approximate the left-hand-side (LHS) of Eq. , which we’ll do using quadrature. Notice that the marginal distribution is just Eq. . We then plug this value into the formula for entropy. In code, this is:
from scipy.integrate import quad
def lhs3_quadrature():
def marginal(x):
return np.sum([
pis[k] * norm(mus[k], sigs[k]).pdf(x)
for k in range(K)]
)
def marginal_entropy(x):
px = marginal(x)
return -1 * px * np.log(px + 1e-9)
# Integrate over every possible value of x.
rhs, _ = quad(marginal_entropy, -np.inf, np.inf)
return rhs
The plus 1e-9
in the log calculation is for numerical stability when px
is very small.
Approximating
For Eq. , the LHS is easy; it is the entropy of a categorical random variable:
from scipy.stats import multinomial
def lhs4_exact():
return multinomial(1, pis).entropy()
Finally, we just have to compute the RHS of Eq. , the expected entropy of the posterior over . Notice that we can easily compute the posterior of for a single observation :
The only tricky part is that when performing quadrature, small values of will result in very small probabilities that do not normalize properly. Instead, let’s work in log space and then use the log-sum-exp trick to normalize the posterior:
from scipy.special import logsumexp
def rhs4_quadrature():
def marginal(x):
return np.sum([
pis[k] * norm(mus[k], sigs[k]).pdf(x)
for k in range(K)]
)
def z_posterior(x):
log_p = np.empty(K)
for k in range(K):
log_p[k] = np.log(pis[k]) + norm(mus[k], sigs[k]).logpdf(x)
# Log-sum-exp trick.
z_post = np.exp(log_p - logsumexp(log_p))
assert(np.isclose(z_post.sum(), 1))
return z_post
def exp_entropy_z(x):
mx = marginal(x)
zp = z_posterior(x)
return -1 * mx * np.sum(zp * np.log(zp + 1e-9))
# Integrate over every possible value of x.
rhs, _ = quad(exp_entropy_z, -np.inf, np.inf)
return rhs
Putting everything together, we see that the two approximations produce very similar mutual information estimates. On a single run on my machine, I get:
print(lhs3_quadrature() - rhs3_exact()) # 0.969247379948003
print(lhs4_exact() - rhs4_quadrature()) # 0.9692473862194966
Conclusion
The symmetry of mutual information is nice because it may be easier to compute one formulation than the other. In this example, we might prefer to integrate over the predictive distribution in Eq. or the posterior in Eq. . This is the main idea behind predictive entropy search (Hernández-Lobato et al., 2014), which I’ve discussed previously.
- Bishop, C. M. (2006). Pattern Recognition and Machine Learning.
- Hernández-Lobato, J. M., Hoffman, M. W., & Ghahramani, Z. (2014). Predictive entropy search for efficient global optimization of black-box functions. Advances in Neural Information Processing Systems, 918–926.