Source code for quinn.ens.learner

#!/usr/bin/env python
"""Module for a Learner class that is a wrapper with basic training/prediction functionality."""

import math
import copy

from ..nns.tchutils import npy, tch, print_nnparams
from ..nns.nnfit import nnfit

[docs] class Learner(): """A learner class that holds PyTorch NN module and helps train it. Attributes: nnmodel (torch.nn.Module): Main PyTorch NN module. best_model (torch.nn.Module): The best trained PyTorch NN module. trained (bool): Whether the module is trained or not. verbose (bool): Whether to be verbose or not. """
[docs] def __init__(self, nnmodel, verbose=False): """Initialization. Args: nnmodel (torch.nn.Module): Main PyTorch NN module. verbose (bool): Whether to be verbose or not. """ super().__init__() self.nnmodel = copy.deepcopy(nnmodel) self.trained = False self.verbose = verbose self.best_model = None if self.verbose: self.print_params(names_only=True)
[docs] def print_params(self, names_only=False): """Print parameters of the learner's model. Args: names_only (bool, optional): Whether to print the parameter names only or not. """ if self.trained: print_nnparams(self.best_model, names_only=names_only) else: print_nnparams(self.nnmodel, names_only=names_only)
[docs] def init_params(self): """An example of random initialization of parameters. .. todo:: we can and should enrich this. """ for p in self.nnmodel.parameters(): try: stdv = 1. / math.sqrt(p.size(1)) except IndexError: stdv = 1. p.data.uniform_(-stdv, stdv)
[docs] def fit(self, xtrn, ytrn, **kwargs): """Fitting function for this learner. Args: xtrn (np.ndarray): Input array of size `(N,d)`. ytrn (np.ndarray): Output array of size `(N,o)`. **kwargs (dict): Keyword arguments. """ if hasattr(self.nnmodel, 'fit') and callable(getattr(self.nnmodel, 'fit')): self.best_model = self.nnmodel.fit(xtrn, ytrn, **kwargs) else: fit_info = nnfit(self.nnmodel, xtrn, ytrn, **kwargs) self.best_model = fit_info['best_nnmodel'] self.trained = True
[docs] def predict(self, x): """Prediction of the learner. Args: x (np.ndarray): Input array of size `(N,d)`. Returns: np.ndarray: Output array of size `(N,o)`. """ assert(self.trained) try: device = self.best_model.device except AttributeError: device = 'cpu' y = self.best_model(tch(x, rgrad=False, device=device)) return npy(y)