Multivariate beta distribution
The beta distribution is a family of continuous probability distributions on the interval [0,1]. Because of this, it is often used as a prior for probabilities. For example, a draw from a beta distribution can be viewed as a random probability for a Bernoulli random variable. The beta distribution has two parameters that control the shape of the distribution (Figure 1). These are typically denoted with α and β, but, for reasons we’ll see in a moment, I’ll use α1 and α2 instead.
The probability density function (PDF) of the beta distribution is:
beta(x∣α1,α2)=B(α1,α2)1xα1−1(1−x)α2−1,α1,α2>0.(1)
For now, let’s ignore the normalizing constant, a beta function. How can we interpret the expression xα1−1(1−x)α2−1? Here’s how I think about it. First, fix α1=α2=1. Anything raised to the zero power is 1, so the PDF is 1 everywhere. Now imagine we increase α1 while keeping α2 fixed at 1. We then have an exponential function with base x and power α1−1. If we increase α2 while keeping α1 fixed at 1, we have a decaying exponential function. What happens if both α1 and α2 are greater than 1? Two exponentially shaped curves, one increasing and the other decaying, mix. If α1=α2, they their peak is at x=0.5; otherwise, there will be an asymmetry. Finally, if α1 and α2 are both less than zero, one exponential curve decays until giving rise to an increasing curve (Figure 1.)
The value x actually lives on a simplex, or a generalization of the notion of a triangle to arbitrary dimensions. Geometrically, all points on a simplex are a convex combination of the vertices. In this case, x lives on a 1-simplex, or any combination of the points (0,1) and (1,0).
Now imagine we could take the curves in Figure 1 and extrude them out of the screen to form a multivariate density. The support of this density should now be a 2-simplex, or any point in the convex hull defined by points (0,0,1), (0,1,0), and (1,0,0). We can visualize this by indicating the height with a color gradient (Figure 2).
Whatever distribution this is should have three parameters; let’s call them α1, α2, and α3. Without knowing the normalizing constant, we can guess that the PDF should look like this:
(x1)α1−1(x2)α2−1(1−x1−x2)α3−1.(2)
These shape parameters must also be non-negative. What we’re constructing is just a multivariate beta distribution, which has its own name: the Dirichlet distribution. Values from a K-dimensional Dirichlet distribution live on a (K−1)-simplex. The general PDF of the Dirichlet distribution is
Dir(α)=B(α)1x1(α1−1)…xK(αK−1),α1,…,αK>0(3)
We now define the normalizing constant,
B(α)=Γ(α1+⋯+αK)Γ(α1)…Γ(αK).(4)
While this may look abstract, there is some intuition. The gamma function is an extension of the factorial function to complex numbers. This is why, for any positive integer z,
Γ(z)=(z−1)!(5)
I think of the normalizing constant, the beta function, as playing a similar role as the binomial coefficient (“n choose k”) in the binomial distribution. When computing the probability of k successes in n coin flips, we must also do some bookkeeping since there are possibly many different ways to get k successes.
Finally, one interesting fact about the Dirichlet distribution is that its marginals are beta distributions. At least visually, this is intuitive. Take the 2-simplex in Figure 2 and collapse it to any side. We would get a 1-simplex controlled by the associated end points. While tedious, it is fairly straightforward to compute this. First, let A=∑kαk. Then we can write the Dirichlet PDF as:
Γ(α1)Γ(A−α1)Γ(A)θ1α1−1(1−θ1)A−α1−1×Γ(α2)Γ(A−α1−α2)Γ(A−α1)(1−θ1)A−α1−1θ2α2−1(1−θ1−θ2)A−α1−α2−1×Γ(α3)Γ(A−α1−α2−α3)Γ(A−α1−α2)(1−θ1−θ2)A−α1−α2−1θ3α3−1(1−θ1−θ2−θ3)A−α1−α2−α3−1⋮×Γ(αK−1)Γ(A−α1−⋯−αK−1)Γ(A−α1−⋯−αK−2)(1−θ1−⋯−θK2)αK−1+αK−1θK−1αK−1−1θKαK−1−1(6)
The terms colored in blue will cancel, leaving us with Eq. 3. However, notice that we can also write the Dirichlet joint distribution as
p(θ)=p(θ1)p(θ2∣θ1)p(θ3∣θ1,θ2)…p(θK−1∣θ1,…,θK−1).(7)
We don’t have to specify the final probability because it is fixed given the other values of θ. Matching terms in Eq. 6 with Eq. 7—so by the first line in Eq. 6—, we see that the marginal distribution p(θ1) is a beta distribution:
p(θ1)=beta(α1,A−α1).(8)
Clearly, we could rewrite Eq. 6 and factorize Eq. 7 to put the j-th term first. So in fact, the result in Eq. 8 is general to any parameter θj.
Without computation, I think we can intuit what the remainder—the rest of the joint distribution in Eq. 7—should be. Imagine we have a K-simplex, and we could “remove” an edge. That edge or marginal would be beta distributed, but the remainder would be a (K−1)-simplex, or another Dirichlet distribution.
Multinomial–Dirichlet distribution
Now that we better understand the Dirichlet distribution, let’s derive the posterior, marginal likelihood, and posterior predictive distributions for a very popular model: a multinomial model with a Dirichlet prior. These derivations will be very similar to my post on Bayesian inference for beta–Bernoulli models. Why? Well, a Bernoulli random variable can be thought of as modeling one coin flip with bias θ. A binomial random variable is then flipping the same coin M times. And a multinomial random variable is rolling a K-sided die M times. My point is that these models have a common structure. Furthermore, as we just saw, the Dirichlet distribution is just the multivariate beta distribution.
We assume our data X={x1,…,xN} are multinomial distributed with a Dirichlet prior:
xnθ∼multi(M,θ),k=1∑KθK=1,∼Dir(α),α1,…,αK>0.(9)
Recall that by definition of multinomial random variables, ∑kxn,k=M.
Posterior. Showing conjugacy by deriving the posterior is relatively easy. The likelihood times prior is
p(θ∣X)∝n=1∏Np(xn∣θ)p(θ)=n=1∏N(xn,1!…xn,K!M!θ1xn,1…θKxn,K)(B(α)1θ1(α1−1)…θK(αK−1))∝θ1(∑nxn,1+α1−1)…θK(∑nxn,K+αK−1).(10)
So we see the posterior is proportional to a Dirichlet distribution:
p(θ∣X)αN,k=Dir(X∣αN),=n=1∑Nxn,k+αk.(11)
The other terms, namely the normalizers, do not matter because they do not depend on θ. This is just a multivariate generalization of our derivation for the beta–Bernoulli model. Furthermore, it has the same intuition. As we observe more values k, the k-th component of αN increases, placing more mass at that location on the simplex.
Marginal likelihood. To compute the marginal likelihood, we need the integral definition of the multivariate beta function:
B(α)=∫01∫01−θ1⋯∫01−θ1−⋯−θK−2θ1α1−1θ2α2−1…θK−1αK−1−1θKαK−1dθ1…dθK−1.(12)
If this definition is a lot, I suggest first writing out the definition of the beta integral for B(α1,α2) and then extending it to three dimensions, keeping in mind that the sum of θk terms is one. With this definition in mind, let’s integrate over Eq. 10 without dropping the normalizing constants,
p(X)=∫p(X∣θ)p(θ)dθ=∫n=1∏Np(xn∣θ)p(θ)dθ=∫n=1∏N(xn,1!…xn,K!M!θ1xn,1…θKxn,K)(B(α)1θ1(α1−1)…θK(αK−1))dθ=n=1∏Nxn,1!…xn,K!M!B(α)1∫θ1(∑nxn,1+α1−1)…θK(∑nxn,K+αK−1)dθ=n=1∏Nxn,1!…xn,K!M!B(α)B(αN).(13)
We can verify this by comparing it to the probability density function for the Dirichlet–multinomial compound distribution. Let N=1. Then Eq. 13 is
p(x)=x1!…xK!M!B(α)B(αN)=x1!…xK!M!∏kΓ(αN,k)Γ(∑kαN,k)Γ(∑kxk+αk)∏k=1KΓ(xk+αk)=Γ(M+∑kαk)(M!)Γ(∑kαN,k)k=1∏K(xk)!Γ(αN,k)Γ(xk+αk).(14)
This is messy in that it is not a well-known or easy-to-reason about distribution—at least for me. But again, we can imagine that as more observations contain the k-th component, the marginal likelihood will shift mass appropriately.
Posterior predictive. This derivation is quite similar to the derivation for the marginal likelihood. To compute the posterior predictive over an unseen observation x^, we need to integrate out our uncertainty about the parameters w.r.t. the inferred posterior:
p(x^∣X)=∫p(x^∣θ)p(θ∣X)dθ=∫multi(x^∣θ)Dir(θ∣αN)dθ=∫(x^1!…x^K!M!θ1x^1…θKx^K)(B(αN)1θ1αN,1−1…θKαN,K−1)dθ=x^1!…x^K!M!B(αN)1∫θ1x^1+αN,1−1…θKx^K+αN,K−1dθ.(15)
Again, we use the integral definition of the multivariate beta function. The integral in the last line of Eq. 15 is equal to
B(n=1∑Nx^1+α1,…,n=1∑Nx^K+αK).(16)
Putting everything together, we see that the posterior predictive is
p(x^∣X)=x^1!…x^K!M!B(α)B(x^+αN).(17)
Posterior predictive (single trial). Finally, another way to write this posterior predictive is when M=1, e.g. when we roll a K-sided die exactly once rather than M times. First, let θ−j denote all components but the j-th component of the vector θ. The marginal of θj is
p(θj)=∫θ−jp(θ−j,θj)dθ−j.(18)
So the probability that x^=j can be written as
p(x^=j∣X)=∫p(x^=j∣θ)p(θ∣X)dθ=∫θj∫θ−jp(x^=j∣θ)p(θ∣X)dθ−jdθj=∫θjθj(∫θ−jp(θ∣X)dθ−j)dθj=∫θjθjp(θj∣X)dθj=E[θj∣X].(19)
What is this expectation? Well, the posterior p(θ∣X) is a Dirichlet distribution, and the marginal of the Dirichlet is a beta distribution:
p(θ∣X)p(θj∣X)=Dir(X∣αN)⇓=beta(αN,j,AN−αN,j).(20)
Here, I use AN to denote the sum ∑kαN,k. So the expectation is:
E[θj∣X]=∑k(∑n=1Nxn,k+αk)∑nxn,j+αj.(21)
It’s messy notation, but you can think of the numerator as counting “successes” for the j-th component or outcome, and the numerator as tracking all values. A cleaner notation is as follows. Let Nj denote the number of times xn,k=j. And clearly, the left term in the numerator can be written as N, since each of N observation x must sum to M=1. Finally, writing the sum across αk as A gives us
E[θj∣X]=N−ANj+αj.(21)
Again, we see that as more observations take the value j, the marginal beta distribution—this expected value—shifts in favor of that outcome.
Categorial likelihood
As a final note, observe that a categorial random variable is just a multinomial random variable for a single trial,
xx∼multi(M=1,θ)⇓∼cat(θ).(22)
As such, all the derivations above are quite similar for a categorical–Dirichlet model. The only thing that changes are normalizing constants.