Conjugacy in Bayesian Inference

Conjugacy is an important property in exact Bayesian inference. I work though Bishop's example of a beta conjugate prior for the binomial distribution and explore why conjugacy is useful.

In Bayesian inference, a prior p(θ)p(\theta) is conjugate to the likelihood function p(xθ)p(x \mid \theta) when the posterior has the same functional form as the prior. This means that the two boxed terms in Bayes’ formula below have the same functional form:

p(θx)=p(xθ)p(θ)p(xθ)p(θ)dθ \boxed{p(\theta \mid x)} = \frac{p(x \mid \theta) \, \boxed{p(\theta)}}{\int p(x \mid \theta') p(\theta') \text{d} \theta'}

The goal of this post is to work through an example of a conjugate prior to better understand why conjugacy is a useful property.

As a running example, imagine that we have a coin with an unknown bias. Estimating the bias μ\mu is statistical inference; Bayesian inference is assuming a prior on μ\mu; and conjugacy is assuming the prior is conjugate to the likelihood. We will explore these ideas in order.

Modeling a Bernoulli process

We want to estimate the bias μ\mu of our coin. First, we flip the coin nn times. The outcome of the iith coin toss is xix_i, a Bernoulli random variable that takes values 00 or 11 with probability μ\mu or (1μ)(1 - \mu) respectively. Without loss of generality, 11 is a success (heads) and 00 is a failure (tails). The sequence of coin flips is a Bernoulli process of i.i.d. Bernoulli random variables, x1,x2,,xnx_1, x_2, \dots, x_n. Let D\mathcal{D} be these nn coin flips or D={x1,x2,,xn}\mathcal{D} = \{x_1, x_2, \dots, x_n\}, and let m0m \geq 0 be the number of successes.

The probability of mm successes in nn trials is a binomial random variable or

p(xi=m)=(nm)μm(1μ)nm p(x_i = m) = {n \choose m} \mu^m (1 - \mu)^{n-m}

In words, μm(1μ)nm\mu^m (1 - \mu)^{n-m} is the probability of mm successes and nmn-m failures, all independent, in a single sequence of coin flips, and the binomial coefficient (nm){n \choose m} is the number of combinations of coin flips that can have mm successes and nmn-m failures.

Being statisticians, we estimate μ\mu by maximizing the likelihood of the data given the parameter or by computing

p(Dμ)=i=1np(xiμ)=i=1n(nm)μxi(1μ)1xi=(nm)μm(1μ)nm(1) p(\mathcal{D} \mid \mu) = \prod_{i=1}^{n} p(x_i \mid \mu) = \prod_{i=1}^{n} {n \choose m} \mu^{x_i} (1 - \mu)^{1 - x_i} = {n \choose m} \mu^m (1 - \mu)^{n-m} \tag{1}

where the series of products is due to our modeling assumption that coin flips are independent. We then take the log of this—maximizing the log of a function is equivalent to maximize the function itself and logs allow us to leverage the linearity of differentiation—to get

logp(Dμ)=log(nm)+mlogμ+(nm)log(1μ) \log p(\mathcal{D} \mid \mu) = \log {n \choose m} + m \log \mu + (n - m) \log(1 - \mu)

Finally, to solve for the value of μ\mu that maximizes this function, we compute the derivative of logp(Dμ)\log p(\mathcal{D} \mid \mu) with respect to μ\mu, set it equal to 00, and solve for μ\mu. The derivative is

μlogp(Dμ)=μlog(nm)+μmlogμ+μ(nm)log(1μ)=mμnm1μ \begin{aligned} \frac{\partial}{\partial \mu} \log p(\mathcal{D} \mid \mu) &= \frac{\partial}{\partial \mu} \log {n \choose m} + \frac{\partial}{\partial \mu} m \log \mu + \frac{\partial}{\partial \mu} (n - m) \log(1 - \mu) \\ &= \frac{m}{\mu} - \frac{n - m}{1-\mu} \end{aligned}

Note that the normalizer (nm){n \choose m} disappears because it does not depend on μ\mu. Solving for μ\mu when the derivative is equal to 00, we get

μML=mn \mu_{\text{ML}} = \frac{m}{n}

Now this works. But imagine that we flip the coin three times and each time it comes up heads. Assuming the coin is actually fair, this happens with probability 1/81 / 8, but our maximum likelihood estimate of μ\mu is 33=1\frac{3}{3} = 1. In other words, we are overfitting. One way to address this is by being Bayesian, meaning we want to place a prior probability on μ\mu. Rather than maximizing logp(Dμ)\log p(\mathcal{D} \mid \mu), we want to maximize

logp(Dμ)+logp(μ) \log p(\mathcal{D} \mid \mu) + \log p(\mu)

Intuitively, imagine that most coins are fair. Then even if we see three heads in a row, we want to incorporate this prior knowledge about what μ\mu typically is into our model. This is the role of a Bayesian prior p(μ)p(\mu).

A beta prior

What sort of prior should we place on μ\mu? Let’s make two modeling assumptions. First, let’s assume that most coins are fair. This means that μ=0.5\mu = 0.5 should be the mode of the distribution. And let’s assume that biased coins do not favor heads over tails or vice versa. This means we want a symmetric distribution. One distribution that may have these properties is the beta distribution, given by

Beta(μa,b)=Γ(a+b)Γ(a)Γ(b)μa1(1μ)b1(2) \text{Beta}(\mu \mid a, b) = \frac{\Gamma(a + b)}{\Gamma(a) \Gamma(b)} \mu^{a-1} (1-\mu)^{b-1} \tag{2}

where Γ(x)\Gamma(x) is the gamma function

Γ(x)=0μx1eμdμ \Gamma(x) = \int_{0}^{\infty} \mu^{x-1} e^{-\mu} \text{d} \mu

and Γ(a+b)Γ(a)Γ(b)\frac{\Gamma(a + b)}{\Gamma(a) \Gamma(b)} normalizes the distribution. The beta distribution is normalized so that

01Beta(μa,b)dμ=1 \int_{0}^{1} \text{Beta}(\mu \mid a, b) \text{d} \mu = 1

and has a mean and variance given by

E[μ]=aa+bVar(μ)=ab(a+b)2(a+b+1) \begin{aligned} \mathbb{E}[\mu] &= \frac{a}{a+b} \\ \text{Var}(\mu) &= \frac{ab}{(a+b)^2 (a+b+1)} \end{aligned}

The hyperparameters aa and bb (so-named because they are not learned like the parameter μ\mu) control the shape of the distribution (Figure 11). Given our modeling assumptions, hyperparameters a=b=2a = b = 2 seem reasonable.

Figure 1: The beta distribution for a variety of hyperparameters aa and bb.

But another useful fact of the beta distribution—and the reason we picked it over the more obvious Gaussian distribution as our prior—is that it is conjugate to our likelihood function. Let’s see this. Let l=nml = n - m be the number of failures. If we multiply our likelihood (Equation 11) by our prior (Equation 22), we get a posterior that has the same functional form as the prior:

p(μm,l,a,b)=Γ(a+b)Γ(a)Γ(b)μa1(1μ)b1i=1nμxi(1μ)1xiμm+a1(1μ)l+b1 \begin{aligned} p(\mu \mid m, l, a, b) &= \frac{\Gamma(a+b)}{\Gamma(a) \Gamma(b)} \mu^{a-1} (1-\mu)^{b-1} \prod_{i=1}^{n} \mu^{x_i} (1 - \mu)^{1 - x_i} \\ &\propto \mu^{m+a-1} (1-\mu)^{l+b-1} \end{aligned}

Note that we only care about proportionality because Γ(a+b)/Γ(a)Γ(b)\Gamma(a+b)\,/\,\Gamma(a) \Gamma(b) is a constant that is independent of the parameter we want to learn and our data. We can see that our posterior is another beta distribution, and we can easily normalize it:

p(μm,l,a,b)=Γ(m+a+l+b)Γ(m+a)Γ(l+b)μm+a1(1μ)l+b1(3) p(\mu \mid m, l, a, b) = \frac{\Gamma(m+a+l+b)}{\Gamma(m+a)\Gamma(l+b)} \mu^{m+a-1} (1-\mu)^{l+b-1} \tag{3}

Finally, we now want to maximize Equation 33 rather than Equation 11. Recall that this is maximum a posteriori (MAP) estimation because, unlike maximum likelihood estimation, we account for a prior. Because this new posterior is in a tractable form, it is straightforward to compute μMAP\mu_{\text{MAP}}. Let’s first compute the derivative of our new log likelihood with respect to μ\mu:

logp(μm,l,a,b)=log(C)+(m+a1)log(μ)+(l+b1)log(0μ)μlogp(μm,l,a,b)=m+a1μl+b11μ \begin{aligned} \log p(\mu \mid m, l, a, b) &= \log(C) + (m + a - 1) \log(\mu) + (l + b - 1) \log(0 - \mu) \\ \frac{\partial}{\partial \mu} \log p(\mu \mid m, l, a, b) &= \frac{m + a - 1}{\mu} - \frac{l + b - 1}{1 - \mu} \end{aligned}

where CC is the normalizing constant as in the previous section. Setting this equal to 00 and doing some algebra, we get

μMAP=m+a1n+a+b2(4) \mu_{\text{MAP}} = \frac{m + a - 1}{n + a + b - 2} \tag{4}

Note that if n=m=0n = m = 0, then μMAP=12\mu_{\text{MAP}} = \frac{1}{2}. In words, if we can’t flip a coin to estimate its bias, then the best we can do is assume the bias is the mode of our prior. And recall our small pathological example from before, the scenario when both n=3n = 3 and m=3m = 3. With our prior with hyperparameters a=b=2a = b = 2, we have

μMAP=3+213+2+22=45. \mu_{\text{MAP}} = \frac{3 + 2 - 1}{3 + 2 + 2 - 2} = \frac{4}{5}.

This demonstrates why the prior is especially important for parameter estimation with small data and how it helps prevent overfitting.

Benefits of conjugacy

I want to discuss two main benefits of conjugacy. The first is analytic tractability. Computing μMAP\mu_{\text{MAP}} was easy because of conjugacy. Imagine if our prior on μ\mu was the normal distribution. We would have had to optimize

p(μm,l,σ2,ν)=12πσ2exp((μν)22σ2)i=1nμxi(1μ)1xi p(\mu \mid m, l, \sigma^2, \nu) = \frac{1}{\sqrt{2 \pi \sigma^2}} \text{exp} \Big(\frac{-(\mu - \nu)^2}{2 \sigma^2}\Big) \prod_{i=1}^{n} \mu^{x_i} (1 - \mu)^{1 - x_i}

where σ2\sigma^2 and ν\nu are hyperparameters for the normal distribution. In the absence of techniques such as variational inference, conjugacy makes our lives easier.

The second benefit of conjugacy is that it lends itself nicely to sequential learning. In other words, as the model sees more data, the posterior at step tt can become the prior at step t+1t+1. We simply need to update our prior and re-normalize. For example, imagine we process individual coin flips one at a time. Every time we see a heads (xi=1x_i = 1), we increment nn and mm. Otherwise, we increment nn and ll. Alternatively, we could fix n=m=l=0n = m = l = 0 and just increment aa and bb respectively (Equation 44).

Using this technique, we can visualize the posterior over sequential observations (Figure 22).

Figure 2: Visualizing the posterior distribution (Equation 33) for nn samples in {0,10,20,30}\{0, 10, 20, 30\}. As the model sees more data, the posterior places more density around the true bias from the generative process. And after each data point, the posterior becomes the new prior.

The upshot is that the posterior distribution becomes more and more peaked around the true bias as our model sees more data. Note that the yy-axes are at different scales and therefore the last frame is even more peaked than the second-to-last frame. Also note that the posterior is slightly underestimating the true parameter μ=0.8\mu = 0.8, possibly because of the influence of the prior.

Conclusion

Conjugate priors are an important concept in Bayesian inference. Especially when one wants to perform exact Bayesian inference, conjugacy ensures that the posterior is tractable even after multiplying the likelihood times the prior. And they allow for efficient inference algorithms because the posterior and prior share the same functional form.

As a final comment, note that conjugacy is with respect to a particular parameter. For example, the conjugate prior of a Gaussian with respect to its mean parameter μ\mu is another Gaussian, but the conjugate prior with respect to its variance σ2\sigma^2 is the inverse gamma. The conjugate prior with respect to the multivariate Gaussian’s covariance matrix is the inverse-Wishart, while the Wishart is the conjugate prior for its precision matrix (Wikipedia, 2019). In other words, the conjugate prior depends on the parameter of interest and what form that parameter takes.

   

Acknowledgements

I borrowed some of this post’s outline and notation from Bishop’s excellent introduction to conjugacy (Bishop, 2006). See pages 687468-74 specifically.

  1. Wikipedia. (2019). Conjugate prior. URL: Https://En.wikipedia.org/Wiki/Conjugate_prior/.
  2. Bishop, C. M. (2006). Pattern Recognition and Machine Learning.