Source code for quinn.nns.rnet

#!/usr/bin/env python
"""Module containing ResNet class and layer weight parameterization class."""

import math
import torch
import torch.nn.functional as F

from .nnbase import MLPBase


########################################################################
########################################################################
########################################################################


[docs] class RNet(MLPBase): """Residual Neural Network (ResNet) class. Attributes: activ (torch.nn.Module): Activation function. bias_post (torch.nn.Parameter): Bias vector of post-resnet layer, if any. bias_pre (torch.nn.Parameter): Bias vector of pre-resnet layer, if any. biasorno (bool): Whether or not to include biases in each resnet layer. final_layer (str): If there is a final layer function. The only current option is 'exp' for exponential function. indim (int): Input dimensionality `d`. init_factor (float): Multiplicative factor of initialized weights. layer_post (bool): Whether there is a post-resnet linear layer. layer_pre (bool): Whether there is a pre-resnet linear layer. mlp (bool): If True, residual connections are ignored, and this becomes a regular MLP. nlayers (int): Number of layers `L`. outdim (int): Output dimensionality `o`. rdim (int): Width of the ResNet `r`, i.e. number of units in each hidden layer. step_size (float): Time step size, `1/(L+1)`. sum_dim (int): Which dimension the final sum, if any, is with respect to. weight_post (torch.nn.Parameter): Weight matrix of post-resnet layer, if any. weight_pre (torch.nn.Parameter): Weight matrix of pre-resnet layer, if any. wp_function (LayerFcn): Weight parameterization function. """
[docs] def __init__(self, rdim, nlayers, wp_function=None, indim=None, outdim=None, biasorno=True, nonlin=True, mlp=False, layer_pre=False, layer_post=False,final_layer=None, device='cpu', init_factor=1.0, sum_dim=1): """Instantiate ResNet object. Args: rdim (int): Width of the ResNet `r`, i.e. number of units in each hidden layer. nlayers (int): Number of layers `L`. wp_function (LayerFcn, optional): Weight parameterization function. Defaults to a regular ResNet without weight parameterization. indim (int, optional): Input dimensionality `d`. Defaults to width `r`. outdim (int, optional): Output dimensionality `o`. Defaults to width `r`. biasorno (bool, optional): Whether or not to include biases in each resnet layer. Default is True. nonlin (bool, optional): Whether to use nonlinear activation function between layers. Defaults to True. mlp (bool, optional): If True, residual connections are ignored, and this becomes a regular MLP. Default is False. layer_pre (bool, optional): Whether there is a pre-resnet linear layer. Defaults to False. layer_post (bool, optional): Whether there is a post-resnet linear layer. Defaults to False. final_layer (str, optional): If there is a final layer function. Two options: "exp" for exponential function; "sum" for sum function which will reduce rank of the output tensor. Defaults to no final layer. device (str): It represents where computations are performed and tensors are allocated. Default to cpu. init_factor (float, optional): Multiply initial condition tensors by factor. Defaults to 1.0. sum_dim (int, optional): If final layer function is sum, it will select which dimension to perform sum with respect to. Defaults to 1. """ super().__init__(indim, outdim, device=device) if self.indim is None: self.indim = rdim if self.outdim is None: self.outdim = rdim self.nlayers = nlayers self.biasorno = biasorno if wp_function is None: self.wp_function = NonPar(nlayers+1) else: assert(isinstance(wp_function, LayerFcn)) self.wp_function = wp_function self.step_size = 1.0 / (nlayers + 1.0) self.mlp = mlp self.layer_pre = layer_pre self.layer_post = layer_post self.final_layer = final_layer self.init_factor = init_factor # only for final_layer=sum self.sum_dim = sum_dim self.rdim = rdim if self.indim != self.rdim: assert self.layer_pre if self.outdim != self.rdim: assert self.layer_post if self.layer_pre: self.weight_pre = torch.nn.Parameter(self.init_factor*(2. * torch.rand(self.rdim, self.indim) -1.)/math.sqrt(self.indim)) self.bias_pre = torch.nn.Parameter(self.init_factor*(2. * torch.rand(self.rdim) -1.)/math.sqrt(self.indim)) if self.layer_post: self.weight_post = torch.nn.Parameter(self.init_factor*(2. * torch.rand(self.outdim, self.rdim) -1.)/math.sqrt(self.rdim)) self.bias_post = torch.nn.Parameter(self.init_factor*(2. * torch.rand(self.outdim) -1.)/math.sqrt(self.rdim)) pars_w = [] for ip in range(self.wp_function.npar): ww = torch.nn.Parameter(self.init_factor*(2. * torch.rand(self.rdim, self.rdim) -1.)/math.sqrt(self.rdim)) pars_w.append(ww) self.register_parameter(name='ww_'+str(ip), param=ww) #pars_w.append(torch.nn.Parameter(torch.randn(rdim, rdim))) #self.paramsw = pars_w #torch.nn.ParameterList(pars_w) if self.biasorno: pars_b = [] for ip in range(self.wp_function.npar): bb = torch.nn.Parameter(self.init_factor*(2.*torch.rand(self.rdim)-1.)/math.sqrt(self.rdim)) pars_b.append(bb) self.register_parameter(name='bb_'+str(ip), param=bb) #pars_b.append(torch.nn.Parameter(torch.randn(rdim))) #self.paramsb = pars_b #torch.nn.ParameterList(pars_b) if nonlin: self.activ = torch.nn.Tanh() else: self.activ = torch.nn.Identity() self.to(device)
[docs] def forward(self, x): r"""Forward function. Args: x (torch.Tensor): Input tensor `x` of size :math:`(N,d)`. Returns: torch.Tensor: Output tensor of size :math:`(N,o)`. """ out = x+0.0 # Note that the prelayer has activation, too, to avoid two linear layers in succession if self.layer_pre: out = self.activ(F.linear(out, self.weight_pre, self.bias_pre)) paramsw = [getattr(self, 'ww_'+str(ip)) for ip in range(self.wp_function.npar)] if self.biasorno: paramsb = [getattr(self, 'bb_'+str(ip)) for ip in range(self.wp_function.npar)] for i in range(self.nlayers+1): weight = self.wp_function(paramsw, self.step_size * i) if self.biasorno: bias = self.wp_function(paramsb, self.step_size * i) else: bias = None if self.mlp: out = self.activ(F.linear(out, weight, bias)) else: out = out + self.step_size * self.activ(F.linear(out, weight, bias)) if self.layer_post: out = F.linear(out, self.weight_post, self.bias_post) if self.final_layer == "exp": out = torch.exp(out) elif self.final_layer == "logabs": out = torch.log(torch.abs(out)) elif self.final_layer == "sum": out = torch.sum(out,dim=self.sum_dim) return out
# def getParams(self): # """Get parameters of the ResNet. # Returns: # list[torch.nn.Parameter] or (list[torch.nn.Parameter], list[torch.nn.Parameter]): List of weights or a tuple containing list of weights and list of biases. # """ # if self.biasorno: # return self.paramsw, self.paramsb # else: # return self.paramsw # def setParams(self, paramsw, paramsb=None): # """Setting the parameters. # Args: # paramsw (list[torch.nn.Parameter]): List of weight matrices. # paramsb (list[torch.nn.Parameter], optional): List of bias vectors, if any. # """ # if self.biasorno: # self.paramsw = paramsw # assert(paramsb is not None) # self.paramsb = paramsb # else: # self.paramsw = paramsw # assert(paramsb is None) ######################################################################## ######################################################################## ########################################################################
[docs] class LayerFcn(): """Base class for layer weight parameterization layer functions. Attributes: npar (int): Number of parameters in the parameterization (parameters can be Tensors). """
[docs] def __init__(self): """Instantiation.""" self.npar=None
[docs] def __call__(self, pars, t): """Call signature. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Raises: NotImplementedError: Need to implement it in children. """ raise NotImplementedError
[docs] class Const(LayerFcn): """Constant weight parameterization. Attributes: npar (int): Number of parameters. Should be 1. """
[docs] def __init__(self): """Instantiation.""" super().__init__() self.npar = 1
[docs] def __call__(self, pars, t): """Call function. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Returns: torch.nn.Parameter: Constant (independent of `t`). """ assert(len(pars) == self.npar) return pars[0]
[docs] class Lin(LayerFcn): """Linear weight parameterization. Attributes: npar (int): Number of parameters. Should be 2. """
[docs] def __init__(self): """Instantiation.""" super().__init__() self.npar = 2
[docs] def __call__(self, pars, t): """Call function. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Returns: torch.nn.Parameter: Linear in `t`. """ assert(len(pars) == self.npar) return pars[0] + pars[1] * t
[docs] class Quad(LayerFcn): """Quadratic weight parameterization. Attributes: npar (int): Number of parameters. Should be 3. """
[docs] def __init__(self): """Instantiation.""" super().__init__() self.npar = 3
[docs] def __call__(self, pars, t): """Call function. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Returns: torch.nn.Parameter: Quadratic in `t`. """ assert(len(pars) == self.npar) return pars[0] + pars[1] * t + pars[2] * t**2
[docs] class Cubic(LayerFcn): """Cubic weight parameterization. Attributes: npar (int): Number of parameters. Should be 4. """
[docs] def __init__(self): """Instantiation.""" super().__init__() self.npar = 4
[docs] def __call__(self, pars, t): """Call function. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Returns: torch.nn.Parameter: Cubic in `t`. """ assert(len(pars) == self.npar) return pars[0] + pars[1] * t + pars[2] * t**2 + pars[3] * t**3
[docs] class Poly(LayerFcn): """Polynomial weight parameterization. Attributes: npar (int): Number of parameters. """
[docs] def __init__(self, order): """Instantiation. Args: order (int): Order of the polynomial. """ super().__init__() self.npar = order+1
[docs] def __call__(self, pars, t): """Call function. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Returns: torch.nn.Parameter: Polynomial in `t`. """ assert(len(pars) == self.npar) val = 0.0 for i in range(self.npar): val += pars[i]*t**i return val
[docs] class NonPar(LayerFcn): """Non-parameteric weight parameterization, i.e. effectively a regular ResNet without weight parameterization. Attributes: npar (int): Number of parameters. """
[docs] def __init__(self, npar): """Instantiation. Args: npar (int): Should be one more than the number of layers `L`. """ super().__init__() self.npar = npar
[docs] def __call__(self, pars, t): """Call function. Args: pars (list[torch.nn.Parameter]): List of parameters. t (float): 'Time', i.e. layer number. Returns: torch.nn.Parameter: Non-parameteric: a new parameter per `t` value (i.e. per layer). """ #print(len(pars), self.npar, t, t * self.npar, int(t * self.npar)) assert(len(pars) == self.npar) return pars[int(t * self.npar)]