K-FAC Methods#
The bayesian_lora.kfac
module provides functions for calculating
an approximate Fisher information matrix (or GGN) using Kronecker-factored
approximate curvature.
Recall that K-FAC first finds a block-diagonal approximation to the full Fisher / GGN. If we had a simple 4-layer network, then this would be:
Eeach of these blocks (\(\mathbf{G}_{\ell \ell}\)) are further
approximated as the product of two Kronecker factors, one corresponding to the
input activations, \(\mathbf{A}_{\ell-1}\), and another to the output
gradients, \(\mathbf{S}_{\ell}\). That is, for a particular layer /
nn.Module
indexed by \(\ell\), we approximate its block of the full
Fisher as
These factors (curvature information around the network’s current parameters)
are calculated over some dataset \(\mathcal{D}\), and this is what the
bayesian_lora.calculate_kronecker_factors()
function below calculates.
Rather than using numerical indices \(\ell \in \{1, 2, \ldots, L\}\), we use
the nn.Module
’s name to identify the different blocks, and return the
factors in dictionaries of type dict[str, t.Tensor]
.
Full-Rank K-FAC#
The simplest variant is a full-rank Kronecker factorisation, meaning that we store the \(\mathbf{A}\) and \(\mathbf{S}\) matrices exactly.
- bayesian_lora.calculate_kronecker_factors(model: Module, forward_call: Callable[[Module, Any], Float[Tensor, 'batch n_classes']], loader: DataLoader, n_kfac: int | None = None, lr_threshold: int = 512, target_module_keywords: list[str] = [''], exclude_bias: bool = False, use_tqdm: bool = False) dict[str, tuple[Float[Tensor, 'l_in l_in_or_n_kfac'], Float[Tensor, 'l_out l_out_or_n_kfac']]] #
Calculate the Kronecer factors, (A, S) for the likelihood, used to approximate the GGN / Fisher.
- Parameters:
model – the model for which we are calculating the Kronecker factors. Note that it needn’t have LoRA adapters.
forward_call – A function which accepts a batch from the provided data loader, and returns the parameters of the model’s predictive distribution, as a
Tensor
. Usually this contains the logits over each class label.loader – a data loader for the dataset with which to calculate the curvature / Kronecker factors
n_kfac – an optional integer rank to use for a low-rank approximation of large Kronecker factors. If this is
None
, then no low-rank approximations are used.lr_threshold – the threshold beyond which the side length of a Kronecker factor is considered large and a low-rank approximation is applied.
target_module_keywords – a list of keywords which identify the network modules whose parameters we want to include in the Hessian calculation. This is particularly useful when working with LoRA adapters. By deafult, this is
[""]
; targetting every module.exclude_bias – whether to ignore bias terms (NOTE: this is a hack and should not be used)
use_tqdm – whether to show progress with
tqdm
.
Warning
This function has only been implemented for nn.Linear. Models implemented using Conv1D (e.g. GPT2) will sadly not work for now.
Examples
Full-rank Kronecker factor calculation.
>>> factors = calculate_kronecker_factors( >>> model, fwd_call, loader >>> )
Low-rank Kronecker factors on LoRA adaptors with inputs
>>> factors = calculate_kronecker_factors( >>> model, fwd_call, loader, n_kfac=10, >>> lr_threshold=512, target_module_keywords=["lora"], >>> )
- Returns:
A dictionary containing the Kronecker factors; keyed by module name, containing a tuple (A, S) with the activation factor (A) as the first element, and the output gradient factor (S) as the second element.
Notice how these Kronecker factors can themselves be approximated as low-rank which is particularly useful for LLMs, where the factors may be \(4096 \times 4096\) for each layer in a transformer.
Internal Functions#
The above is the main way to use the K-FAC functionality from this library. It calls a number of internal functions, which we document here for re-use and completeness.
- bayesian_lora.kfac.register_hooks(model: Module, activations: dict[str, Float[Tensor, 'l_in l_in_or_n_kfac']], output_grads: dict[str, Float[Tensor, 'l_out l_out_or_n_kfac']], target_module_keywords: list[str], n_kfac: int | None = 10, lr_threshold: int = 100, exclude_bias: bool = False) list[RemovableHandle] #
Registers the activation and output gradient hooks.
- Parameters:
model – the
nn.Module
on which to attach the hooks (usually the full model)activations – dictionary in which to store the parameter activations. The side length is
l_in
(i.e. equal to the number of input features in layerl
), orl_in + 1
if there is a bias. The last dimension isn_kfac
ifl_in >= lr_threshold
.output_grads – dictionary in which to store the output gradients. The side length
l_out
is equal to the number of output features of layerl
(regardless of the presence of a bias; unlike the activations). The last dimension isn_kfac
ifl_out >= lr_threshold
target_module_keywords – a list of the network modules to include in the GGN. Note, only nn.Linear layers are currently supported.
n_kfac – the rank we use to approximate large Kronecker factors. If set to None, we treat all factors as full rank (turns off the lr approximation).
lr_threshold – threshold beyond which to consider a layer’s input to be wide (to decide whether to approximate a Kronecker factor as low rank). LoRA layers with a wide input (e.g. LoRA-A) will have a low-rank approximation of their activation Kronecker factor, A, while LoRA layers with a narrow input (e.g. LoRA-B) will have a low-rank approximation of their output-gradient Kronecker factor, S.
exclude_bias – whether to ignore bias terms (just consider the weights)
- Returns:
a list of hooks (for later removal),
- bayesian_lora.kfac.remove_hooks(hooks: list[RemovableHandle]) None #
Remove the hooks from the module.
- Parameters:
hooks – list of hooks, returned from register_hooks
- bayesian_lora.kfac.save_input_hook(module_name: str, activations: dict[str, Float[Tensor, 'l_in l_in_or_n_kfac']], n_kfac: int | None, lr_threshold: int, has_bias: bool = False, svd_dtype: dtype = torch.float64)#
A closure which returns a new hook to capture a layer’s input activations.
- Parameters:
module_name – name used as a key for the ‘activations’ dictionary. While modules themselves can be hashed, this makes the Kronecker factors more portable.
activations – mapping from layer / module name to input activation Kronecker factor
n_kfac – the rank we use if we’re using a low rank appproximation to this Kronecker factor
lr_threshold – if the side length l_in+1 exceeds this threshold, and n_kfac is not none, treat the factor as low-rank
has_bias – does this layer have a bias?
svd_dtype – dtype to cast tensors to for SVD calculations
- bayesian_lora.kfac.save_output_grad_hook(module_name: str, output_grads: dict[str, Float[Tensor, 'l_out l_out_or_n_kfac']], n_kfac: int | None, lr_threshold: int, svd_dtype: dtype = torch.float64)#
A closure which returns a new hook to capture a layer’s output gradients.
- Parameters:
module_name – name used as a key for the ‘output_grads’ dictionary. While modules themselves can be hashed, this makes the Kronecker factors more portable.
output_grads – mapping from layer / module name to the output gradient Kronecker factor.
n_kfac – the rank we use if we’re using a low rank appproximation to this Kronecker factor
lr_threshold – if the side length l_in+1 exceeds this threshold, and n_kfac is not none, treat the factor as low-rank
svd_dtype – dtype to cast tensors to for SVD calculations