#!/usr/bin/env python
"""Module for the Variational Inference (VI) NN wrapper."""
import copy
import math
import torch
from ..vi.bnet import BNet
from ..solvers.quinn import QUiNNBase
from ..nns.tchutils import npy, tch, print_nnparams
from ..nns.nnfit import nnfit
[docs]
class NN_VI(QUiNNBase):
"""VI wrapper class. This implements the Bayes-by-backprop method. For details of the method, see :cite:t:`blundell:2015`.
Attributes:
best_model (torch.nn.Module): The best PyTorch NN model found during training.
bmodel (BNet): The underlying Bayesian model.
device (str): Device on which the computations are done.
trained (bool): Whether the model is trained or not.
verbose (bool): Whether to be verbose or not.
"""
[docs]
def __init__(self, nnmodel, verbose=False, pi=0.5, sigma1=1.0, sigma2=1.0,
mu_init_lower=-0.2, mu_init_upper=0.2,
rho_init_lower=-5.0, rho_init_upper=-4.0 ):
"""Instantiate a VI wrapper object.
Args:
nnmodel (torch.nn.Module): The underlying PyTorch NN model.
verbose (bool, optional): Whether to print out model details or not.
pi (float): Weight of the first gaussian. The second weight is 1-pi.
sigma1 (float): Standard deviation of the first gaussian. Can also be a scalar torch.Tensor.
sigma2 (float): Standard deviation of the second gaussian. Can also be a scalar torch.Tensor.
mu_init_lower (float): Initialization of mu lower value
mu_init_upper (float): Initialization of mu upper value
rho_init_lower (float): Initialization of rho lower value
rho_init_upper (float): Initialization of rho upper value
"""
super().__init__(nnmodel)
self.bmodel = BNet(nnmodel,pi=pi,sigma1=sigma1,sigma2=sigma2,
mu_init_lower=mu_init_lower, mu_init_upper=mu_init_upper,
rho_init_lower=rho_init_lower, rho_init_upper=rho_init_upper )
try:
self.device = nnmodel.device
except AttributeError:
self.device = 'cpu'
self.bmodel.to(self.device)
self.verbose = verbose
self.trained = False
self.best_model = None
if self.verbose:
print("=========== Deterministic model parameters ================")
self.print_params(names_only=True)
print("=========== Variational model parameters ==================")
print_nnparams(self.bmodel, names_only=True)
print("===========================================================")
[docs]
def fit(self, xtrn, ytrn, val=None,
nepochs=600, lrate=0.01, batch_size=None, freq_out=100,
freq_plot=1000, wd=0,
cooldown=100,
factor=0.95,
nsam=1,scheduler_lr=None, datanoise=0.05):
"""Fit function to train the network.
Args:
xtrn (np.ndarray): Training input array of size `(N,d)`.
ytrn (np.ndarray): Training output array of size `(N,o)`.
val (tuple, optional): `x,y` tuple of validation points. Default uses the training set for validation.
nepochs (int, optional): Number of epochs.
lrate (float, optional): Learning rate or learning rate schedule factor. Default is 0.01.
batch_size (int, optional): Batch size. Default is None, i.e. single batch.
freq_out (int, optional): Frequency, in epochs, of screen output. Defaults to 100.
freq_plot (int, optional): Frequency, in epoch, of plotting the loss.
wd (float, optional): Optional weight decay (L2 regularization) parameter.
cooldown (int, optional): cooldown in ReduceLROnPlateau
factor (float, optional): factor in ReduceLROnPlateau
nsam (int, optional): Number of samples for ELBO computation. Defaults to 1.
scheduler_lr (None, optional): Scheduler of learning rate. See the corresponding argument in :func:`..nns.nnfit.nnfit()`.
datanoise (float, optional): Datanoise for ELBO computation. Defaults to 0.05.
"""
shape_xtrn = xtrn.shape
ntrn = shape_xtrn[0]
ntrn_, outdim = ytrn.shape
assert(ntrn==ntrn_)
if batch_size is None or batch_size > ntrn:
batch_size = ntrn
if batch_size == 1:
num_batches = ntrn
else:
num_batches = (ntrn + 1) // batch_size
self.bmodel.loss_params = [datanoise, nsam, num_batches]
fit_info = nnfit(self.bmodel, xtrn, ytrn, val=val,
loss_xy=self.bmodel.viloss,
lrate=lrate, batch_size=batch_size,
nepochs=nepochs,
wd=wd,
cooldown=cooldown,
factor=factor,
freq_plot=freq_plot,
scheduler_lr=scheduler_lr, freq_out=freq_out)
self.best_model = fit_info['best_nnmodel']
self.trained = True
[docs]
def predict_sample(self, x):
"""Predict a single sample.
Args:
x (np.ndarray): Input array `x` of size `(N,d)`.
Returns:
np.ndarray: Output array `x` of size `(N,o)`.
Note:
predict_ens() from the parent class will use this to sample an ensemble.
"""
assert(self.trained)
device = self.best_model.device
y = npy(self.best_model(tch(x, rgrad=False,device=device), sample=True))
return y
######################################################################
######################################################################
######################################################################