1. Learning Energy-Based Models by Diffusion Recovery Likelihood
Ruiqi Gao, Yang Song, Ben Poole, Ying Nian Wu, Diederik P. Kingma
While energy-based models (EBMs) exhibit a number of desirable properties, training and sampling on high-dimensional datasets remains challenging. Inspired by recent progress on diffusion probabilistic models, we present a diffusion recovery likelihood method to tractably learn and sample from a sequence of EBMs trained on increasingly noisy versions of a dataset. Each EBM is trained by maximizing the recovery likelihood: the conditional probability of the data at a certain noise level given their noisy versions at a higher noise level. The recovery likelihood objective is more tractable than the marginal likelihood objective, since it only requires MCMC sampling from a relatively concentrated conditional distribution. Moreover, we show that this estimation method is theoretically consistent: it learns the correct conditional and marginal distributions at each noise level, given sufficient data. After training, synthesized images can be generated efficiently by a sampling process that initializes from a spherical Gaussian distribution and progressively samples the conditional distributions at decreasingly lower noise levels. Our method generates high fidelity samples on various image datasets. On unconditional CIFAR-10 our method achieves FID 9.60 and inception score 8.58, superior to the majority of GANs. Moreover, we demonstrate that unlike previous work on EBMs, our long-run MCMC samples from the conditional distributions do not diverge and still represent realistic images, allowing us to accurately estimate the normalized density of data even for high-dimensional datasets.
Pleased to share our new work on learning energy-based models: https://t.co/fX50RBGXn0
— Ruiqi Gao (@RuiqiGao) December 16, 2020
By maximizing recovery likelihoods on increasingly noisy data, the MCMC becomes more tractable. We achieve (1)high quality samples (2)stable long-run chains (3)estimated likelihoods. (1/n) pic.twitter.com/qQpnYbpfMn