UQ4NN Solvers

This section provides the mathematical foundations underlying each UQ solver in QUiNN. All solvers share a common goal: given a neural network model \(M(x; w)\) with weights \(w \in \mathbb{R}^K\), and training data \(\{(x_i, y_i)\}_{i=1}^{N}\), approximate the posterior distribution

\[p(w \mid \mathcal{D}) \propto p(\mathcal{D} \mid w) \, p(w),\]

where \(\mathcal{D} = \{(x_i, y_i)\}_{i=1}^{N}\) is the dataset, and use samples from this posterior to propagate uncertainty through the network predictions.

MCMC (NN_MCMC)

Markov chain Monte Carlo directly samples from the posterior \(p(w \mid \mathcal{D})\) by constructing a Markov chain whose stationary distribution is the target posterior. Given \(M_{\text{MCMC}}\) chain samples \(\{w^{(j)}\}_{j=1}^{M_{\text{MCMC}}}\) (after discarding burn-in), predictions are obtained as

\[M(x; w^{(j)}), \quad j = 1, \ldots, M_{\text{MCMC}}.\]

The posterior mean and variance of the prediction at a test point \(x^*\) are estimated as

\[\bar{y}(x^*) = \frac{1}{M}\sum_{j=1}^{M} M(x^*; w^{(j)}), \qquad \text{Var}[y(x^*)] \approx \frac{1}{M-1}\sum_{j=1}^{M} \left(M(x^*; w^{(j)}) - \bar{y}(x^*)\right)^2.\]

QUiNN supports three MCMC samplers:

Adaptive Metropolis (AMCMC)

The Adaptive Metropolis algorithm [3] uses a random-walk Metropolis-Hastings sampler with an adaptively tuned proposal covariance. At step \(t\), the proposal is

