A Fast and Numerically Stable Implementation of the Multivariate Normal PDF
Naively computing the probability density function for the multivariate normal can be slow and numerically unstable. I work through SciPy's implementation.
Consider the multivariate normal probability density function (PDF) for with parameters (mean) and (covariance):
Computing this is tricky because we need to compute both the inverse and determinant of the covariance matrix . Each of these operations has runtime complexity . Furthermore, computing matrix inverses and determinants can be numerically unstable if done naively. SciPy has a fast and numerically stable implementation that is worth understanding. The big idea is to do one intensive operation, eigenvalue decomposition, and then use that decomposition to compute the matrix inverse and determinant cheaply.
Matrix inverse
Since is Hermitian, it has an eigendecomposition
where is an orthogonal matrix whose columns are the eigenvectors of and where is a diagonal matrix of the associated eigenvalues. Since in general and since due to orthogonality, we can easily compute the inverse of ,
Since is diagonal, is just one divided by each value along the diagonal.
Determinant
To compute the determinant of , SciPy uses the following mathematical fact: for any matrix with eigenvalues , the determinant is the product of the eigenvalues or
However, computing the determinant this way is numerically unstable, as I’ve written about before. The upshot is that if is small, the computed determinant might be zero due to machine precision. So SciPy computes the log of the PDF so that computing the determinant amounts to
To compute the PDF, SciPy first computes the log PDF and then computes the exponent of that quantity. For completeness, the log PDF for the multivariate normal is
Implementation
Below is an abbreviated version of SciPy’s implementation of multivariate_normal.pdf
. For clarity, I’ve removed any code for modularity or standarization. SciPy’s implementation makes one additional optimization worth mentioning. Rather than computing
naively, it computes and then squares it.
def pdf(x, mean, cov):
return np.exp(logpdf(x, mean, cov))
def logpdf(x, mean, cov):
# `eigh` assumes the matrix is Hermitian.
vals, vecs = np.linalg.eigh(cov)
logdet = np.sum(np.log(vals))
valsinv = np.array([1./v for v in vals])
# `vecs` is R times D while `vals` is a R-vector where R is the matrix
# rank. The asterisk performs element-wise multiplication.
U = vecs * np.sqrt(valsinv)
rank = len(vals)
dev = x - mean
# "maha" for "Mahalanobis distance".
maha = np.square(np.dot(dev, U)).sum()
log2pi = np.log(2 * np.pi)
return -0.5 * (rank * log2pi + maha + logdet)
SciPy performs a bunch of other checks such as thresholding based on eigenvalues, raising an exception for singular matrices if specified, and so forth.