Bayesian Lora#

This file contains the main methods relating to the Bayesian Low-Rank Adaptation for Large Language Models paper. Namely, calculating the model evidence for tuning prior and network hyperparameters and calculating the posterior predictive parameters for making (linearised) predictions.

Model Evidence#

The model evidence, or marginal likelihood, is a scalar value that indicates the evidence provided by the data for a particular model. A model with a higher marginal likelihood is considered more supported by the data under the given prior.

bayesian_lora.main.model_evidence(model: Module, LL: Tensor, factors: dict[str, tuple[Float[Tensor, 'l_in l_in_or_n_kfac'], Float[Tensor, 'l_out l_out_or_n_kfac']]], n_lora: int, n_kfac: int, s2: Float[Tensor, '1']) Float[Tensor, '1']#

Use this function to calculate the marginal likelihood / model evidence; for instance to tune the value of s2 (prior variance).

Parameters:
  • model – your model

  • LL – the log likelihood on a dataset of interest

  • factors – dictionary of Kronecker factors

  • n_lora – LoRA rank

  • n_kfac – K-FAC rank

  • s2 – prior variance

Returns:

model evidence

Posterior Predictive#

This involves two steps, calculating the mean and the variance.

For the first, we invoke the (admittedly, awkwardly named) jacobian_mean function, which returns the Jacobian, and the mean, respectively.

bayesian_lora.main.jacobian_mean(model: Module, batch_inputs: BatchEncoding, target_ids: Tensor | None = None, is_s2s: bool = False, output_callback: Callable[[ModelOutput], Tensor] | None = None) tuple[dict[str, Tensor], Tensor]#

Calculates the Jacobian and logit means

Parameters:
  • model – the LoRA LLM from which to make predictions

  • batch_inputs – the batch inputs, exactly as you would pass them into your model with model(**inputs).

  • target_ids – selects specific model outputs. Leave this as None if either a) you wish to consider all model outputs or b) you are providing an output_callback to post-process the model output.

  • is_s2s – whether this is an s2s model. Can omit if providing an output_callback

  • output_callback – a function that takes the results of model(**batch_inputs) and returns the logits of interest

Returns:

The Jacobian (a dictionary of module keys and Jacobian Tensors) and the logit mean predictions.

As you can see, there are two ways of calling this function, which determine how we’ll handle the outputs from the wrapped network call.

  1. Directly, with parameters Here, we assume that a model is either a sequence-to-sequence model or not (defaults to False), and that we may optionally want to pick out some specific logits from the model’s full vocabulary:

    jacobian, f_mu = jacobian_mean(
        model, batch_inputs, target_ids=dset.target_ids, is_s2s=False
    )
    
  2. Custom output callback Here, we allow the user to provide a callback function, taking in the result of the model’s forward call, and returning the logits of interest, with arbitrary post-processing in between.

    def default_output_callback(outputs: ModelOutput) -> Tensor:
       logits = outputs.logits if cfg.llm.is_s2s else outputs.logits[:, -1]
       target_logits = logits[:, dset.target_ids]
       return target_logits
    
    jacobian, f_mu = jacobian_mean(
        model, batch_inputs, output_callback=output_callback
    )
    

For the second step, we calculate the output logits’ covariance matrix.

bayesian_lora.main.variance(inputs, jacobian, factors: dict[str, tuple[Float[Tensor, 'l_in l_in_or_n_kfac'], Float[Tensor, 'l_out l_out_or_n_kfac']]], s2: Tensor, n_logits: int, n_lora: int, n_kfac: int, device: str)#

Calculates the variance matrix for performing (linearised) prediction.

Parameters:
  • inputs (dict) – tokenized batch of inputs (returned from a HF Tokenizer)

  • jacobian (dict) – a dictionary of first derivatives for each of the target module’s parameters

  • factors – dictionary of Kronecker factors

  • s2 – prior variance (scalar valued tensor)

  • n_logits – the number of logits to predict (e.g. the number of classes in your Categorical likelihood)

  • n_lora – rank used in the LoRA adapters

  • n_kfac – rank used for the low-rank approximation of large Kronekcer factors

  • device – device on which to accumulate the variance matrix