\[w' = w^{(t)} + \xi, \qquad \xi \sim \mathcal{N}(0, \Sigma_t),\]

where the proposal covariance is updated online using the sample covariance of the chain history:

\[\Sigma_t = \frac{\gamma \cdot 2.4^2}{K} \left( \hat{C}_t + 10^{-8} I_K \right),\]

with \(\hat{C}_t\) the running sample covariance of \(\{w^{(0)}, \ldots, w^{(t)}\}\), \(K\) the parameter dimensionality, and \(\gamma\) a user-tunable scaling factor. The adaptation is triggered after an initial burn-in period \(t_0\), and the covariance is refreshed every \(t_{\text{adapt}}\) steps. The standard Metropolis-Hastings acceptance criterion applies:

\[\alpha = \min\!\left(1,\; \frac{p(w' \mid \mathcal{D})}{p(w^{(t)} \mid \mathcal{D})}\right).\]

Hamiltonian Monte Carlo (HMC)

Hamiltonian Monte Carlo [4] augments the parameter space with an auxiliary momentum variable \(p \in \mathbb{R}^K\) and defines a Hamiltonian

\[H(w, p) = U(w) + \tfrac{1}{2} p^\top p, \qquad U(w) = -\log p(w \mid \mathcal{D}).\]

The leapfrog integrator evolves the state \((w, p)\) for \(L\) steps with step size \(\varepsilon\):

\[\begin{split}p_{t+\frac{1}{2}} &= p_t + \frac{\varepsilon}{2}\,\nabla_w \log p(w_t \mid \mathcal{D}), \\ w_{t+1} &= w_t + \varepsilon\, p_{t+\frac{1}{2}}, \\ p_{t+1} &= p_{t+\frac{1}{2}} + \frac{\varepsilon}{2}\,\nabla_w \log p(w_{t+1} \mid \mathcal{D}).\end{split}\]

The proposal \((w', p')\) is accepted with probability

\[\alpha = \min\!\left(1,\; \exp\!\big(H(w, p) - H(w', p')\big)\right).\]

Metropolis-Adjusted Langevin Algorithm (MALA)

MALA [5] is a gradient-informed random-walk that uses the Langevin diffusion to construct proposals. The proposal is

\[w' = w^{(t)} + \frac{\varepsilon^2}{2}\,\nabla_w \log p(w^{(t)} \mid \mathcal{D}) + \varepsilon\, \xi, \qquad \xi \sim \mathcal{N}(0, I_K),\]

which corresponds to a single Euler-Maruyama discretization step of the Langevin stochastic differential equation. The Metropolis-Hastings correction ensures exact sampling.

Deep Ensemble (NN_Ens)

Deep Ensembles train \(J\) independent networks from random initializations, optionally on random subsets of the data (controlled by the data fraction parameter \(\delta \in (0, 1]\)). Each ensemble member \(j\) minimizes the standard MSE loss

\[\mathcal{L}_j(w_j) = \frac{1}{|\mathcal{D}_j|} \sum_{(x_i, y_i) \in \mathcal{D}_j} \|y_i - M(x_i; w_j)\|^2,\]

where \(\mathcal{D}_j \subseteq \mathcal{D}\), \(|\mathcal{D}_j| = \lfloor \delta \cdot N \rfloor\). Predictions from all members are aggregated:

\[\bar{y}(x^*) = \frac{1}{J}\sum_{j=1}^{J} M(x^*; w_j), \qquad \text{Var}[y(x^*)] \approx \frac{1}{J-1}\sum_{j=1}^{J} \left(M(x^*; w_j) - \bar{y}(x^*)\right)^2.\]

Randomized MAP Sampling (NN_RMS)

Randomized MAP Sampling (RMS) [2] extends the deep ensemble approach by training each member with a randomized prior anchor. Each ensemble member \(j\) minimizes the negative log-posterior

\[\mathcal{L}_j(w_j) = \frac{1}{2\sigma^2}\sum_{i \in \mathcal{D}_j} \|y_i - M(x_i; w_j)\|^2 + \frac{|\mathcal{D}_j|}{N} \cdot \frac{1}{2\sigma_{\text{prior}}^2}\|w_j - w_0^{(j)}\|^2,\]

where \(w_0^{(j)} \sim \mathcal{N}(0, \sigma_{\text{prior}}^2 I_K)\) is a random anchor independently drawn for each member. This provides an implicit sampling scheme: the set of MAP solutions \(\{w_j^*\}_{j=1}^J\) are approximate posterior samples.

Variational Inference (NN_VI)

Variational inference approximates the posterior \(p(w \mid \mathcal{D})\) with a tractable distribution \(q_\phi(w)\) by minimizing the Kullback-Leibler (KL) divergence, which is equivalent to maximizing the Evidence Lower Bound (ELBO). QUiNN implements the Bayes by Backprop method [1].

Variational Family

Each weight \(w_k\) is parameterized with an independent Gaussian:

\[q_\phi(w_k) = \mathcal{N}(w_k \mid \mu_k,\; \sigma_k^2), \qquad \sigma_k = \log(1 + e^{\rho_k}),\]

where \(\phi = \{\mu_k, \rho_k\}_{k=1}^K\) are the variational parameters. The softplus transformation ensures \(\sigma_k > 0\).

Scale Mixture Prior

The prior over each weight is a scale mixture of two Gaussians:

\[p(w_k) = \pi\,\mathcal{N}(w_k \mid 0,\,\sigma_1^2) + (1 - \pi)\,\mathcal{N}(w_k \mid 0,\,\sigma_2^2),\]

where \(\pi \in [0,1]\) and \(\sigma_1, \sigma_2 > 0\) are hyperparameters.

ELBO Loss

The variational loss (per mini-batch) is

\[\mathcal{L}(\phi) = \frac{1}{B}\bigl[\log q_\phi(w) - \log p(w)\bigr] + \frac{N}{2}\log(2\pi\sigma^2) + \frac{N}{2\sigma^2}\,\text{MSE}(w),\]

where \(w \sim q_\phi\), \(B\) is the number of mini-batches, and \(\text{MSE}(w) = \frac{1}{|b|}\sum_{i \in b}\|y_i - M(x_i; w)\|^2\) over the current mini-batch \(b\). At each training step, \(S\) weight samples are drawn for a Monte Carlo estimate of the ELBO. At prediction time, weight samples from \(q_\phi(w)\) are drawn to produce an ensemble of outputs.

Laplace Approximation (NN_Laplace)

The Laplace approximation [6] constructs a Gaussian approximation to the posterior centered at the MAP estimate \(w^*\):

\[p(w \mid \mathcal{D}) \approx \mathcal{N}\!\left(w \;\Big|\; w^*,\; \bigl[\nabla^2_w \mathcal{L}(w^*)\bigr]^{-1}\right),\]

where \(\mathcal{L}(w) = -\log p(w \mid \mathcal{D})\) is the negative log-posterior and \(\nabla^2_w \mathcal{L}(w^*)\) is its Hessian evaluated at the MAP.

Step 1: MAP Training. The network is trained by minimizing the negative log-posterior \(\mathcal{L}(w)\), yielding the MAP estimate \(w^*\).

Step 2: Hessian Computation. QUiNN supports two Hessian approximations:

  • Full Hessian: The exact \(K \times K\) Hessian is computed via second-order automatic differentiation:

    \[H_{ij} = \frac{\partial^2 \mathcal{L}}{\partial w_i \partial w_j}\Bigg|_{w=w^*}.\]
  • Diagonal (Fisher) approximation: The diagonal of the empirical Fisher information matrix is used as a Hessian proxy:

    \[\tilde{H}_{kk} = \frac{1}{N}\sum_{i=1}^{N} \left(\frac{\partial \mathcal{L}_i}{\partial w_k}\Bigg|_{w=w^*}\right)^2,\]

    where \(\mathcal{L}_i\) denotes the per-sample loss. The resulting Hessian is diagonal: \(\tilde{H} = \text{diag}(\tilde{H}_{11}, \ldots, \tilde{H}_{KK})\).

Step 3: Posterior Covariance. The posterior covariance is

\[\Sigma = \left(s \cdot H\right)^{-1},\]

where \(s\) is a user-tunable covariance scaling factor.

Step 4: Prediction. A predictive sample is drawn as

\[w \sim \mathcal{N}(w^*,\, \Sigma), \qquad y(x^*) = M(x^*; w).\]

SWAG (NN_SWAG)

Stochastic Weight Averaging-Gaussian (SWAG) [7] approximates the posterior by fitting a Gaussian distribution to the SGD trajectory after initial training.

Step 1: Pre-training. The network is trained with the negative log-posterior loss to obtain a good initialization.

Step 2: SGD Trajectory Collection. Starting from the pre-trained weights, \(T\) additional SGD steps are performed. At every \(c\)-th step, the current weight vector \(w_t\) is recorded and the running moments are updated:

\[\begin{split}\bar{w}_{n+1} &= \frac{n\,\bar{w}_n + w_t}{n+1}, \\ \overline{w^2}_{n+1} &= \frac{n\,\overline{w^2}_n + w_t \odot w_t}{n+1},\end{split}\]

where \(n = \lfloor t/c \rfloor\) is the snapshot counter and \(\odot\) is element-wise product.

Step 3: Covariance Approximation. The diagonal variance is

\[\Sigma_{\text{diag}} = \overline{w^2} - \bar{w} \odot \bar{w}.\]

For the low-rank variant, the last \(k\) deviation vectors \(d_i = w_{t_i} - \bar{w}\) are stored as columns of a matrix \(D \in \mathbb{R}^{K \times k}\).

Step 4: Prediction. A posterior sample is drawn as

\[w = \bar{w} + \frac{1}{\sqrt{2}}\,\text{diag}\!\left(\sqrt{\Sigma_{\text{diag}}}\right) z_1 + \frac{1}{\sqrt{2}}\,\frac{D\, z_2}{\sqrt{k-1}},\]

where \(z_1 \sim \mathcal{N}(0, I_K)\) and \(z_2 \sim \mathcal{N}(0, I_k)\). If the covariance type is not low-rank, the second term is omitted. Predictions are obtained as \(y(x^*) = M(x^*; w)\).

Summary of Solvers

Solver

Posterior approximation

Training cost

Memory cost

Key hyperparameters

NN_MCMC

Exact (asymptotically)

High (\(O(M_{\text{MCMC}})\) forward/backward passes)

\(O(M_{\text{MCMC}} \cdot K)\)

\(M_{\text{MCMC}}\), sampler type, \(\sigma\), \(\varepsilon\) (HMC)

NN_Ens

Implicit (point estimates)

\(J \times\) single training

\(O(J \cdot K)\)

\(J\), \(\delta\)

NN_RMS

Implicit (randomized MAP)

\(J \times\) single training

\(O(J \cdot K)\)

\(J\), \(\sigma\), \(\sigma_{\text{prior}}\)

NN_VI

Factored Gaussian \(q_\phi(w)\)

\(\sim 2\times\) single training

\(O(2K)\) (for \(\mu, \rho\))

\(\pi\), \(\sigma_1\), \(\sigma_2\), \(S\)

NN_Laplace

Gaussian at MAP

Single training + Hessian

\(O(K^2)\) full / \(O(K)\) diag

la_type, cov_scale, \(\sigma_{\text{prior}}\)

NN_SWAG

Low-rank Gaussian

Single training + \(T\) SGD steps

\(O(K \cdot k)\) low-rank

\(k\), \(T\), \(c\), lr_swag

References

See the References page for the full reference list. Key references for the solvers:

  • AMCMC: Haario et al. [3]

  • HMC: Brooks et al. [4]

  • MALA: Girolami and Calderhead [5]

  • RMS: Pearce et al. [2]

  • VI (Bayes by Backprop): Blundell et al. [1]

  • Laplace: MacKay [6]

  • SWAG: Maddox et al. [7]