Understanding Dirichlet–Multinomial Models

The Dirichlet distribution is really a multivariate beta distribution. I discuss this connection and then derive the posterior, marginal likelihood, and posterior predictive distributions for Dirichlet–multinomial models.

Multivariate beta distribution

The beta distribution is a family of continuous probability distributions on the interval [0,1][0, 1]. Because of this, it is often used as a prior for probabilities. For example, a draw from a beta distribution can be viewed as a random probability for a Bernoulli random variable. The beta distribution has two parameters that control the shape of the distribution (Figure 11). These are typically denoted with α\alpha and β\beta, but, for reasons we’ll see in a moment, I’ll use α1\alpha_1 and α2\alpha_2 instead.

Figure 1. Six examples of beta distributions with different shape parameters.

The probability density function (PDF) of the beta distribution is:

beta(xα1,α2)=1B(α1,α2)xα11(1x)α21,α1,α2>0.(1) \text{beta}(x \mid \alpha_1, \alpha_2) = \frac{1}{\text{B}(\alpha_1, \alpha_2)} x^{\alpha_1 - 1} (1 - x)^{\alpha_2 - 1}, \quad \alpha_1, \alpha_2 > 0. \tag{1}

For now, let’s ignore the normalizing constant, a beta function. How can we interpret the expression xα11(1x)α21x^{\alpha_1 - 1} (1 - x)^{\alpha_2 - 1}? Here’s how I think about it. First, fix α1=α2=1\alpha_1 = \alpha_2 = 1. Anything raised to the zero power is 11, so the PDF is 11 everywhere. Now imagine we increase α1\alpha_1 while keeping α2\alpha_2 fixed at 11. We then have an exponential function with base xx and power α11\alpha_1 - 1. If we increase α2\alpha_2 while keeping α1\alpha_1 fixed at 11, we have a decaying exponential function. What happens if both α1\alpha_1 and α2\alpha_2 are greater than 11? Two exponentially shaped curves, one increasing and the other decaying, mix. If α1=α2\alpha_1 = \alpha_2, they their peak is at x=0.5x = 0.5; otherwise, there will be an asymmetry. Finally, if α1\alpha_1 and α2\alpha_2 are both less than zero, one exponential curve decays until giving rise to an increasing curve (Figure 11.)

The value xx actually lives on a simplex, or a generalization of the notion of a triangle to arbitrary dimensions. Geometrically, all points on a simplex are a convex combination of the vertices. In this case, xx lives on a 11-simplex, or any combination of the points (0,1)(0, 1) and (1,0)(1, 0).

Now imagine we could take the curves in Figure 11 and extrude them out of the screen to form a multivariate density. The support of this density should now be a 22-simplex, or any point in the convex hull defined by points (0,0,1)(0, 0, 1), (0,1,0)(0, 1, 0), and (1,0,0)(1, 0, 0). We can visualize this by indicating the height with a color gradient (Figure 22).

Figure 2. Four examples of Dirichlet (multivariate beta) distributions with different shape parameters. Color indicates the PDF. The parameters α1\alpha_1, α2\alpha_2, and α3\alpha_3 are labeled in that order above each figure.

Whatever distribution this is should have three parameters; let’s call them α1\alpha_1, α2\alpha_2, and α3\alpha_3. Without knowing the normalizing constant, we can guess that the PDF should look like this:

(x1)α11(x2)α21(1x1x2)α31.(2) (x_1)^{\alpha_1 - 1} (x_2)^{\alpha_2 - 1} (1 - x_1 - x_2)^{\alpha_3 - 1}. \tag{2}

These shape parameters must also be non-negative. What we’re constructing is just a multivariate beta distribution, which has its own name: the Dirichlet distribution. Values from a KK-dimensional Dirichlet distribution live on a (K1)(K-1)-simplex. The general PDF of the Dirichlet distribution is

Dir(α)=1B(α)x1(α11)xK(αK1),α1,,αK>0(3) \text{Dir}(\boldsymbol{\alpha}) = \frac{1}{\text{B}(\boldsymbol{\alpha})} x_1^{(\alpha_1 - 1)} \dots x_K^{(\alpha_K - 1)}, \quad \alpha_1, \dots, \alpha_K > 0 \tag{3}

