Evidence lower-bound: two derivations

This post is about a fundamental inequality in variational inference.

Intuition

In order to fit a model \(p_\theta(\cdot)\) to datapoints \(\mathrm{x}_1, \dots \mathrm{x}_n\) one has to compute the marginal likelihood \(p_\theta(\mathrm{x})\) of the data (also called the evidence ) under the parameter \(\theta\). In variational methods, one usually models a generative process \(p_\theta(\mathrm{x}\mid \mathrm{z})\) which maps a latent \(\mathrm{z}\) into an estimate of the data \(\mathrm{x}\). If one has a simple distribution \(p(\mathrm{z})\) over the latents, then one has the equality:

$$ p_\theta(\mathrm{x}) = \int_\mathcal{Z}p_\theta(\mathrm{x}\mid \mathrm{z})p(\mathrm{z})\,\mathrm{dz} $$

therefore, one can approximate the marginal likelihood by sampling points \(\mathrm{z}_1, \dots, \mathrm{z}_m\) from \(p(\mathrm{z})\) and using monte-carlo integration:

$$p_\theta(\mathrm{x}) \approx \dfrac1m \sum_{i=1}^mp_\theta(\mathrm{x}\mid \mathrm{z}_i)$$

However, picking the \(\mathrm{z}_i\) at random will result in poor estimates, as a random \(\mathrm{z}_i\) is unlikely to have generated \(\mathrm{x}\) (that is, the value \(p_\theta(\mathrm{x}\mid \mathrm{z}_i)\) will be low). The best way to approximate the integral is to pick \(m\) latents which maximise the above sum.

Thus, instead of sampling from \(p(\mathrm{z})\) which is a fixed, simple and suboptimal distribution, we wish to sample from a parametric distribution \(\mathrm{z} \sim q_\phi(\mathrm{z}\mid \mathrm{x})\) called an encoder. This name comes from the fact that the datapoints \(\mathrm{x}\) are often high-dimensional, and the latent variables (called codes) are much simpler: \(\dim(\mathrm{x}) \gg \dim(\mathrm{z})\). The reason we wish to sample from \(q_\phi\) is that by optimizing over the parameters \(\phi\), we can put more mass on codes \(\mathrm{z}^\star\) that result in high likelihoods \(p_\theta(\mathrm{x}\mid \mathrm{z}^\star)\).

The variational equality

The Kullback-Leibler divergence\).

As we sample from \(q_\phi\), we should compute

\begin{align} \mathrm{KL}(q_\phi(\mathrm{z}\mid \mathrm{x})\parallel p_\theta(\mathrm{z}\mid \mathrm{x})) &= \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log q_\phi(\mathrm{z}\mid \mathrm{x}) – \log p_\theta(\mathrm{z}\mid \mathrm{x})] \\ & = \log p_\theta(\mathrm{x}) + \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log q_\phi(\mathrm{z}\mid \mathrm{x}) – \log p_\theta(\mathrm{x}\mid \mathrm{z}) – \log p(\mathrm{z})] \\ & = \log p_\theta(\mathrm{x}) + \mathrm{KL}(q_\phi(\mathrm{z}\mid \mathrm{x})\parallel p(\mathrm{z})) \; – \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log p_\theta(\mathrm{x}\mid \mathrm{z}) ] \end{align}

Therefore, denoting

\begin{align} E(\theta, \phi, \mathrm{x}) &\triangleq \mathrm{KL}(q_\phi(\mathrm{z}\mid \mathrm{x})\parallel p_\theta(\mathrm{z}\mid \mathrm{x})) \\ \Omega(\phi) &\triangleq \mathrm{KL}(q_\phi(\mathrm{z}\mid \mathrm{x})\parallel p(\mathrm{z})) \end{align}

we have the equality \begin{align} \log p_\theta(\mathrm{x})\; – E(\theta, \phi, \mathrm{x}) = \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log p_\theta(\mathrm{x}\mid \mathrm{z}) ] – \Omega(\phi) \end{align}

which reads: “ Evidence \(-\) error \(=\) approximation \(+\) regularizer “. The regularizer term \(\Omega(\phi)\) makes sure our new distribution \(q_\phi\) is not too far from a simple distribution \(p(\mathrm{z})\), and the \(E(\cdot)\) term quantifies the amount of error we incurred when sampling from \(q_\phi\) instead of \(p\) to estimate the likelihood \(\log p_\theta(\mathrm{x}\mid \mathrm{z})\). The regularizer includes a minus sign because we wish to minimize \(\Omega(\phi)\) while maximizing the approximate likelihood.

A first derivation of the inequality

Because the KL-divergence is always positive, we must that that \(E(\cdot)\) is positive, and thus

$$\log p_\theta(\mathrm{x}) \geq \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log p_\theta(\mathrm{x}\mid \mathrm{z}) ] + \Omega(\phi) \triangleq \mathcal{L}(\theta, \phi, \mathrm{x}) $$

Thus, to estimate an intractable evidence \(\log p_\theta(\mathrm{x})\), we can optimize the evidence lower bound \(\mathcal{L}(\theta, \phi, \mathrm{x})\) over \(\phi\) to tighten the bound so that \(\mathcal{L}(\theta, \phi, \mathrm{x}) \approx \log p_\theta(\mathrm{x})\).

The distributions \(p(\mathrm{z})\) and \(p_\theta(\mathrm{x}\mid\mathrm{z})\) are often chosen to be multivariate gaussians to obtain a closed form for \(\Omega(\phi)\). The expectation term \(\mathbb{E}_{\mathrm{z}\sim q_\phi}[\log p_\theta(\mathrm{x}\mid \mathrm{z}) ]\) is often estimated using monte-carlo methods.

A second derivation

A second derivation, found in this famous paper on normalizing flows uses Jensen's inequality. Indeed, by the concavity of the logarithm, we have for any distribution \(p\) the inequality \(\log(\mathbb{E}_p[Y]) \geq \mathbb{E}_p[\log(Y)]\). In particular, we can apply this inequality to derive the inequality through the steps

\begin{align} \log p_\theta(\mathrm{x}) &= \log\int_\mathcal{Z} p_\theta(\mathrm{x}\mid \mathrm{z})p(\mathrm{z})\,\mathrm{dz} \\ &= \log\int_\mathcal{Z} \dfrac{q_\phi(\mathrm{z}\mid \mathrm{x})}{q_\phi(\mathrm{z}\mid \mathrm{x})}\, p_\theta(\mathrm{x}\mid \mathrm{z})p(\mathrm{z})\,\mathrm{dz} \\ &\geq \int_\mathcal{Z} q_\phi(\mathrm{z}\mid \mathrm{x})\log\dfrac{p_\theta(\mathrm{x}\mid \mathrm{z}) p(\mathrm{z})}{q_\phi(\mathrm{z}\mid \mathrm{x})}\, \,\mathrm{dz} \\ & = \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log p_\theta(\mathrm{x}\mid \mathrm{z}) ]\; – \mathrm{KL}(q_\phi(\mathrm{z}\mid \mathrm{x})\parallel p(\mathrm{z})) \end{align}

This proof is elegant but does not give an analytic expression for the jensen gap \(\log p_\theta(\mathrm{x}) – \mathbb{E}_{\mathrm{z}\sim q_\phi}[\log p_\theta(\mathrm{x}\mid \mathrm{z}) ] + \Omega(\phi)\). This gap has the exact form \(E(\cdot)\) derived in the previous section.