Source code for quinn.vi.bnet

#!/usr/bin/env python
"""Module for the Bayesian network."""

import copy
import math
import torch

from ..rvar.rvs import Gaussian_1d, GMM2_1d


[docs] class BNet(torch.nn.Module): """Bayesian NN class. Attributes: device (str): Device where the computations are done. log_prior (float): Value of log-prior. log_variational_posterior (float): Value of logarithm of variational posterior. nnmodel (torch.nn.Module): The underlying PyTorch NN module. nparams (int): Number of deterministic parameters. param_names (list[str]): List of parameter names. param_priors (list[quinn.vi.rvs.RV]): List of parameter priors. params (torch.nn.ParameterList): Variational parameters. rparams (list[quinn.rvar.rvs.RV]): List of variational PDFs. """
[docs] def __init__(self, nnmodel, pi=0.5, sigma1=1.0, sigma2=1.0, mu_init_lower=-0.2, mu_init_upper=0.2, rho_init_lower=-5.0, rho_init_upper=-4.0): """Instantiate a Bayesian NN object given an underlying PyTorch NN module. Args: nnmodel (torch.nn.Module): The original PyTorch NN module. pi (float): Weight of the first gaussian. The second weight is 1-pi. sigma1 (float): Standard deviation of the first gaussian. Can also be a scalar torch.Tensor. sigma2 (float): Standard deviation of the second gaussian. Can also be a scalar torch.Tensor. mu_init_lower (float): Initialization of mu lower value mu_init_upper (float): Initialization of mu upper value rho_init_lower (float): Initialization of rho lower value rho_init_upper (float): Initialization of rho upper value """ super().__init__() assert(isinstance(nnmodel, torch.nn.Module)) self.nnmodel = copy.deepcopy(nnmodel) try: self.device = nnmodel.device except AttributeError: self.device = 'cpu' # for name, param in self.nnmodel.named_parameters(): # print(name) # if param.requires_grad: # name_ = (name+'').replace('.', '_') # self.del_attr(self.nnmodel, [name]) # self.nnmodel.register_parameter(name, torch.nn.Parameter(param.data)) # print("######") # for name, param in self.nnmodel.named_parameters(): # print(name) # sys.exit() self.param_names = [] self.rparams = [] self.param_priors = [] i=0 for name, param in self.nnmodel.named_parameters(): if param.requires_grad: #param.requires_grad = False mu = torch.nn.Parameter(torch.Tensor(param.shape).uniform_(mu_init_lower, mu_init_upper)) self.register_parameter(name.replace('.', '_')+"_mu", mu) rho = torch.nn.Parameter(torch.Tensor(param.shape).uniform_(rho_init_lower, rho_init_upper)) self.register_parameter(name.replace('.', '_')+"_rho", rho) if i==0: self.params = torch.nn.ParameterList([mu, rho]) else: self.params.append(mu) self.params.append(rho) self.rparams.append(Gaussian_1d(mu, logsigma=rho)) ## PRIOR self.param_priors.append(GMM2_1d(pi, sigma1, sigma2)) self.param_names.append(name) # for i, param_name in enumerate(self.param_names): # #print("AAAA ", i, param_name) # self.set_attr(self.nnmodel,param_name.split("."), par_samples[i]) i+=1 self.log_prior = 0.0 self.log_variational_posterior = 0.0 self.nparams = len(self.rparams) #print("AAAAA ", self.param_names) for param_name in self.param_names: self.del_attr(self.nnmodel,param_name.split("."))
# Inspired by this https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076/10 # this could be better https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties
[docs] def del_attr(self, obj, names): """Deletes attributes from a given object. Args: obj (any): The object of interest. names (list): List that corresponds to the attribute to be deleted. If list is ['A', 'B', 'C'], the attribute A.B.C is deleted recursively. """ # print("Del ", names) if len(names) == 1: delattr(obj, names[0]) else: self.del_attr(getattr(obj, names[0]), names[1:])
[docs] def set_attr(self, obj, names, val): """Sets attributes of a given object. Args: obj (any): The object of interest. names (list): List that corresponds to the attribute of interest. If list is ['A', 'B', 'C'], the attribute A.B.C is filled with value val. val (torch.Tensor): Value to be set. """ # print("Set ", names, val) if len(names) == 1: setattr(obj, names[0], val) else: self.set_attr(getattr(obj, names[0]), names[1:], val)
[docs] def forward(self, x, sample=False, par_samples=None): """Forward function of Bayesian NN object. Args: x (torch.Tensor): Input array of size `(N,d)`. sample (bool, optional): Whether this is used in a sampling mode or not. par_samples (None, optional): Parameter samples. Default is None, in which cases the mean values of variational PDFs are used. Returns: torch.Tensor: Output array of size `(N,o)`. """ if self.training or sample: assert par_samples is None par_samples = [] for rpar in self.rparams: par_samples.append(rpar.sample()) else: if par_samples is None: par_samples = [] for rpar in self.rparams: par_samples.append(rpar.mu) assert(len(par_samples)==self.nparams) if self.training: self.log_prior = 0.0 for par_sample, param_prior in zip(par_samples, self.param_priors): self.log_prior += param_prior.log_prob(par_sample) self.log_variational_posterior = 0.0 for par_sample, rpar in zip(par_samples, self.rparams): self.log_variational_posterior += rpar.log_prob(par_sample) else: self.log_prior, self.log_variational_posterior = 0, 0 for i, param_name in enumerate(self.param_names): #print("AAAA ", i, param_name, param_name.split("."), par_samples[i]) self.set_attr(self.nnmodel,param_name.split("."), par_samples[i]) #print([i for i in self.nnmodel.coefs]) #self.nnmodel.register_parameter(param_name.replace(".","_"), torch.nn.Parameter(par_samples[i])) #print("BBB ", list(self.nnmodel.parameters())) #print(dir(self)) #print(par_samples) return self.nnmodel(x)
[docs] def sample_elbo(self, x, target, nsam, likparams=None): """Sample from ELBO. Args: x (torch.Tensor): A 2d input tensor. target (torch.Tensor): A 2d output tensor. nsam (int): Number of samples likparams (tuple, optional): Other parameters of the likelihood, e.g. data noise. Returns: tuple: (log_prior, log_variational_posterior, negative_log_likelihood) """ shape_x = x.shape batch_size = shape_x[0] batch_size_, outdim = target.shape assert(batch_size == batch_size_) # FIXME: device = x.device outputs = torch.zeros(nsam, batch_size, outdim, device=device) log_priors = torch.zeros(nsam, device=device) log_variational_posteriors = torch.zeros(nsam, device=device) for i in range(nsam): outputs[i] = self(x, sample=True) log_priors[i] = self.log_prior log_variational_posteriors[i] = self.log_variational_posterior #print("AA ", outputs) log_prior = log_priors.mean() log_variational_posterior = log_variational_posteriors.mean() #print(F.mse_loss(outputs, target, reduction='mean').shape) # outputs is MxBxd, target is Bxd, below broadcasting works, and we average over MxBxd (usually d=1) #negative_log_likelihood = F.mse_loss(outputs, target, reduction='none').mean() #print(outputs.shape, target.shape) ## FIXME transfer data to device is expensive. datasigma = torch.Tensor([likparams[0]]).to(device) negative_log_likelihood = batch_size * torch.log(datasigma) + 0.5*batch_size*torch.log(2.0*torch.tensor(math.pi))+ 0.5 * batch_size * ((outputs - target)**2).mean() / datasigma**2 return log_prior, log_variational_posterior, negative_log_likelihood
[docs] def viloss(self, data, target): """Variational loss function `L(x,y)`. Args: data (torch.Tensor): A 2d input tensor `x`. target (torch.Tensor): A 2d output tensor `y`. Returns: float: The value of loss function. """ datanoise, nsam, num_batches = self.loss_params log_prior, log_variational_posterior, negative_log_likelihood = self.sample_elbo(data, target, nsam, likparams=[datanoise]) return (log_variational_posterior - log_prior)/num_batches + negative_log_likelihood