We now define the normalizing constant,

B(α)=Γ(α1)Γ(αK)Γ(α1++αK).(4) \text{B}(\boldsymbol{\alpha}) = \frac{\Gamma(\alpha_1) \dots \Gamma(\alpha_K)}{\Gamma(\alpha_1 + \dots + \alpha_K)}. \tag{4}

While this may look abstract, there is some intuition. The gamma function is an extension of the factorial function to complex numbers. This is why, for any positive integer zz,

Γ(z)=(z1)!(5) \Gamma(z) = (z - 1)! \tag{5}

I think of the normalizing constant, the beta function, as playing a similar role as the binomial coefficient (“nn choose kk”) in the binomial distribution. When computing the probability of kk successes in nn coin flips, we must also do some bookkeeping since there are possibly many different ways to get kk successes.

Finally, one interesting fact about the Dirichlet distribution is that its marginals are beta distributions. At least visually, this is intuitive. Take the 22-simplex in Figure 22 and collapse it to any side. We would get a 11-simplex controlled by the associated end points. While tedious, it is fairly straightforward to compute this. First, let A=kαkA = \sum_k \alpha_k. Then we can write the Dirichlet PDF as:

Γ(A)Γ(α1)Γ(Aα1)θ1α11(1θ1)Aα11×Γ(Aα1)Γ(α2)Γ(Aα1α2)θ2α21(1θ1θ2)Aα1α21(1θ1)Aα11×Γ(Aα1α2)Γ(α3)Γ(Aα1α2α3)θ3α31(1θ1θ2θ3)Aα1α2α31(1θ1θ2)Aα1α21×Γ(Aα1αK2)Γ(αK1)Γ(Aα1αK1)θK1αK11θKαK11(1θ1θK2)αK1+αK1(6) \begin{aligned} &\frac{\Gamma(A)}{\Gamma(\alpha_1) \textcolor{#59a1cf}{\Gamma(A - \alpha_1)}} \theta_1^{\alpha_1 - 1} \textcolor{#59a1cf}{(1 - \theta_1)^{A - \alpha_1 - 1}} \\ &\times \frac{\textcolor{#59a1cf}{\Gamma(A - \alpha_1)}}{\Gamma(\alpha_2) \textcolor{#59a1cf}{\Gamma(A - \alpha_1 - \alpha_2)}} \frac{\theta_2^{\alpha_2 - 1} \textcolor{#59a1cf}{(1 - \theta_1 - \theta_2)^{A - \alpha_1 - \alpha_2 - 1}}}{\textcolor{#59a1cf}{(1 - \theta_1)^{A - \alpha_1 - 1}}} \\ &\times \frac{\textcolor{#59a1cf}{\Gamma(A - \alpha_1 - \alpha_2)}}{\Gamma(\alpha_3) \textcolor{#59a1cf}{\Gamma(A - \alpha_1 - \alpha_2 - \alpha_3)}} \frac{\theta_3^{\alpha_3 - 1} \textcolor{#59a1cf}{(1 - \theta_1 - \theta_2 - \theta_3)^{A - \alpha_1 - \alpha_2 - \alpha_3 - 1}}}{\textcolor{#59a1cf}{(1 - \theta_1 - \theta_2)^{A - \alpha_1 - \alpha_2 - 1}}} \\ &\vdots \\ &\times \frac{\textcolor{#59a1cf}{\Gamma(A - \alpha_1 - \dots - \alpha_{K-2})}}{\Gamma(\alpha_{K-1}) \textcolor{#59a1cf}{\Gamma(A - \alpha_1 - \dots - \alpha_{K-1})}} \frac{\theta_{K-1}^{\alpha_{K-1}-1} \theta_K^{\alpha_{K-1} - 1}}{\textcolor{#59a1cf}{(1 - \theta_1 - \dots - \theta_{K_2})^{\alpha_{K-1} + \alpha_K - 1}}} \end{aligned} \tag{6}

The terms colored in blue will cancel, leaving us with Eq. 33. However, notice that we can also write the Dirichlet joint distribution as

p(θ)=p(θ1)p(θ2θ1)p(θ3θ1,θ2)p(θK1θ1,,θK1).(7) p(\boldsymbol{\theta}) = p(\theta_1) p(\theta_2 \mid \theta_1) p(\theta_3 \mid \theta_1, \theta_2) \dots p(\theta_{K-1} \mid \theta_1, \dots, \theta_{K-1}). \tag{7}

We don’t have to specify the final probability because it is fixed given the other values of θ\boldsymbol{\theta}. Matching terms in Eq. 66 with Eq. 77—so by the first line in Eq. 66—, we see that the marginal distribution p(θ1)p(\theta_1) is a beta distribution:

p(θ1)=beta(α1,Aα1).(8) p(\theta_1) = \text{beta}(\alpha_1, A - \alpha_1). \tag{8}

Clearly, we could rewrite Eq. 66 and factorize Eq. 77 to put the jj-th term first. So in fact, the result in Eq. 88 is general to any parameter θj\theta_j.

Without computation, I think we can intuit what the remainder—the rest of the joint distribution in Eq. 77—should be. Imagine we have a KK-simplex, and we could “remove” an edge. That edge or marginal would be beta distributed, but the remainder would be a (K1)(K-1)-simplex, or another Dirichlet distribution.

Multinomial–Dirichlet distribution

Now that we better understand the Dirichlet distribution, let’s derive the posterior, marginal likelihood, and posterior predictive distributions for a very popular model: a multinomial model with a Dirichlet prior. These derivations will be very similar to my post on Bayesian inference for beta–Bernoulli models. Why? Well, a Bernoulli random variable can be thought of as modeling one coin flip with bias θ\theta. A binomial random variable is then flipping the same coin MM times. And a multinomial random variable is rolling a KK-sided die MM times. My point is that these models have a common structure. Furthermore, as we just saw, the Dirichlet distribution is just the multivariate beta distribution.

We assume our data X={x1,,xN}\mathbf{X} = \{\mathbf{x}_1, \dots, \mathbf{x}_N\} are multinomial distributed with a Dirichlet prior:

xnmulti(M,θ),k=1KθK=1,θDir(α),α1,,αK>0.(9) \begin{aligned} \mathbf{x}_n &\sim \text{multi}(M, \boldsymbol{\theta}), \quad \sum_{k=1}^K \theta_K = 1, \\ \boldsymbol{\theta} &\sim \text{Dir}(\boldsymbol{\alpha}), \quad \alpha_1, \dots, \alpha_K > 0. \tag{9} \end{aligned}

Recall that by definition of multinomial random variables, kxn,k=M\sum_{k} x_{n,k} = M.

Posterior. Showing conjugacy by deriving the posterior is relatively easy. The likelihood times prior is

p(θX)n=1Np(xnθ)p(θ)=n=1N(M!xn,1!xn,K!θ1xn,1θKxn,K)(1B(α)θ1(α11)θK(αK1))θ1(nxn,1+α11)θK(nxn,K+αK1).(10) \begin{aligned} p(\boldsymbol{\theta} \mid \mathbf{X}) &\propto \prod_{n=1}^{N} p(\mathbf{x}_n \mid \boldsymbol{\theta}) p(\boldsymbol{\theta}) \\ &= \prod_{n=1}^{N} \left( \frac{M!}{x_{n,1}! \dots x_{n,K}!} \theta_1^{x_{n,1}} \dots \theta_K^{x_{n,K}} \right) \left( \frac{1}{\text{B}(\boldsymbol{\alpha})} \theta_1^{(\alpha_1 - 1)} \dots \theta_K^{(\alpha_K - 1)} \right) \\ &\propto \theta_1^{\left(\sum_n x_{n,1} + \alpha_1 - 1\right)} \dots \theta_K^{\left(\sum_n x_{n,K} + \alpha_K - 1 \right)}. \end{aligned} \tag{10}

So we see the posterior is proportional to a Dirichlet distribution:

p(θX)=Dir(XαN),αN,k=n=1Nxn,k+αk.(11) \begin{aligned} p(\boldsymbol{\theta} \mid \mathbf{X}) &= \text{Dir}(\mathbf{X} \mid \boldsymbol{\alpha}_N), \\ \boldsymbol{\alpha}_{N,k} &= \sum_{n=1}^N x_{n,k} + \alpha_k. \end{aligned} \tag{11}

The other terms, namely the normalizers, do not matter because they do not depend on θ\boldsymbol{\theta}. This is just a multivariate generalization of our derivation for the beta–Bernoulli model. Furthermore, it has the same intuition. As we observe more values kk, the kk-th component of αN\boldsymbol{\alpha}_N increases, placing more mass at that location on the simplex.

Marginal likelihood. To compute the marginal likelihood, we need the integral definition of the multivariate beta function:

B(α)=0101θ1 ⁣01θ1θK2θ1α11θ2α21θK1αK11θKαK1dθ1dθK1.(12) \text{B}(\boldsymbol{\alpha}) = \int_0^1 \int_0^{1 - \theta_1} \dots \int_0^{1 - \theta_1 - \dots - \theta_{K-2}} \theta_1^{\alpha_1 - 1} \theta_2^{\alpha_2 - 1} \dots \theta_{K-1}^{\alpha_{K-1} - 1} \theta_K^{\alpha_K - 1} \text{d}\theta_1 \dots \text{d}\theta_{K-1}. \tag{12}

If this definition is a lot, I suggest first writing out the definition of the beta integral for B(α1,α2)\text{B}(\alpha_1, \alpha_2) and then extending it to three dimensions, keeping in mind that the sum of θk\theta_k terms is one. With this definition in mind, let’s integrate over Eq. 1010 without dropping the normalizing constants,

p(X)=p(Xθ)p(θ)dθ=n=1Np(xnθ)p(θ)dθ=n=1N(M!xn,1!xn,K!θ1xn,1θKxn,K)(1B(α)θ1(α11)θK(αK1))dθ=n=1NM!xn,1!xn,K!1B(α)θ1(nxn,1+α11)θK(nxn,K+αK1)dθ=n=1NM!xn,1!xn,K!B(αN)B(α).(13) \begin{aligned} p(\mathbf{X}) &= \int p(\mathbf{X} \mid \boldsymbol{\theta}) p(\boldsymbol{\theta}) \text{d} \boldsymbol{\theta} \\ &= \int \prod_{n=1}^{N} p(\mathbf{x}_n \mid \boldsymbol{\theta}) p(\boldsymbol{\theta}) \text{d}\boldsymbol{\theta} \\ &= \int \prod_{n=1}^{N} \left( \frac{M!}{x_{n,1}! \dots x_{n,K}!} \theta_1^{x_{n,1}} \dots \theta_K^{x_{n,K}} \right) \left( \frac{1}{\text{B}(\boldsymbol{\alpha})} \theta_1^{(\alpha_1 - 1)} \dots \theta_K^{(\alpha_K - 1)} \right) \text{d}\boldsymbol{\theta} \\ &= \prod_{n=1}^{N} \frac{M!}{x_{n,1}! \dots x_{n,K}!} \frac{1}{\text{B}(\boldsymbol{\alpha})} \int \theta_1^{\left(\sum_n x_{n,1} + \alpha_1 - 1\right)} \dots \theta_K^{\left(\sum_n x_{n,K} + \alpha_K - 1 \right)} \text{d}\boldsymbol{\theta} \\ &= \prod_{n=1}^{N} \frac{M!}{x_{n,1}! \dots x_{n,K}!} \frac{\text{B}(\boldsymbol{\alpha}_N)}{\text{B}(\boldsymbol{\alpha})}. \end{aligned} \tag{13}

We can verify this by comparing it to the probability density function for the Dirichlet–multinomial compound distribution. Let N=1N = 1. Then Eq. 1313 is

p(x)=M!x1!xK!B(αN)B(α)=M!x1!xK!Γ(kαN,k)kΓ(αN,k)k=1KΓ(xk+αk)Γ(kxk+αk)=(M!)Γ(kαN,k)Γ(M+kαk)k=1KΓ(xk+αk)(xk)!Γ(αN,k).(14) \begin{aligned} p(\mathbf{x}) &= \frac{M!}{x_{1}! \dots x_{K}!} \frac{\text{B}(\boldsymbol{\alpha}_N)}{\text{B}(\boldsymbol{\alpha})} \\ &= \frac{M!}{x_{1}! \dots x_{K}!} \frac{\Gamma\left(\sum_k \alpha_{N,k} \right)}{\prod_k \Gamma(\alpha_{N,k})} \frac{\prod_{k=1}^{K} \Gamma(x_k + \alpha_k)}{\Gamma\left( \sum_k x_k + \alpha_k \right)} \\ &= \frac{(M!) \Gamma\left(\sum_k \alpha_{N,k} \right)}{\Gamma\left( M + \sum_k \alpha_k \right)} \prod_{k=1}^{K} \frac{\Gamma(x_k + \alpha_k)}{(x_k)! \Gamma(\alpha_{N,k})}. \end{aligned} \tag{14}

This is messy in that it is not a well-known or easy-to-reason about distribution—at least for me. But again, we can imagine that as more observations contain the kk-th component, the marginal likelihood will shift mass appropriately.

Posterior predictive. This derivation is quite similar to the derivation for the marginal likelihood. To compute the posterior predictive over an unseen observation x^\hat{\mathbf{x}}, we need to integrate out our uncertainty about the parameters w.r.t. the inferred posterior:

p(x^X)=p(x^θ)p(θX)dθ=multi(x^θ)Dir(θαN)dθ=(M!x^1!x^K!θ1x^1θKx^K)(1B(αN)θ1αN,11θKαN,K1)dθ=M!x^1!x^K!1B(αN)θ1x^1+αN,11θKx^K+αN,K1dθ.(15) \begin{aligned} p(\hat{\mathbf{x}} \mid \mathbf{X}) &= \int p(\hat{\mathbf{x}} \mid \boldsymbol{\theta}) p(\boldsymbol{\theta} \mid \mathbf{X}) \text{d}\boldsymbol{\theta} \\ &= \int \text{multi}(\hat{\mathbf{x}} \mid \boldsymbol{\theta}) \text{Dir}(\boldsymbol{\theta} \mid \boldsymbol{\alpha}_N) \text{d}\boldsymbol{\theta} \\ &= \int \left( \frac{M!}{\hat{x}_1! \dots \hat{x}_K!} \theta_1^{\hat{x}_1} \dots \theta_K^{\hat{x}_K} \right) \left( \frac{1}{\text{B}(\boldsymbol{\alpha}_N)} \theta_1^{\alpha_{N,1} - 1} \dots \theta_K^{\alpha_{N,K} - 1} \right) \text{d}\boldsymbol{\theta} \\ &= \frac{M!}{\hat{x}_1! \dots \hat{x}_K!} \frac{1}{\text{B}(\boldsymbol{\alpha}_N)} \int \theta_1^{\hat{x}_1 + \alpha_{N,1} - 1} \dots \theta_K^{\hat{x}_K + \alpha_{N,K} - 1} \text{d}\boldsymbol{\theta}. \end{aligned} \tag{15}

Again, we use the integral definition of the multivariate beta function. The integral in the last line of Eq. 1515 is equal to

B(n=1Nx^1+α1,,n=1Nx^K+αK).(16) \text{B}\left(\sum_{n=1}^N \hat{x}_1 + \alpha_1, \dots, \sum_{n=1}^N \hat{x}_K + \alpha_K \right). \tag{16}

Putting everything together, we see that the posterior predictive is

p(x^X)=M!x^1!x^K!B(x^+αN)B(α).(17) p(\hat{\mathbf{x}} \mid \mathbf{X}) = \frac{M!}{\hat{x}_{1}! \dots \hat{x}_{K}!} \frac{\text{B}(\hat{\mathbf{x}} + \boldsymbol{\alpha}_N)}{\text{B}(\boldsymbol{\alpha})}. \tag{17}

Posterior predictive (single trial). Finally, another way to write this posterior predictive is when M=1M=1, e.g. when we roll a KK-sided die exactly once rather than MM times. First, let θj\boldsymbol{\theta}_{-j} denote all components but the jj-th component of the vector θ\boldsymbol{\theta}. The marginal of θj\theta_j is

p(θj)=θjp(θj,θj)dθj.(18) p(\theta_j) = \int_{\boldsymbol{\theta}_{-j}} p(\boldsymbol{\theta}_{-j}, \theta_j) \text{d}\boldsymbol{\theta}_{-j}. \tag{18}

So the probability that x^=j\hat{x} = j can be written as

p(x^=jX)=p(x^=jθ)p(θX)dθ=θjθjp(x^=jθ)p(θX)dθjdθj=θjθj(θjp(θX)dθj)dθj=θjθjp(θjX)dθj=E[θjX].(19) \begin{aligned} p(\hat{x} = j \mid \mathbf{X}) &= \int p(\hat{x} = j \mid \boldsymbol{\theta}) p(\boldsymbol{\theta} \mid \mathbf{X}) \text{d} \boldsymbol{\theta} \\ &= \int_{\theta_j} \int_{\boldsymbol{\theta}_{-j}} p(\hat{x} = j \mid \boldsymbol{\theta}) p(\boldsymbol{\theta} \mid \mathbf{X}) \text{d} \boldsymbol{\theta}_{-j} \text{d}\theta_j \\ &= \int_{\theta_j} \theta_j \left( \int_{\boldsymbol{\theta}_{-j}} p(\boldsymbol{\theta} \mid \mathbf{X}) \text{d} \boldsymbol{\theta}_{-j} \right) \text{d}\theta_j \\ &= \int_{\theta_j} \theta_j p(\theta_j \mid \mathbf{X}) \text{d}\theta_j \\ &= \mathbb{E}[\theta_j \mid \mathbf{X}]. \end{aligned} \tag{19}

What is this expectation? Well, the posterior p(θX)p(\boldsymbol{\theta} \mid \mathbf{X}) is a Dirichlet distribution, and the marginal of the Dirichlet is a beta distribution:

p(θX)=Dir(XαN)p(θjX)=beta(αN,j,ANαN,j).(20) \begin{aligned} p(\boldsymbol{\theta} \mid \mathbf{X}) &= \text{Dir}(\mathbf{X} \mid \boldsymbol{\alpha}_N) \\ &\Downarrow \\ p(\theta_j \mid \mathbf{X}) &= \text{beta}(\alpha_{N,j}, A_{N} - \alpha_{N,j}). \end{aligned} \tag{20}

Here, I use ANA_{N} to denote the sum kαN,k\sum_{k} \alpha_{N,k}. So the expectation is:

E[θjX]=nxn,j+αjk(n=1Nxn,k+αk).(21) \mathbb{E}[\theta_j \mid \mathbf{X}] = \frac{\sum_n x_{n,j} + \alpha_{j}}{\sum_{k} \left( \sum_{n=1}^{N} x_{n,k} + \alpha_k \right)}. \tag{21}

It’s messy notation, but you can think of the numerator as counting “successes” for the jj-th component or outcome, and the numerator as tracking all values. A cleaner notation is as follows. Let NjN_j denote the number of times xn,k=jx_{n,k} = j. And clearly, the left term in the numerator can be written as NN, since each of NN observation x\mathbf{x} must sum to M=1M=1. Finally, writing the sum across αk\alpha_k as AA gives us

E[θjX]=Nj+αjNA.(21) \mathbb{E}[\theta_j \mid \mathbf{X}] = \frac{N_j + \alpha_j}{N - A}. \tag{21}

Again, we see that as more observations take the value jj, the marginal beta distribution—this expected value—shifts in favor of that outcome.

Categorial likelihood

As a final note, observe that a categorial random variable is just a multinomial random variable for a single trial,

xmulti(M=1,θ)xcat(θ).(22) \begin{aligned} \mathbf{x} &\sim \text{multi}(M=1, \boldsymbol{\theta}) \\ &\Downarrow \\ \mathbf{x} &\sim \text{cat}(\boldsymbol{\theta}). \end{aligned} \tag{22}

As such, all the derivations above are quite similar for a categorical–Dirichlet model. The only thing that changes are normalizing constants.