FishLeg package

Submodules

FishLeg.fishleg module

class FishLeg.fishleg.FishLeg(model: Module, draw: Callable[[Module, Tensor], Tuple[Tensor, Tensor]], nll: Callable[[Module, Tuple[Tensor, Tensor]], Tensor], aux_dataloader: DataLoader, likelihood: Optional[FishLikelihood] = None, fish_lr: float = 0.05, damping: float = 0.5, weight_decay: float = 1e-05, beta: float = 0.9, update_aux_every: int = 10, aux_lr: float = 0.0001, aux_betas: Tuple[float, float] = (0.9, 0.999), aux_eps: float = 1e-08, num_steps=None, batch_speedup: bool = False, full: bool = True, normalization: bool = False, fine_tune: bool = False, module_names: List[str] = [], skip_names: List[str] = [], initialization: str = 'uniform', scale: float = 1.0, warmup: int = 0, warmup_data: Optional[DataLoader] = None, warmup_loss: Optional[Callable] = None, device: str = 'cpu', config=None, verbose=False)

Bases: Optimizer

Implement FishLeg algorithm.

As described in https://openreview.net/forum?id=c9lAOPvQHS.

Parameters
  • model (torch.nn.Module) – a pytorch neural network module, can be nested in a tree structure

  • draw (Callable[[nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]) – Sampling function that takes a model \(f\) and input data \(\mathbf X\), and returns \((\mathbf X, \mathbf y)\), where \(\mathbf y\) is sampled from the conditional distribution \(p(\mathbf y|f(\mathbf X))\)

  • nll (Callable[[nn.Module, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]) – A function that takes a model and data, and evaluate the negative log-likelihood.

  • aux_dataloader (torch.utiles.data.DataLoader) – A function that takes a batch size as input and output dataset with corresponding size.

:param FishLikelihood likelihooda FishLeg likelihood, with Qv method if

any parameters are learnable.

Parameters
  • fish_lr (float) – Learning rate, for the parameters of the input model using FishLeg (default: 1e-2)

  • damping (float) – Static damping applied to Fisher matrix, \(\gamma\), for stability when FIM becomes near-singular. (default: 5e-1)

  • weight_decay (float) – L2 penalty on weights (default: 1e-5)

  • beta (float) – coefficient for running averages of gradient (default: 0.9)

  • update_aux_every (int) – Number of iteration after which an auxiliary update is executed, if negative, then run -update_aux_every auxiliary updates in each outer iteration. (default: 10)

  • aux_lr (float) – learning rate for the auxiliary parameters, using Adam (default: 1e-3)

  • aux_betas (Tuple[float, float]) – Coefficients used for computing running averages of gradient and its square for auxiliary parameters (default: (0.9, 0.999))

  • aux_eps (float) – Term added to the denominator to improve numerical stability for auxiliary parameters (default: 1e-8)

  • batch_speedup (bool) – Whether to use speed-up Qv product (default: False)

  • full (bool) – Whether to use full inner and outer diagonal rescalling for block Kronecker approximation of Q. (default: True)

  • normalization (bool) – Whether to use normalization on gradients when calculating the auxiliary loss, this is important to learn about curvature even when gradients are small (default: False)

  • fine_tune (bool) – Whether to use Fisher as preconditioner of pretrained tasks, and fine-tune on a downstream task. If True, Q will be fixed and continual learning will be performed (default: False)

  • module_names (List) – A List of module names wished to be optimized/pruned by FishLeg. (default: [], meaning all modules optimized/pruned by FishLeg)

  • initialization (string) – Initialization of weights (default: uniform)

  • warmup (int) – If warmup is zero, the default SGD warmup will be used, where Q is initialized as a scaled identity matrix. If warmup is positive, the diagonal of Q will be initialized as \(\frac{1}{g^2 + \gamma}\); and in this case, warmup_data and warmup_loss should be provided for sampling of gradients.

  • scale (float) – Help specify initial scale of the inverse Fisher Information matrix approximation. If using SGD warmup we suggest, \(\eta=\gamma^{-1}\). If warmup is positive, scale should be 1. (default: 1)

  • device (str) – The device where calculations will be performed using PyTorch Tensors.

Example:
>>> aux_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100)
>>> train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100)
>>>
>>> likelihood = FixedGaussianLikelihood(sigma=1.0)
>>>
>>> def nll(model, data_x, data_y):
>>>     pred_y = model.forward(data_x)
>>>     return likelihood.nll(data_y, pred_y)
>>>
>>> def draw(model, data_x):
>>>     pred_y = model.forward(data_x)
>>>     return likelihood.draw(pred_y)
>>>
>>> model = nn.Sequential(
>>>     nn.Linear(2, 5),
>>>     nn.ReLU(),
>>>     nn.Linear(5, 1),
>>> )
>>>
>>> opt = FishLeg(
>>>     model,
>>>     draw,
>>>     nll,
>>>     aux_loader
>>> )
>>>
>>> for data_x, data_y in dataloader:
>>>     opt.zero_grad()
>>>     pred_y = model(data_x)
>>>     loss = nn.MSELoss()(data_y, pred_y)
>>>     loss.backward()
>>>     opt.step()
>>>     if iteration % 10 == 0:
>>>         print(loss.detach())
init_model_aux(model: Module, module_names: List[str], skip_names: List[str], config=None) Union[Module, List]

