FishLeg package
Submodules
FishLeg.fishleg module
- class FishLeg.fishleg.FishLeg(model: Module, draw: Callable[[Module, Tensor], Tuple[Tensor, Tensor]], nll: Callable[[Module, Tuple[Tensor, Tensor]], Tensor], aux_dataloader: DataLoader, likelihood: Optional[FishLikelihood] = None, fish_lr: float = 0.05, damping: float = 0.5, weight_decay: float = 1e-05, beta: float = 0.9, update_aux_every: int = 10, aux_lr: float = 0.0001, aux_betas: Tuple[float, float] = (0.9, 0.999), aux_eps: float = 1e-08, num_steps=None, batch_speedup: bool = False, full: bool = True, normalization: bool = False, fine_tune: bool = False, module_names: List[str] = [], skip_names: List[str] = [], initialization: str = 'uniform', scale: float = 1.0, warmup: int = 0, warmup_data: Optional[DataLoader] = None, warmup_loss: Optional[Callable] = None, device: str = 'cpu', config=None, verbose=False)
Bases:
Optimizer
Implement FishLeg algorithm.
As described in https://openreview.net/forum?id=c9lAOPvQHS.
- Parameters
model (torch.nn.Module) – a pytorch neural network module, can be nested in a tree structure
draw (Callable[[nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]) – Sampling function that takes a model \(f\) and input data \(\mathbf X\), and returns \((\mathbf X, \mathbf y)\), where \(\mathbf y\) is sampled from the conditional distribution \(p(\mathbf y|f(\mathbf X))\)
nll (Callable[[nn.Module, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]) – A function that takes a model and data, and evaluate the negative log-likelihood.
aux_dataloader (torch.utiles.data.DataLoader) – A function that takes a batch size as input and output dataset with corresponding size.
- :param FishLikelihood likelihooda FishLeg likelihood, with Qv method if
any parameters are learnable.
- Parameters
fish_lr (float) – Learning rate, for the parameters of the input model using FishLeg (default: 1e-2)
damping (float) – Static damping applied to Fisher matrix, \(\gamma\), for stability when FIM becomes near-singular. (default: 5e-1)
weight_decay (float) – L2 penalty on weights (default: 1e-5)
beta (float) – coefficient for running averages of gradient (default: 0.9)
update_aux_every (int) – Number of iteration after which an auxiliary update is executed, if negative, then run -update_aux_every auxiliary updates in each outer iteration. (default: 10)
aux_lr (float) – learning rate for the auxiliary parameters, using Adam (default: 1e-3)
aux_betas (Tuple[float, float]) – Coefficients used for computing running averages of gradient and its square for auxiliary parameters (default: (0.9, 0.999))
aux_eps (float) – Term added to the denominator to improve numerical stability for auxiliary parameters (default: 1e-8)
batch_speedup (bool) – Whether to use speed-up Qv product (default: False)
full (bool) – Whether to use full inner and outer diagonal rescalling for block Kronecker approximation of Q. (default: True)
normalization (bool) – Whether to use normalization on gradients when calculating the auxiliary loss, this is important to learn about curvature even when gradients are small (default: False)
fine_tune (bool) – Whether to use Fisher as preconditioner of pretrained tasks, and fine-tune on a downstream task. If True, Q will be fixed and continual learning will be performed (default: False)
module_names (List) – A List of module names wished to be optimized/pruned by FishLeg. (default: [], meaning all modules optimized/pruned by FishLeg)
initialization (string) – Initialization of weights (default: uniform)
warmup (int) – If warmup is zero, the default SGD warmup will be used, where Q is initialized as a scaled identity matrix. If warmup is positive, the diagonal of Q will be initialized as \(\frac{1}{g^2 + \gamma}\); and in this case, warmup_data and warmup_loss should be provided for sampling of gradients.
scale (float) – Help specify initial scale of the inverse Fisher Information matrix approximation. If using SGD warmup we suggest, \(\eta=\gamma^{-1}\). If warmup is positive, scale should be 1. (default: 1)
device (str) – The device where calculations will be performed using PyTorch Tensors.
- Example:
>>> aux_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100) >>> train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100) >>> >>> likelihood = FixedGaussianLikelihood(sigma=1.0) >>> >>> def nll(model, data_x, data_y): >>> pred_y = model.forward(data_x) >>> return likelihood.nll(data_y, pred_y) >>> >>> def draw(model, data_x): >>> pred_y = model.forward(data_x) >>> return likelihood.draw(pred_y) >>> >>> model = nn.Sequential( >>> nn.Linear(2, 5), >>> nn.ReLU(), >>> nn.Linear(5, 1), >>> ) >>> >>> opt = FishLeg( >>> model, >>> draw, >>> nll, >>> aux_loader >>> ) >>> >>> for data_x, data_y in dataloader: >>> opt.zero_grad() >>> pred_y = model(data_x) >>> loss = nn.MSELoss()(data_y, pred_y) >>> loss.backward() >>> opt.step() >>> if iteration % 10 == 0: >>> print(loss.detach())
- init_model_aux(model: Module, module_names: List[str], skip_names: List[str], config=None) Union[Module, List]
Given a model to optimize, parameters can be devided to
those fixed as pre-trained.
those required to optimize using FishLeg.
Replace modules in the second group with FishLeg modules.
- Args:
- model (
torch.nn.Module
, required): A model containing modules to replace with FishLeg modules containing extra functionality related to FishLeg algorithm.
- model (
- Returns:
torch.nn.Module
, the replaced model.
- pretrain_fish(dataloader: DataLoader, loss: Callable[[Module, Tuple[Tensor, Tensor]], Tensor], iterations: int = 10000, difference: bool = False, verbose: bool = False, testloader: Optional[DataLoader] = None, batch_size: int = 500, fisher: bool = True) List
- step(closure=None) None
Performes a single optimization step of FishLeg.
- update_aux(train=True, fisher=True) None
Performs a single auxliarary parameter update using Adam. By minimizing the following objective:
\[nll(model, \theta + \epsilon Q(\lambda)g) + nll(model, \theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g\]where \(\theta\) is the parameters of model, \(\lambda\) is the auxliarary parameters.
- warmup_aux(dataloader: Optional[DataLoader] = None, loss: Optional[Callable[[Module, Tuple[Tensor, Tensor]], Tensor]] = None, scale: float = 1.0) None
Warm up auxilirary parameters, if warmup is larger zero, follow approxiamte Adam, if warmup is zero, follow SGD
FishLeg.fishleg_layers module
- class FishLeg.fishleg_layers.FishBatchNorm2d(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, init_scale=1.0, device=None, dtype=None)
Bases:
BatchNorm2d
,FishModule
- Qv(v: Tuple, full=False)
\(Q(\lambda)\) is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for \(Q\) should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks.
- Args:
- aux: (Dict, required): auxiliary parameters,
\(\lambda\), a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form \(Q(\lambda)\).
- v: (Tuple[Tensor, …], required): Values of the original parameters,
in an order that align with self.order, to multiply with \(Q(\lambda)\).
full: (bool, optional), whether to use full inner and outer re-scaling
- Returns:
- Tuple[Tensor, …]: The calculated \(Q(\lambda)v\) products,
in same order with self.order.
- affine: bool
- diagQ()
- eps: float
- momentum: float
- num_features: int
- track_running_stats: bool
- class FishLeg.fishleg_layers.FishConv2d(in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None)
Bases:
Conv2d
,FishModule
- Qv(v: Tuple[Tensor, Optional[Tensor]], full: bool = False) Tuple[Tensor, Optional[Tensor]]
Inspired by KFAC’s conv2D layer by Grosse and Martens: Kronecker product of sizes (out_channels ⊗ (in_channels_eff * k_size))
- bias: Optional[Tensor]
- diagQ() Tensor
Similar maths as the Linear layer
- dilation: Tuple[int, ...]
- groups: int
- in_channels: int
- kernel_size: Tuple[int, ...]
- out_channels: int
- output_padding: Tuple[int, ...]
- padding: Union[str, Tuple[int, ...]]
- padding_mode: str
- stride: Tuple[int, ...]
- transposed: bool
- warmup(v: Optional[Tuple[Tensor, Tensor]] = None, init_scale: float = 1.0) None
- weight: Tensor
- class FishLeg.fishleg_layers.FishLayerNorm(normalized_shape, eps: float = 1e-05, elementwise_affine: bool = True, init_scale=1.0, device=None, dtype=None)
Bases:
LayerNorm
,FishModule
- Qv(v: Tuple, full=False)
\(Q(\lambda)\) is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for \(Q\) should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks.
- Args:
- aux: (Dict, required): auxiliary parameters,
\(\lambda\), a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form \(Q(\lambda)\).
- v: (Tuple[Tensor, …], required): Values of the original parameters,
in an order that align with self.order, to multiply with \(Q(\lambda)\).
full: (bool, optional), whether to use full inner and outer re-scaling
- Returns:
- Tuple[Tensor, …]: The calculated \(Q(\lambda)v\) products,
in same order with self.order.
- diagQ()
- elementwise_affine: bool
- eps: float
- normalized_shape: Tuple[int, ...]
- class FishLeg.fishleg_layers.FishLinear(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None)
Bases:
Linear
,FishModule
- Qg() Tuple[Tensor, Tensor]
Speed up Qg product, when batch size is smaller than parameter size. By chain rule:
\[DW_i = g_i\hat{a}^T_{i-1}\]where \(DW_i\) is gradient of parameter of the ith layer, \(g_i\) is gradient w.r.t output of ith layer and \(\hat{a}_i\) is input to ith layer, and output of (i-1)th layer.
- Qv(v: Tuple[Tensor, Tensor], full: bool = False) Tuple[Tensor, Tensor]
For fully-connected layers, the default structure of \(Q\) as a block-diaglonal matrix is, .. math:
Q_l = (R_lR_l^T \otimes L_lL_l^T)
where \(l\) denotes the l-th layer. The matrix \(R_l\) has size \((N_{l-1} + 1) \times (N_{l-1} + 1)\) while the matrix \(L_l\) has size \(N_l \times N_l\). The auxiliarary parameters \(\lambda\) are represented by the matrices \(L_l, R_l\). For a Kronecker form that introduces full inner and outer diagonal rescaling structure is,
\[Q_l = A_l(L_l \otimes R_l^T) D_l^2 (L_l^T \otimes R_l) A_l\]where \(A_l\) and \(D_l\) are two diagonal matrices of the appropriate size.
- diagQ() Tuple
The Q matrix defines the inverse fisher approximation as below:
\[Q_l = (R_lR_l^T \otimes L_lL_l^T)\]where \(l\) denotes the l-th layer. The matrix \(R_l\) has size \((N_{l-1} + 1) \times (N_{l-1} + 1)\) while the matrix \(L_l\) has size \(N_l \times N_l\). The auxiliarary parameters \(\lambda\) are represented by the matrices \(L_l, R_l\).
The diagonal of this matrix is therefore calculated by
\[\text{diag}(Q_l) = \text{diag}(R_l R_l^T) \otimes \text{diag}(L_l L_l^T)\]where \(\text{diag}\) involves summing over the columns of the and \(\otimes\) remains as the Kronecker product.
- in_features: int
- out_features: int
- save_layer_grad_output(grad_output: Tuple[Tensor, ...]) None
- save_layer_input(input_: List[Tensor]) None
- warmup(v: Optional[Tuple[Tensor, Tensor]] = None, batch_speedup: bool = False, init_scale: float = 1.0) None
- weight: Tensor
FishLeg.fishleg_likelihood module
- class FishLeg.fishleg_likelihood.BernoulliLikelihood(device: str = 'cpu')
Bases:
FishLikelihood
The Bernoulli likelihood used for classification. Using the standard Normal CDF \(\Phi(x)\)) and the identity \(\Phi(-x) = 1-\Phi(x)\), we can write the likelihood as:
\[p(y|f(x))=\Phi(yf(x))\]- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- nll(preds: Tensor, observations: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.FishLikelihood
Bases:
object
A Likelihood in FishLeg specifies a probablistic modeling, which attributes the mapping from latent function values \(f(\mathbf X)\) to observed labels \(y\).
For example, in the case of regression, a Gaussian likelihood can be chosen, as
\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]As for the case of classification, a Bernoulli distribution can be chosen
\[\begin{split}y(\mathbf x) = \begin{cases} 1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\ 0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x)) \end{cases}\end{split}\]- abstract draw(preds, **kwargs)
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- get_parameters() List
return a list of learnable parameter.
- abstract nll(preds, observations, **kwargs)
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.FixedGaussianLikelihood(sigma: Tensor, device: str = 'cpu')
Bases:
FishLikelihood
The standard likelihood for regression, but assuming fixed heteroscedastic noise.
\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]- Parameters
sigma (torch.Tensor) – Known observation standard deviation for each example.
- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- property get_variance: Tensor
- nll(preds: Tensor, observations: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.GaussianLikelihood(sigma: Tensor, device: str = 'cpu')
Bases:
FishLikelihood
The standard likelihood for regression, but assuming fixed heteroscedastic noise.
\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]- Parameters
sigma (torch.Tensor) – standard deviation for each example; also to be learned during training.
- Qv(v) List
- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- get_aux_parameters() List
- get_parameters() List
return a list of learnable parameter.
- init_aux(init_scale) None
- nll(preds: Tensor, observations: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.SoftMaxLikelihood(device: str = 'cpu')
Bases:
FishLikelihood
- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- nll(preds: Tensor, observations: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
FishLeg.utils module
- FishLeg.utils.get_named_layers_by_regex(module: Module, param_names: List[str], params_strict: bool = False) List[NamedLayerParam]
- Parameters
module – the module to get the matching layers and params from
param_names – a list of names or regex patterns to match with full parameter paths. Regex patterns must be specified with the prefix ‘re:’
params_strict – if True, this function will raise an exception if there a parameter is not found to match every name or regex in param_names
- Returns
a list of NamedLayerParam tuples whose full parameter names in the given module match one of the given regex patterns or parameter names
- FishLeg.utils.recursive_getattr(obj, attr)
- FishLeg.utils.recursive_setattr(obj, attr, value)
- FishLeg.utils.update_dict(replace: Module, module: Module) Module