Bayesian Lora#
This file contains the main methods relating to the Bayesian Low-Rank Adaptation for Large Language Models paper. Namely, calculating the model evidence for tuning prior and network hyperparameters and calculating the posterior predictive parameters for making (linearised) predictions.
Model Evidence#
The model evidence, or marginal likelihood, is a scalar value that indicates the evidence provided by the data for a particular model. A model with a higher marginal likelihood is considered more supported by the data under the given prior.
- bayesian_lora.main.model_evidence(model: Module, LL: Tensor, factors: dict[str, tuple[Float[Tensor, 'l_in l_in_or_n_kfac'], Float[Tensor, 'l_out l_out_or_n_kfac']]], n_lora: int, n_kfac: int, s2: Float[Tensor, '1']) Float[Tensor, '1'] #
Use this function to calculate the marginal likelihood / model evidence; for instance to tune the value of s2 (prior variance).
- Parameters:
model – your model
LL – the log likelihood on a dataset of interest
factors – dictionary of Kronecker factors
n_lora – LoRA rank
n_kfac – K-FAC rank
s2 – prior variance
- Returns:
model evidence
Posterior Predictive#
This involves two steps, calculating the mean and the variance.
For the first, we invoke the (admittedly, awkwardly named) jacobian_mean
function, which returns the Jacobian, and the mean, respectively.
- bayesian_lora.main.jacobian_mean(model: Module, batch_inputs: BatchEncoding, target_ids: Tensor | None = None, is_s2s: bool = False, output_callback: Callable[[ModelOutput], Tensor] | None = None) tuple[dict[str, Tensor], Tensor] #
Calculates the Jacobian and logit means
- Parameters:
model – the LoRA LLM from which to make predictions
batch_inputs – the batch inputs, exactly as you would pass them into your model with
model(**inputs)
.target_ids – selects specific model outputs. Leave this as None if either a) you wish to consider all model outputs or b) you are providing an output_callback to post-process the model output.
is_s2s – whether this is an s2s model. Can omit if providing an output_callback
output_callback – a function that takes the results of
model(**batch_inputs)
and returns the logits of interest
- Returns:
The Jacobian (a dictionary of module keys and Jacobian Tensors) and the logit mean predictions.
As you can see, there are two ways of calling this function, which determine how we’ll handle the outputs from the wrapped network call.
Directly, with parameters Here, we assume that a model is either a sequence-to-sequence model or not (defaults to
False
), and that we may optionally want to pick out some specific logits from the model’s full vocabulary:jacobian, f_mu = jacobian_mean( model, batch_inputs, target_ids=dset.target_ids, is_s2s=False )
Custom output callback Here, we allow the user to provide a callback function, taking in the result of the model’s
forward
call, and returning the logits of interest, with arbitrary post-processing in between.def default_output_callback(outputs: ModelOutput) -> Tensor: logits = outputs.logits if cfg.llm.is_s2s else outputs.logits[:, -1] target_logits = logits[:, dset.target_ids] return target_logits jacobian, f_mu = jacobian_mean( model, batch_inputs, output_callback=output_callback )
For the second step, we calculate the output logits’ covariance matrix.
- bayesian_lora.main.variance(inputs, jacobian, factors: dict[str, tuple[Float[Tensor, 'l_in l_in_or_n_kfac'], Float[Tensor, 'l_out l_out_or_n_kfac']]], s2: Tensor, n_logits: int, n_lora: int, n_kfac: int, device: str)#
Calculates the variance matrix for performing (linearised) prediction.
- Parameters:
inputs (dict) – tokenized batch of inputs (returned from a HF Tokenizer)
jacobian (dict) – a dictionary of first derivatives for each of the target module’s parameters
factors – dictionary of Kronecker factors
s2 – prior variance (scalar valued tensor)
n_logits – the number of logits to predict (e.g. the number of classes in your Categorical likelihood)
n_lora – rank used in the LoRA adapters
n_kfac – rank used for the low-rank approximation of large Kronekcer factors
device – device on which to accumulate the variance matrix