Source code for quinn.nns.nnbase

#!/usr/bin/env python
"""Module for the MLP NN base class."""

import torch
import functools
import numpy as np
import matplotlib.pyplot as plt

from .tchutils import npy, tch

from ..nns.nnfit import nnfit
from ..utils.stats import get_domain
from ..utils.plotting import plot_dm
from ..utils.maps import scale01ToDom

torch.set_default_dtype(torch.double)


[docs] class MLPBase(torch.nn.Module): """Base class for an MLP architecture. Attributes: best_model (torch.nn.Module): Best trained instance, if any. device (str): Device this object's model will live in. history (list[np.ndarray]): List containing training history, namely, [fepoch, loss_trn, loss_trn_full, loss_val] indim (int): Input dimensionality. outdim (int): Output dimensionality. trained (bool): Whether the NN is already trained. """
[docs] def __init__(self, indim, outdim, device='cpu'): """Initialization. Args: indim (int): Input dimensionality, `d`. outdim (int): Output dimensionality, `o`. device (str): Indicates where computations are performed and tensors are allocated. Default to 'cpu'. """ super().__init__() self.indim = indim self.outdim = outdim self.best_model = None self.trained = False self.history = None self.device = device
[docs] def forward(self, x): """Forward function is not implemented in base class. Args: x (torch.Tensor): Input of the function. Raises: NotImplementedError: Needs to be implemented in children. """ raise NotImplementedError
[docs] def predict(self, x): """Prediction of the NN. Args: x (np.ndarray): Input array of size `(N,d)`. Returns: np.ndarray: Output array of size `(N,o)`. Note: Both input and outputs are numpy arrays. Note: If trained, it uses the best trained model, otherwise it will use the current weights. """ try: device = self.best_model.device except AttributeError: device = 'cpu' if self.trained: y = npy(self.best_model(tch(x, device=device))) else: y = npy(self.forward(tch(x, device=device))) return y
[docs] def numpar(self): """Get the number of parameters of NN. Returns: int: Number of parameters, trainable or not. """ pdim = sum(p.numel() for p in self.parameters()) return pdim
[docs] def fit(self, xtrn, ytrn, **kwargs): """Fit function. Args: xtrn (np.ndarray): Input array of size `(N,d)`. ytrn (np.ndarray): Output array of size `(N,o)`. **kwargs (dict): Keyword arguments. Returns: torch.nn.Module: Best trained instance. """ #self.fitdict = locals() fit_info = nnfit(self, xtrn, ytrn, **kwargs) # Use object.__setattr__ to prevent nn.Module from registering # best_model as a submodule (which would double the parameter count). object.__setattr__(self, 'best_model', fit_info['best_nnmodel']) self.history = fit_info['history'] self.trained = True return self.best_model
[docs] def printParams(self): """Print parameter names and values.""" for name, param in self.named_parameters(): if param.requires_grad: print(name, param.data)
[docs] def printParamNames(self): """Print parameter names and shapes.""" for name, param in self.named_parameters(): if param.requires_grad: print(name, param.data.shape)
[docs] def predict_plot(self, xx_list, yy_list, labels=None, colors=None, iouts=None): """Plots the diagonal comparison figures. Args: xx_list (list[np.ndarray]): List of `(N,d)` inputs (e.g., training, validation, testing). yy_list (list[np.ndarray]): List of `(N,o)` outputs. labels (list[str], optional): List of labels. If None, set label internally. colors (list[str], optional): List of colors. If None, sets colors internally. iouts (list[int], optional): List of outputs to plot. If None, plot all. Note: There is a similar function for probabilistic NN in :class:`..solvers.quinn.QUiNNBase`. """ nlist = len(xx_list) assert(nlist==len(yy_list)) yy_pred_list = [] for xx in xx_list: yy_pred = self.predict(xx) yy_pred_list.append(yy_pred) nout = yy_pred.shape[1] if iouts is None: iouts = range(nout) if labels is None: labels = [f'Set {i+1}' for i in range(nlist)] assert(len(labels)==nlist) if colors is None: colors = ['b', 'g', 'r', 'c', 'm', 'y']*nlist colors = colors[:nlist] assert(len(colors)==nlist) for iout in iouts: x1 = [yy[:, iout] for yy in yy_list] x2 = [yy[:, iout] for yy in yy_pred_list] plot_dm(x1, x2, labels=labels, colors=colors, axes_labels=[f'Model output # {iout+1}', f'Fit output # {iout+1}'], figname='fitdiag_o'+str(iout)+'.png', legendpos='in', msize=13)
[docs] def plot_1d_fits(self, xx_list, yy_list, domain=None, ngr=111, true_model=None, labels=None, colors=None): """Plotting one-dimensional slices, with the other dimensions at the nominal, of the fit. Args: xx_list (list[np.ndarray]): List of `(N,d)` inputs (e.g., training, validation, testing). yy_list (list[np.ndarray]): List of `(N,o)` outputs. domain (np.ndarray, optional): Domain of the function, `(d,2)` array. If None, sets it automatically based on data. ngr (int, optional): Number of grid points in the 1d plot. true_model (callable, optional): Optionally, plot the true function. labels (list[str], optional): List of labels. If None, set label internally. colors (list[str], optional): List of colors. If None, sets colors internally. Note: There is a similar function for probabilistic NN in :class:`..solvers.quinn.QUiNNBase`. """ nlist = len(xx_list) assert(nlist==len(yy_list)) if labels is None: labels = [f'Set {i+1}' for i in range(nlist)] assert(len(labels)==nlist) if colors is None: colors = ['b', 'g', 'r', 'c', 'm', 'y']*nlist colors = colors[:nlist] assert(len(colors)==nlist) if domain is None: xall = functools.reduce(lambda x,y: np.vstack((x,y)), xx_list) domain = get_domain(xall) mlabel = 'Mean Pred.' ndim = xx_list[0].shape[1] nout = yy_list[0].shape[1] for idim in range(ndim): xgrid_ = 0.5 * np.ones((ngr, ndim)) xgrid_[:, idim] = np.linspace(0., 1., ngr) xgrid = scale01ToDom(xgrid_, domain) ygrid_pred = self.predict(xgrid) for iout in range(nout): for j in range(nlist): xx = xx_list[j] yy = yy_list[j] plt.plot(xx[:, idim], yy[:, iout], colors[j]+'o', markersize=13, markeredgecolor='w', label=labels[j]) if true_model is not None: true = true_model(xgrid, 0.0) plt.plot(xgrid[:, idim], true[:, iout], 'k-', label='Truth', alpha=0.5) p, = plt.plot(xgrid[:, idim], ygrid_pred[:, iout], 'm-', linewidth=5, label=mlabel) plt.legend() plt.xlabel(f'Input # {idim+1}') plt.ylabel(f'Output # {iout+1}') plt.savefig('fit_d' + str(idim) + '_o' + str(iout) + '.png') plt.clf()