Given a model to optimize, parameters can be devided to

  1. those fixed as pre-trained.

  2. those required to optimize using FishLeg.

Replace modules in the second group with FishLeg modules.

Args:
model (torch.nn.Module, required):

A model containing modules to replace with FishLeg modules containing extra functionality related to FishLeg algorithm.

Returns:

torch.nn.Module, the replaced model.

pretrain_fish(dataloader: DataLoader, loss: Callable[[Module, Tuple[Tensor, Tensor]], Tensor], iterations: int = 10000, difference: bool = False, verbose: bool = False, testloader: Optional[DataLoader] = None, batch_size: int = 500, fisher: bool = True) List
step(closure=None) None

Performes a single optimization step of FishLeg.

update_aux(train=True, fisher=True) None

Performs a single auxliarary parameter update using Adam. By minimizing the following objective:

\[nll(model, \theta + \epsilon Q(\lambda)g) + nll(model, \theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g\]

where \(\theta\) is the parameters of model, \(\lambda\) is the auxliarary parameters.

warmup_aux(dataloader: Optional[DataLoader] = None, loss: Optional[Callable[[Module, Tuple[Tensor, Tensor]], Tensor]] = None, scale: float = 1.0) None

Warm up auxilirary parameters, if warmup is larger zero, follow approxiamte Adam, if warmup is zero, follow SGD

FishLeg.fishleg_layers module

class FishLeg.fishleg_layers.FishBatchNorm2d(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, init_scale=1.0, device=None, dtype=None)

Bases: BatchNorm2d, FishModule

Qv(v: Tuple, full=False)

\(Q(\lambda)\) is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for \(Q\) should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks.

Args:
aux: (Dict, required): auxiliary parameters,

\(\lambda\), a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form \(Q(\lambda)\).

v: (Tuple[Tensor, …], required): Values of the original parameters,

in an order that align with self.order, to multiply with \(Q(\lambda)\).

full: (bool, optional), whether to use full inner and outer re-scaling

Returns:
Tuple[Tensor, …]: The calculated \(Q(\lambda)v\) products,

in same order with self.order.

affine: bool
diagQ()
eps: float
momentum: float
num_features: int
track_running_stats: bool
class FishLeg.fishleg_layers.FishConv2d(in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None)

Bases: Conv2d, FishModule

Qv(v: Tuple[Tensor, Optional[Tensor]], full: bool = False) Tuple[Tensor, Optional[Tensor]]

Inspired by KFAC’s conv2D layer by Grosse and Martens: Kronecker product of sizes (out_channels ⊗ (in_channels_eff * k_size))

bias: Optional[Tensor]
diagQ() Tensor

Similar maths as the Linear layer

dilation: Tuple[int, ...]
groups: int
in_channels: int
kernel_size: Tuple[int, ...]
out_channels: int
output_padding: Tuple[int, ...]
padding: Union[str, Tuple[int, ...]]
padding_mode: str
stride: Tuple[int, ...]
transposed: bool
warmup(v: Optional[Tuple[Tensor, Tensor]] = None, init_scale: float = 1.0) None
weight: Tensor
class FishLeg.fishleg_layers.FishLayerNorm(normalized_shape, eps: float = 1e-05, elementwise_affine: bool = True, init_scale=1.0, device=None, dtype=None)

Bases: LayerNorm, FishModule

Qv(v: Tuple, full=False)

\(Q(\lambda)\) is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for \(Q\) should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks.

Args:
aux: (Dict, required): auxiliary parameters,

\(\lambda\), a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form \(Q(\lambda)\).

v: (Tuple[Tensor, …], required): Values of the original parameters,

in an order that align with self.order, to multiply with \(Q(\lambda)\).

full: (bool, optional), whether to use full inner and outer re-scaling

Returns:
Tuple[Tensor, …]: The calculated \(Q(\lambda)v\) products,

in same order with self.order.

diagQ()
elementwise_affine: bool
eps: float
normalized_shape: Tuple[int, ...]
class FishLeg.fishleg_layers.FishLinear(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None)

Bases: Linear, FishModule

Qg() Tuple[Tensor, Tensor]

Speed up Qg product, when batch size is smaller than parameter size. By chain rule:

\[DW_i = g_i\hat{a}^T_{i-1}\]

where \(DW_i\) is gradient of parameter of the ith layer, \(g_i\) is gradient w.r.t output of ith layer and \(\hat{a}_i\) is input to ith layer, and output of (i-1)th layer.

Qv(v: Tuple[Tensor, Tensor], full: bool = False) Tuple[Tensor, Tensor]

For fully-connected layers, the default structure of \(Q\) as a block-diaglonal matrix is, .. math:

Q_l = (R_lR_l^T \otimes L_lL_l^T)

where \(l\) denotes the l-th layer. The matrix \(R_l\) has size \((N_{l-1} + 1) \times (N_{l-1} + 1)\) while the matrix \(L_l\) has size \(N_l \times N_l\). The auxiliarary parameters \(\lambda\) are represented by the matrices \(L_l, R_l\). For a Kronecker form that introduces full inner and outer diagonal rescaling structure is,

\[Q_l = A_l(L_l \otimes R_l^T) D_l^2 (L_l^T \otimes R_l) A_l\]

where \(A_l\) and \(D_l\) are two diagonal matrices of the appropriate size.

diagQ() Tuple

The Q matrix defines the inverse fisher approximation as below:

\[Q_l = (R_lR_l^T \otimes L_lL_l^T)\]

where \(l\) denotes the l-th layer. The matrix \(R_l\) has size \((N_{l-1} + 1) \times (N_{l-1} + 1)\) while the matrix \(L_l\) has size \(N_l \times N_l\). The auxiliarary parameters \(\lambda\) are represented by the matrices \(L_l, R_l\).

The diagonal of this matrix is therefore calculated by

\[\text{diag}(Q_l) = \text{diag}(R_l R_l^T) \otimes \text{diag}(L_l L_l^T)\]

where \(\text{diag}\) involves summing over the columns of the and \(\otimes\) remains as the Kronecker product.

in_features: int
out_features: int
save_layer_grad_output(grad_output: Tuple[Tensor, ...]) None
save_layer_input(input_: List[Tensor]) None
warmup(v: Optional[Tuple[Tensor, Tensor]] = None, batch_speedup: bool = False, init_scale: float = 1.0) None
weight: Tensor

FishLeg.fishleg_likelihood module

class FishLeg.fishleg_likelihood.BernoulliLikelihood(device: str = 'cpu')

Bases: FishLikelihood

The Bernoulli likelihood used for classification. Using the standard Normal CDF \(\Phi(x)\)) and the identity \(\Phi(-x) = 1-\Phi(x)\), we can write the likelihood as:

