Source code for quinn.solvers.nn_rms

#!/usr/bin/env python
"""Module for RMS NN wrapper."""

import torch
import numpy as np

from .nn_ens import NN_Ens


[docs] class NN_RMS(NN_Ens): """RMS Ensemble NN Wrapper. For details of the method, see :cite:t:`pearce:2018`. Attributes: datanoise (float): Data noise standard deviation. nparams (int): Number of model parameters. priorsigma (float): Prior standard deviation. """
[docs] def __init__(self, nnmodel, datanoise=0.1, priorsigma=1.0, **kwargs): """Initialization. Args: nnmodel (torch.nn.Module): NNWrapper class. datanoise (float, optional): Data noise standard deviation. Defaults to 0.1. priorsigma (float, optional): Gaussian prior standard deviation. Defaults to 1.0. **kwargs: Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. """ super().__init__(nnmodel, **kwargs) self.datanoise = datanoise self.priorsigma = priorsigma self.nparams = sum(p.numel() for p in self.nnmodel.parameters())
[docs] def fit(self, xtrn, ytrn, **kwargs): """Fitting function for each ensemble member. Args: xtrn (np.ndarray): Input array of size `(N,d)`. ytrn (np.ndarray): Output array of size `(N,o)`. **kwargs (dict): Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. """ for jens in range(self.nens): print(f"======== Fitting Learner {jens+1}/{self.nens} =======") ntrn = ytrn.shape[0] permutation = np.random.permutation(ntrn) ind_this = permutation[: int(ntrn * self.dfrac)] this_learner = self.learners[jens] kwargs["lhist_suffix"] = f"_e{jens}" #kwargs["loss"] = torch.nn.MSELoss(reduction='mean') #Loss_Gaussian(self.nnmodel, 1.1) kwargs["loss_fn"] = "logpost" kwargs["datanoise"] = self.datanoise kwargs["priorparams"] = {'sigma': self.priorsigma, 'anchor': torch.randn(size=(self.nparams,)) * self.priorsigma} this_learner.fit(xtrn[ind_this], ytrn[ind_this], **kwargs)