NN Training
All QUiNN architectures are trained through the generic nnfit function,
which provides a configurable training loop with support for multiple loss
functions, optimizers, learning-rate schedules, mini-batching, and automatic
early stopping.
The standard entry point is MLPBase.fit(xtrn, ytrn, **kwargs), which
delegates to nnfit internally.
Training Objective
At each gradient step the optimizer minimizes a loss function \(\mathcal{L}(w)\) that depends on the current mini-batch \(\mathcal{B} \subseteq \{1,\ldots,N\}\):
where \(\eta_t\) is the learning rate at step \(t\).
Loss Functions
nnfit selects the loss through either the loss_fn string or a
user-supplied callable loss_xy.
Mean Squared Error (loss_fn='mse')
The default loss is the mean squared error over the mini-batch:
Negative Log-Posterior (loss_fn='logpost')
When a Bayesian prior is used (e.g. for NN_MCMC or NN_RMS), the
loss is the negative log-posterior combining a Gaussian likelihood and a
Gaussian prior centred at an anchor \(w_0\):
where \(\sigma\) is the data noise, \(\sigma_{\text{prior}}\) the
prior standard deviation, and \(K\) the parameter count. This loss
requires the datanoise and priorparams arguments.
Log-Loss (loss_fn='logloss')
Fits in the log-transformed space with a user-specified shift \(y_{\text{shift}}\):
This is activated by loss_fn='logloss' and requires the lossparams
argument.
Custom Loss (loss_xy)
Any callable with signature loss_xy(x_batch, y_batch) -> scalar tensor
can be passed directly. When loss_xy is provided, the loss_fn
string is ignored. This mechanism is used internally by NN_VI (Bayes
by Backprop) and can be used for any problem-specific objective.
Optimizers
nnfit supports two first-order optimizers:
String |
Algorithm |
|---|---|
|
Adam (adaptive moment estimation), the default. Updates each parameter with bias-corrected first and second moment estimates. |
|
Stochastic Gradient Descent with optional momentum (via PyTorch defaults). |
Both accept an optional weight-decay parameter wd that adds an L2
penalty \(\frac{\lambda}{2}\|w\|^2\) to the loss:
Learning Rate Schedules
Three scheduling modes are available (mutually exclusive):
Constant — when neither
lmbdnorscheduler_lris set, the learning rate stays atlratethroughout training.Lambda schedule — a user-defined function
lmbd(epoch)that returns a multiplicative factor. The effective rate at epoch \(t\) is\[\eta_t = \texttt{lrate} \times \texttt{lmbd}(t).\]ReduceLROnPlateau — set
scheduler_lr='ReduceLROnPlateau'. The scheduler monitors the validation loss and reduces the learning rate byfactorwhenever the loss plateaus forcooldownepochs:\[\eta \leftarrow \texttt{factor} \times \eta \quad\text{if validation loss stagnates for \texttt{cooldown} epochs}.\]
Mini-Batch Training
When batch_size is specified and smaller than \(N\), each epoch is
split into \(\lceil N / B \rceil\) sub-epochs, where \(B\) is the
batch size. At the start of every epoch the training data is randomly
permuted, and each sub-epoch draws a contiguous slice of size \(B\):
where \(\pi\) is the random permutation. When batch_size is
None or exceeds \(N\), full-batch training is used.
Early Stopping
At every gradient step the validation loss is evaluated (without gradients). If it improves on the current best, a deep copy of the model is checkpointed:
The returned model is always the best snapshot, not the final-epoch model.
When no separate validation set is provided (val=None), the training set
is used for both training and validation.
Arguments
Argument |
Default |
Description |
|---|---|---|
|
The |
|
|
Training inputs, numpy array of shape \((N,\,d)\). |
|
|
Training targets, numpy array of shape \((N,\,o)\). |
|
|
|
Validation data as an |
|
|
Loss identifier: |
|
|
Custom loss callable |
|
|
Data noise \(\sigma\) for |
|
|
Weight decay (L2 regularisation) coefficient \(\lambda\). |
|
|
Dictionary with keys |
|
|
Parameters for custom losses (e.g. |
|
|
Optimizer string: |
|
|
Base learning rate \(\eta\). |
|
|
Lambda schedule |
|
|
Adaptive scheduler. Currently only |
|
|
Total number of training epochs. |
|
|
Mini-batch size \(B\). |
|
|
If |
|
|
Cooldown epochs for |
|
|
Multiplicative factor for |
|
|
Screen-output frequency (in epochs). |
|
|
Loss-history plot frequency (in epochs). |
|
|
Filename suffix for the saved loss-history figures. |
Return Value
nnfit returns a dictionary with the following keys:
Key |
Content |
|---|---|
|
Deep copy of the model at the best validation loss. |
|
Best validation loss value. |
|
Epoch index at which the best loss occurred. |
|
Fractional epoch (accounts for sub-epochs in mini-batch training). |
|
List of |