\[p(y|f(x))=\Phi(yf(x))\]
draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

nll(preds: Tensor, observations: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.FishLikelihood

Bases: object

A Likelihood in FishLeg specifies a probablistic modeling, which attributes the mapping from latent function values \(f(\mathbf X)\) to observed labels \(y\).

For example, in the case of regression, a Gaussian likelihood can be chosen, as

\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]

As for the case of classification, a Bernoulli distribution can be chosen

\[\begin{split}y(\mathbf x) = \begin{cases} 1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\ 0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x)) \end{cases}\end{split}\]
abstract draw(preds, **kwargs)

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

get_parameters() List

return a list of learnable parameter.

abstract nll(preds, observations, **kwargs)

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.FixedGaussianLikelihood(sigma: Tensor, device: str = 'cpu')

Bases: FishLikelihood

The standard likelihood for regression, but assuming fixed heteroscedastic noise.

\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]
Parameters

sigma (torch.Tensor) – Known observation standard deviation for each example.

draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

property get_variance: Tensor
nll(preds: Tensor, observations: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.GaussianLikelihood(sigma: Tensor, device: str = 'cpu')

Bases: FishLikelihood

The standard likelihood for regression, but assuming fixed heteroscedastic noise.

\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]
Parameters

sigma (torch.Tensor) – standard deviation for each example; also to be learned during training.

Qv(v) List
draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

get_aux_parameters() List
get_parameters() List

return a list of learnable parameter.

init_aux(init_scale) None
nll(preds: Tensor, observations: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.SoftMaxLikelihood(device: str = 'cpu')

Bases: FishLikelihood

draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

nll(preds: Tensor, observations: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

FishLeg.utils module

FishLeg.utils.get_named_layers_by_regex(module: Module, param_names: List[str], params_strict: bool = False) List[NamedLayerParam]
Parameters
  • module – the module to get the matching layers and params from

  • param_names – a list of names or regex patterns to match with full parameter paths. Regex patterns must be specified with the prefix ‘re:’

  • params_strict – if True, this function will raise an exception if there a parameter is not found to match every name or regex in param_names

Returns

a list of NamedLayerParam tuples whose full parameter names in the given module match one of the given regex patterns or parameter names

FishLeg.utils.recursive_getattr(obj, attr)
FishLeg.utils.recursive_setattr(obj, attr, value)
FishLeg.utils.update_dict(replace: Module, module: Module) Module

Module contents