K-FAC Methods#

The bayesian_lora.kfac module provides functions for calculating an approximate Fisher information matrix (or GGN) using Kronecker-factored approximate curvature.

Recall that K-FAC first finds a block-diagonal approximation to the full Fisher / GGN. If we had a simple 4-layer network, then this would be:

Block-diagonal approximation

Eeach of these blocks (\(\mathbf{G}_{\ell \ell}\)) are further approximated as the product of two Kronecker factors, one corresponding to the input activations, \(\mathbf{A}_{\ell-1}\), and another to the output gradients, \(\mathbf{S}_{\ell}\). That is, for a particular layer / nn.Module indexed by \(\ell\), we approximate its block of the full Fisher as

(1)#\[\mathbf{G}_{\ell\ell} \approx \mathbf{A}_{\ell-1} \otimes \mathbf{S}_{\ell}.\]

These factors (curvature information around the network’s current parameters) are calculated over some dataset \(\mathcal{D}\), and this is what the bayesian_lora.calculate_kronecker_factors() function below calculates.

Rather than using numerical indices \(\ell \in \{1, 2, \ldots, L\}\), we use the nn.Module’s name to identify the different blocks, and return the factors in dictionaries of type dict[str, t.Tensor].

Full-Rank K-FAC#

The simplest variant is a full-rank Kronecker factorisation, meaning that we store the \(\mathbf{A}\) and \(\mathbf{S}\) matrices exactly.

bayesian_lora.calculate_kronecker_factors(model: Module, forward_call: Callable[[Module, Any], Float[Tensor, 'batch n_classes']], loader: DataLoader, n_kfac: int | None = None, lr_threshold: int = 512, target_module_keywords: list[str] = [''], exclude_bias: bool = False, use_tqdm: bool = False) dict[str, tuple[Float[Tensor, 'l_in l_in_or_n_kfac'], Float[Tensor, 'l_out l_out_or_n_kfac']]]#

Calculate the Kronecer factors, (A, S) for the likelihood, used to approximate the GGN / Fisher.

Parameters:
  • model – the model for which we are calculating the Kronecker factors. Note that it needn’t have LoRA adapters.

  • forward_call – A function which accepts a batch from the provided data loader, and returns the parameters of the model’s predictive distribution, as a Tensor. Usually this contains the logits over each class label.

  • loader – a data loader for the dataset with which to calculate the curvature / Kronecker factors

  • n_kfac – an optional integer rank to use for a low-rank approximation of large Kronecker factors. If this is None, then no low-rank approximations are used.

  • lr_threshold – the threshold beyond which the side length of a Kronecker factor is considered large and a low-rank approximation is applied.

  • target_module_keywords – a list of keywords which identify the network modules whose parameters we want to include in the Hessian calculation. This is particularly useful when working with LoRA adapters. By deafult, this is [""]; targetting every module.

  • exclude_bias – whether to ignore bias terms (NOTE: this is a hack and should not be used)

  • use_tqdm – whether to show progress with tqdm.

Warning

This function has only been implemented for nn.Linear. Models implemented using Conv1D (e.g. GPT2) will sadly not work for now.

Examples

Full-rank Kronecker factor calculation.

>>> factors = calculate_kronecker_factors(
>>>     model, fwd_call, loader
>>> )

Low-rank Kronecker factors on LoRA adaptors with inputs

>>> factors = calculate_kronecker_factors(
>>>     model, fwd_call, loader, n_kfac=10,
>>>     lr_threshold=512, target_module_keywords=["lora"],
>>> )
Returns:

A dictionary containing the Kronecker factors; keyed by module name, containing a tuple (A, S) with the activation factor (A) as the first element, and the output gradient factor (S) as the second element.

Notice how these Kronecker factors can themselves be approximated as low-rank which is particularly useful for LLMs, where the factors may be \(4096 \times 4096\) for each layer in a transformer.

Internal Functions#

The above is the main way to use the K-FAC functionality from this library. It calls a number of internal functions, which we document here for re-use and completeness.

bayesian_lora.kfac.register_hooks(model: Module, activations: dict[str, Float[Tensor, 'l_in l_in_or_n_kfac']], output_grads: dict[str, Float[Tensor, 'l_out l_out_or_n_kfac']], target_module_keywords: list[str], n_kfac: int | None = 10, lr_threshold: int = 100, exclude_bias: bool = False) list[RemovableHandle]#

Registers the activation and output gradient hooks.

Parameters:
  • model – the nn.Module on which to attach the hooks (usually the full model)

  • activations – dictionary in which to store the parameter activations. The side length is l_in (i.e. equal to the number of input features in layer l), or l_in + 1 if there is a bias. The last dimension is n_kfac if l_in >= lr_threshold.

  • output_grads – dictionary in which to store the output gradients. The side length l_out is equal to the number of output features of layer l (regardless of the presence of a bias; unlike the activations). The last dimension is n_kfac if l_out >= lr_threshold

  • target_module_keywords – a list of the network modules to include in the GGN. Note, only nn.Linear layers are currently supported.

  • n_kfac – the rank we use to approximate large Kronecker factors. If set to None, we treat all factors as full rank (turns off the lr approximation).

  • lr_threshold – threshold beyond which to consider a layer’s input to be wide (to decide whether to approximate a Kronecker factor as low rank). LoRA layers with a wide input (e.g. LoRA-A) will have a low-rank approximation of their activation Kronecker factor, A, while LoRA layers with a narrow input (e.g. LoRA-B) will have a low-rank approximation of their output-gradient Kronecker factor, S.

  • exclude_bias – whether to ignore bias terms (just consider the weights)

Returns:

  • a list of hooks (for later removal),

bayesian_lora.kfac.remove_hooks(hooks: list[RemovableHandle]) None#

Remove the hooks from the module.

Parameters:

hooks – list of hooks, returned from register_hooks

bayesian_lora.kfac.save_input_hook(module_name: str, activations: dict[str, Float[Tensor, 'l_in l_in_or_n_kfac']], n_kfac: int | None, lr_threshold: int, has_bias: bool = False, svd_dtype: dtype = torch.float64)#

A closure which returns a new hook to capture a layer’s input activations.

Parameters:
  • module_name – name used as a key for the ‘activations’ dictionary. While modules themselves can be hashed, this makes the Kronecker factors more portable.

  • activations – mapping from layer / module name to input activation Kronecker factor

  • n_kfac – the rank we use if we’re using a low rank appproximation to this Kronecker factor

  • lr_threshold – if the side length l_in+1 exceeds this threshold, and n_kfac is not none, treat the factor as low-rank

  • has_bias – does this layer have a bias?

  • svd_dtype – dtype to cast tensors to for SVD calculations

bayesian_lora.kfac.save_output_grad_hook(module_name: str, output_grads: dict[str, Float[Tensor, 'l_out l_out_or_n_kfac']], n_kfac: int | None, lr_threshold: int, svd_dtype: dtype = torch.float64)#

A closure which returns a new hook to capture a layer’s output gradients.

Parameters:
  • module_name – name used as a key for the ‘output_grads’ dictionary. While modules themselves can be hashed, this makes the Kronecker factors more portable.

  • output_grads – mapping from layer / module name to the output gradient Kronecker factor.

  • n_kfac – the rank we use if we’re using a low rank appproximation to this Kronecker factor

  • lr_threshold – if the side length l_in+1 exceeds this threshold, and n_kfac is not none, treat the factor as low-rank

  • svd_dtype – dtype to cast tensors to for SVD calculations