ens

learner

Module for a Learner class that is a wrapper with basic training/prediction functionality.

class quinn.ens.learner.Learner(nnmodel, verbose=False)[source]

Bases: object

A learner class that holds PyTorch NN module and helps train it.

nnmodel

Main PyTorch NN module.

Type:

torch.nn.Module

best_model

The best trained PyTorch NN module.

Type:

torch.nn.Module

trained

Whether the module is trained or not.

Type:

bool

verbose

Whether to be verbose or not.

Type:

bool

__init__(nnmodel, verbose=False)[source]

Initialization.

Parameters:
  • nnmodel (torch.nn.Module) – Main PyTorch NN module.

  • verbose (bool) – Whether to be verbose or not.

print_params(names_only=False)[source]

Print parameters of the learner’s model.

Parameters:

names_only (bool, optional) – Whether to print the parameter names only or not.

init_params()[source]

An example of random initialization of parameters.

Todo

we can and should enrich this.

fit(xtrn, ytrn, **kwargs)[source]

Fitting function for this learner.

Parameters:
  • xtrn (np.ndarray) – Input array of size (N,d).

  • ytrn (np.ndarray) – Output array of size (N,o).

  • **kwargs (dict) – Keyword arguments.

predict(x)[source]

Prediction of the learner.

Parameters:

x (np.ndarray) – Input array of size (N,d).

Returns:

Output array of size (N,o).

Return type:

np.ndarray