NN Architectures

QUiNN provides several neural network architectures that can serve as the underlying model \(M(x;\,w)\) for any of the UQ solvers. All architectures inherit from MLPBase, which extends torch.nn.Module with convenience methods for training, prediction, and parameter management.

Base Class (MLPBase)

MLPBase is the abstract base that every QUiNN architecture must subclass. It stores the input/output dimensions, tracks whether the network has been trained, and exposes the following core interface:

  • fit(xtrn, ytrn, **kwargs) — trains the network via nnfit, stores the best model snapshot, and records the loss history.

  • predict(x) — numpy-in / numpy-out convenience: converts the input, evaluates the (best) model, and returns a numpy array.

  • numpar() — returns the total number of parameters \(K\).

  • predict_plot / plot_1d_fits — diagnostic plotting helpers.

Multilayer Perceptron (MLP)

The primary architecture in QUiNN is a fully connected feedforward network (multilayer perceptron). For an input \(x \in \mathbb{R}^d\), hidden layer widths \((h_1, h_2, \ldots, h_L)\), and output \(y \in \mathbb{R}^o\), the forward pass is

\[\begin{split}z_0 &= x, \\ z_{\ell} &= \phi\!\left(W_\ell\, z_{\ell-1} + b_\ell\right), \quad \ell = 1, \ldots, L, \\ y &= W_{L+1}\, z_L + b_{L+1},\end{split}\]

where \(W_\ell\) and \(b_\ell\) are the weight matrix and bias vector of layer \(\ell\), and \(\phi\) is the elementwise activation function.

Constructor Parameters

Parameter

Default

Description

indim

Input dimensionality \(d\).

outdim

Output dimensionality \(o\).

hls

Tuple of hidden-layer widths \((h_1, \ldots, h_L)\).

activ

'relu'

Activation function \(\phi\). Options: 'relu' (\(\max(0,x)\)), 'tanh' (\(\tanh(x)\)), 'sin' (\(A\sin(2\pi x / T)\)), or identity.

biasorno

True

Include bias terms \(b_\ell\).

bnorm

False

Apply Batch Normalization after each layer.

bnlearn

True

Make batch-norm parameters learnable.

dropout

0.0

Dropout fraction \(p\). Applied after each linear layer when \(p > 0\).

final_transform

None

Optional output transform. 'exp' applies \(e^{(\cdot)}\) to enforce positivity.

device

'cpu'

Compute device ('cpu' or 'cuda').

Architecture Diagram

A three-hidden-layer MLP with widths \((h_1, h_2, h_3)\):

Input (d)
  │
  ├─ Linear(d → h₁) ─ [Dropout] ─ [BatchNorm] ─ Activation
  ├─ Linear(h₁→ h₂) ─ [Dropout] ─ [BatchNorm] ─ Activation
  ├─ Linear(h₂→ h₃) ─ [Dropout] ─ [BatchNorm] ─ Activation
  ├─ Linear(h₃→ o)  ─ [Dropout] ─ [BatchNorm]
  │
  └─ [final_transform]
Output (o)

The total parameter count is

\[K = \sum_{\ell=1}^{L+1}\!\bigl(n_{\ell-1}\,n_\ell + n_\ell\bigr),\]

where \(n_0 = d\), \(n_{L+1} = o\), and \(n_\ell = h_\ell\) for \(\ell = 1,\ldots,L\) (assuming biases are included).

Residual Network (RNet)

The RNet class implements a Residual Neural Network whose layers can be viewed as a discretised ODE (the neural ODE perspective). All hidden layers share the same width \(r\), and the forward pass is

\[\begin{split}z_0 &= x \quad \text{(or } z_0 = \phi(W_{\text{pre}}\,x + b_{\text{pre}}) \text{ if a pre-layer is used)}, \\ z_{\ell+1} &= z_\ell + \Delta t\;\phi\!\left(W_\ell\,z_\ell + b_\ell\right), \quad \ell = 0,\ldots,L, \\ y &= z_{L+1} \quad \text{(or } y = W_{\text{post}}\,z_{L+1} + b_{\text{post}} \text{ if a post-layer is used)},\end{split}\]

