Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Jascha Sohl-Dickstein, Eric A. Weiss, Niru Maheswaranathan, Surya Ganguli

2015 · arXiv

Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Problem

Framing

Generative models had not combined normalized likelihood, exact ancestral sampling, and easy conditional manipulation in one stochastic framework. The paper closes that gap by diffusing data to a tractable terminal law, then learning only local reverse kernels. On dead leaves it reports 1.2441.244 bits/pixel.

Currently Used Methods

Foundational

Proposed Method

Architecture

The model defines a forward diffusion chain q(x(0T))q(\mathbf{x}^{(0\cdots T)}) and a learned reverse chain p(x(0T))p(\mathbf{x}^{(0\cdots T)}). For images, the reverse model predicts per-step mean and covariance from multiscale convolutions followed by several 1×11 \times 1 convolutions. Time enters through learned bump-function coefficients over tt.

Swiss-roll diffusion figure: top row shows forward noising from the data spiral to a near-Gaussian cloud; middle row shows samples from the learned reverse chain; bottom row shows learned reverse drift vectors at three times.

Loss / Objective

Training maximizes a variational lower bound on log likelihood.

LKL \ge K K=t=2Tdx(0)dx(t)q(x(0),x(t))DKL ⁣(q(x(t1)x(t),x(0))p(x(t1)x(t)))+Hq(X(T)X(0))+Hp(X(T))K = - \sum_{t=2}^{T} \int d\mathbf{x}^{(0)} \, d\mathbf{x}^{(t)} \, q(\mathbf{x}^{(0)}, \mathbf{x}^{(t)}) \, D_{KL}\!\left(q(\mathbf{x}^{(t-1)} \mid \mathbf{x}^{(t)}, \mathbf{x}^{(0)}) \,\|\, p(\mathbf{x}^{(t-1)} \mid \mathbf{x}^{(t)})\right) + H_q(X^{(T)} \mid X^{(0)}) + H_p(X^{(T)})

Sampling Rule / Algorithm

Sampling starts from the tractable terminal distribution and applies learned reverse kernels for TT steps.

x(T)π(x(T)),x(t1)p(x(t1)x(t)),t=T,,1\mathbf{x}^{(T)} \sim \pi(\mathbf{x}^{(T)}), \qquad \mathbf{x}^{(t-1)} \sim p(\mathbf{x}^{(t-1)} \mid \mathbf{x}^{(t)}), \quad t = T, \ldots, 1

Training Procedure

Evaluation

Datasets

Metrics

Headline results

CIFAR-10 qualitative results: holdout images, the same images corrupted with Gaussian noise, posterior denoised reconstructions, and unconditional samples from the diffusion model.

Table 1: Log-likelihood summary across datasets.

DatasetKK - Lnull
Swiss Roll2.35 bits6.45 bits
Binary Heartbeat-2.414 bits/seq.2.676 bits/seq.
MNIST82.90 bits136.7 bits
CIFAR-104.51 bits/pixel0.59 bits/pixel
Dead Leaves
MCGSM1.244 bits/pixel
Diffusion probabilistic model1.184 bits/pixel0.53 bits/pixel

Ablations

Method Strengths and Weaknesses

Strengths

Weaknesses

Suggestions from the authors

Links

Prior Papers

Further Papers