Source code for quinn.solvers.nn_ens

#!/usr/bin/env python
"""Module for Ensemble NN wrapper."""

import numpy as np

from .quinn import QUiNNBase
from ..ens.learner import Learner

[docs] class NN_Ens(QUiNNBase): """Deep Ensemble NN Wrapper. Attributes: dfrac (float): Fraction of data each learner sees. learners (list[Learner]): List of learners. nens (int): Number of ensemble members. verbose (bool): Verbose or not. """
[docs] def __init__(self, nnmodel, nens=1, dfrac=1.0, verbose=False): """Initialization. Args: nnmodel (torch.nn.Module): PyTorch NN model. nens (int, optional): Number of ensemble members. Defaults to 1. dfrac (float, optional): Fraction of data for each learner. Defaults to 1.0. verbose (bool, optional): Verbose or not. """ super().__init__(nnmodel) self.verbose = verbose self.nens = nens self.dfrac = dfrac self.learners = [] for i in range(nens): self.learners.append(Learner(nnmodel)) if self.verbose: self.print_params(names_only=True)
[docs] def print_params(self, names_only=False): """Print model parameter names and optionally, values. Args: names_only (bool, optional): Print names only. Default is False. """ for i, learner in enumerate(self.learners): print(f"========== Learner {i+1}/{self.nens} ============") learner.print_params(names_only=names_only)
[docs] def fit(self, xtrn, ytrn, **kwargs): """Fitting function for each ensemble member. Args: xtrn (np.ndarray): Input array of size `(N,d)`. ytrn (np.ndarray): Output array of size `(N,o)`. **kwargs (dict): Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. """ for jens in range(self.nens): print(f"======== Fitting Learner {jens+1}/{self.nens} =======") ntrn = ytrn.shape[0] permutation = np.random.permutation(ntrn) ind_this = permutation[:int(ntrn*self.dfrac)] this_learner = self.learners[jens] kwargs['lhist_suffix'] = f'_e{jens}' this_learner.fit(xtrn[ind_this], ytrn[ind_this], **kwargs)
[docs] def predict_sample(self, x): """Predict a single, randomly selected sample. Args: x (np.ndarray): Input array of size `(N,d)`. Returns: np.ndarray: Output array of size `(N,o)`. """ jens = np.random.randint(0, self.nens) return self.learners[jens].predict(x)
[docs] def predict_ens(self, x, nens=None): """Predict from all ensemble members. Args: x (np.ndarray): `(N,d)` input array. Returns: list[np.ndarray]: List of `M` arrays of size `(N, o)`, i.e. `M` random samples of `(N,o)` outputs. Note: This overloads QUiNN's base predict_ens function. """ if nens is None: nens = self.nens if nens>self.nens: print(f"Warning: Requested {nens} but only {self.nens} ensemble members available.") nens = self.nens permuted_inds=np.random.permutation(nens) y_all = [] for jens in range(nens): y = self.learners[permuted_inds[jens]].predict(x) y_all.append(y) return np.array(y_all)
[docs] def predict_ens_fromsamples(self, x, nens=1): """Predict ensemble in a loop using individual predict_sample() calls. Args: x (np.ndarray): `(N,d)` input array. nens (int, optional): Number of samples requested. Returns: list[np.ndarray]: List of `M` arrays of size `(N, o)`, i.e. `M` random samples of `(N,o)` outputs. """ y_all = [] for _ in range(nens): y = self.predict_sample(x) y_all.append(y) return np.array(y_all)