where \(\Delta t = 1/(L+1)\) is the step size. Setting mlp=True drops the skip connections, recovering a standard MLP.

Weight Parameterization

A distinguishing feature of RNet is that the layer weights need not be independent: they can be parameterised as a function of the layer index (interpreted as a time variable \(t \in [0,1]\)). The LayerFcn hierarchy provides:

Class

Weight function \(W(t)\)

Parameters

NonPar

Independent weight per layer (standard ResNet)

\(L+1\)

Const

\(W(t) = A\)

1

Lin

\(W(t) = A + Bt\)

2

Quad

\(W(t) = A + Bt + Ct^2\)

3

Cubic

\(W(t) = A + Bt + Ct^2 + Dt^3\)

4

Poly

\(W(t) = \sum_{i=0}^{p} A_i\,t^i\)

\(p+1\)

These parameterisations dramatically reduce the number of free parameters while preserving expressiveness.

Constructor Parameters

Parameter

Default

Description

rdim

Hidden-layer width \(r\).

nlayers

Number of residual layers \(L\).

wp_function

None

A LayerFcn instance for weight parameterization. None uses NonPar (standard ResNet).

indim

rdim

Input dimensionality \(d\). If \(d \neq r\), layer_pre must be True.

outdim

rdim

Output dimensionality \(o\). If \(o \neq r\), layer_post must be True.

biasorno

True

Include bias terms.

mlp

False

If True, skip connections are removed (standard MLP behavior).

layer_pre

False

Prepend a linear layer projecting \(d \to r\).

layer_post

False

Append a linear layer projecting \(r \to o\).

final_layer

None

Output transform: 'exp', 'logabs', or 'sum'.

init_factor

1.0

Multiplicative factor applied to weight initialization.

PyTorch Functions

The quinn.nns.nns module provides small reusable building blocks that can be used as standalone models or composed within larger architectures:

Module

Computation

Polynomial(p)

\(\sum_{i=0}^{p} c_i\,x^i\) with learnable coefficients

Polynomial3

\(a + bx + cx^2 + dx^3\) (four scalar parameters)

Constant

\(C\) (single learnable scalar)

Gaussian

\(e^{-x^2}\)

Sine(A, T)

\(A\sin(2\pi x / T)\)

SiLU

\(x\,\sigma(x)\) (Sigmoid Linear Unit)

Expon

\(e^{x}\)

TwoLayerNet

Two linear layers with a cubic polynomial in between

MLP_simple

Lightweight MLP with tanh activations, specified as a single layer-width tuple

Numpy Wrapper (NNWrap)

NNWrap provides a numpy-centric interface around any torch.nn.Module, enabling:

  • __call__(x) — evaluate the wrapped model with numpy arrays.

  • p_flatten() / p_unflatten(flat) — flatten all parameters into a 1-D numpy vector and restore them. This is essential for MCMC and Laplace solvers that operate in flat parameter space.

  • predict(x, weights) — evaluate the model after setting parameters from a flat weight vector.

  • calc_loss(weights, loss_fn, x, y) — compute a loss after unflattening weights.

  • calc_lossgrad(weights, loss_fn, x, y) — compute the gradient of a loss w.r.t. the flat parameter vector.

  • calc_hess_full(weights, loss_fn, x, y) — compute the full Hessian of the loss (used by NN_Laplace).

Summary

Architecture

Key feature

When to use

Parameter count

MLP

Fully connected

General-purpose regression

\(O\!\left(\sum h_\ell h_{\ell+1}\right)\)

RNet

Residual / neural ODE

Deep networks, parameter-efficient models

\(O(r^2 \cdot p)\) with weight param

BNet

Variational weights

Used internally by NN_VI

\(2K\)

NNWrap

Numpy interface

MCMC, Laplace, flat-parameter manipulations

Same as wrapped model