NN Architectures
QUiNN provides several neural network architectures that can serve as the
underlying model \(M(x;\,w)\) for any of the UQ solvers. All architectures
inherit from MLPBase, which extends torch.nn.Module with
convenience methods for training, prediction, and parameter management.
Base Class (MLPBase)
MLPBase is the abstract base that every QUiNN architecture must
subclass. It stores the input/output dimensions, tracks whether the network
has been trained, and exposes the following core interface:
fit(xtrn, ytrn, **kwargs)— trains the network viannfit, stores the best model snapshot, and records the loss history.predict(x)— numpy-in / numpy-out convenience: converts the input, evaluates the (best) model, and returns a numpy array.numpar()— returns the total number of parameters \(K\).predict_plot/plot_1d_fits— diagnostic plotting helpers.
Multilayer Perceptron (MLP)
The primary architecture in QUiNN is a fully connected feedforward network (multilayer perceptron). For an input \(x \in \mathbb{R}^d\), hidden layer widths \((h_1, h_2, \ldots, h_L)\), and output \(y \in \mathbb{R}^o\), the forward pass is
where \(W_\ell\) and \(b_\ell\) are the weight matrix and bias vector of layer \(\ell\), and \(\phi\) is the elementwise activation function.
Constructor Parameters
Parameter |
Default |
Description |
|---|---|---|
|
Input dimensionality \(d\). |
|
|
Output dimensionality \(o\). |
|
|
Tuple of hidden-layer widths \((h_1, \ldots, h_L)\). |
|
|
|
Activation function \(\phi\). Options: |
|
|
Include bias terms \(b_\ell\). |
|
|
Apply Batch Normalization after each layer. |
|
|
Make batch-norm parameters learnable. |
|
|
Dropout fraction \(p\). Applied after each linear layer when \(p > 0\). |
|
|
Optional output transform. |
|
|
Compute device ( |
Architecture Diagram
A three-hidden-layer MLP with widths \((h_1, h_2, h_3)\):
Input (d)
│
├─ Linear(d → h₁) ─ [Dropout] ─ [BatchNorm] ─ Activation
├─ Linear(h₁→ h₂) ─ [Dropout] ─ [BatchNorm] ─ Activation
├─ Linear(h₂→ h₃) ─ [Dropout] ─ [BatchNorm] ─ Activation
├─ Linear(h₃→ o) ─ [Dropout] ─ [BatchNorm]
│
└─ [final_transform]
Output (o)
The total parameter count is
where \(n_0 = d\), \(n_{L+1} = o\), and \(n_\ell = h_\ell\) for \(\ell = 1,\ldots,L\) (assuming biases are included).
Residual Network (RNet)
The RNet class implements a Residual Neural Network whose layers can be
viewed as a discretised ODE (the neural ODE perspective). All hidden layers
share the same width \(r\), and the forward pass is
where \(\Delta t = 1/(L+1)\) is the step size. Setting mlp=True
drops the skip connections, recovering a standard MLP.
Weight Parameterization
A distinguishing feature of RNet is that the layer weights need not be
independent: they can be parameterised as a function of the layer index
(interpreted as a time variable \(t \in [0,1]\)). The LayerFcn
hierarchy provides:
Class |
Weight function \(W(t)\) |
Parameters |
|---|---|---|
|
Independent weight per layer (standard ResNet) |
\(L+1\) |
|
\(W(t) = A\) |
1 |
|
\(W(t) = A + Bt\) |
2 |
|
\(W(t) = A + Bt + Ct^2\) |
3 |
|
\(W(t) = A + Bt + Ct^2 + Dt^3\) |
4 |
|
\(W(t) = \sum_{i=0}^{p} A_i\,t^i\) |
\(p+1\) |
These parameterisations dramatically reduce the number of free parameters while preserving expressiveness.
Constructor Parameters
Parameter |
Default |
Description |
|---|---|---|
|
Hidden-layer width \(r\). |
|
|
Number of residual layers \(L\). |
|
|
|
A |
|
|
Input dimensionality \(d\). If \(d \neq r\), |
|
|
Output dimensionality \(o\). If \(o \neq r\), |
|
|
Include bias terms. |
|
|
If |
|
|
Prepend a linear layer projecting \(d \to r\). |
|
|
Append a linear layer projecting \(r \to o\). |
|
|
Output transform: |
|
|
Multiplicative factor applied to weight initialization. |
PyTorch Functions
The quinn.nns.nns module provides small reusable building blocks that can be
used as standalone models or composed within larger architectures:
Module |
Computation |
|---|---|
|
\(\sum_{i=0}^{p} c_i\,x^i\) with learnable coefficients |
|
\(a + bx + cx^2 + dx^3\) (four scalar parameters) |
|
\(C\) (single learnable scalar) |
|
\(e^{-x^2}\) |
|
\(A\sin(2\pi x / T)\) |
|
\(x\,\sigma(x)\) (Sigmoid Linear Unit) |
|
\(e^{x}\) |
|
Two linear layers with a cubic polynomial in between |
|
Lightweight MLP with tanh activations, specified as a single layer-width tuple |
Numpy Wrapper (NNWrap)
NNWrap provides a numpy-centric interface around any
torch.nn.Module, enabling:
__call__(x)— evaluate the wrapped model with numpy arrays.p_flatten()/p_unflatten(flat)— flatten all parameters into a 1-D numpy vector and restore them. This is essential for MCMC and Laplace solvers that operate in flat parameter space.predict(x, weights)— evaluate the model after setting parameters from a flat weight vector.calc_loss(weights, loss_fn, x, y)— compute a loss after unflattening weights.calc_lossgrad(weights, loss_fn, x, y)— compute the gradient of a loss w.r.t. the flat parameter vector.calc_hess_full(weights, loss_fn, x, y)— compute the full Hessian of the loss (used byNN_Laplace).
Summary
Architecture |
Key feature |
When to use |
Parameter count |
|---|---|---|---|
|
Fully connected |
General-purpose regression |
\(O\!\left(\sum h_\ell h_{\ell+1}\right)\) |
|
Residual / neural ODE |
Deep networks, parameter-efficient models |
\(O(r^2 \cdot p)\) with weight param |
|
Variational weights |
Used internally by |
\(2K\) |
|
Numpy interface |
MCMC, Laplace, flat-parameter manipulations |
Same as wrapped model |