4. Inference Overview#
Inferring physical galaxy parameters from photometric observations can be seen as the task of learning a mapping \(f : \mathcal{X} \to \Theta\), from the space of \(n\)-dimensional photometric observations \(\mathcal{X}\) (corresponding to \(n\) filters), to the space of physical parameters \(\Theta\) (such as mass, star formation, E(B-V), AGN disk inclination and so forth).
Since the photometric observations \(\mathbf{x}\in\mathcal{X}\) alone are unlikely to be sufficient to constrain the full range of physical parameters \(\theta \in \Theta\) that we’d like to infer, we resolve to output distributions physical parameters, conditioned on the information in the photometric observations. That is, the mapping \(f\) is one-to-many, and a reasonable way to deal with this is to work with distributions over the outputs \(p(\theta \vert \mathbf{x})\)
From the simulation section, we generated a dataset of \((\theta, \mathbf{x})\) pairs, \(\mathcal{D} = \big\{(\theta_{i}, \mathbf{x}_{i})\big\}_{i=1}^{N}\), giving us a fairly standard supervised machine learning setup.
We can appeal to the broad machine learning literature which presents many ways to tackle this problem, for instance using generative models or autoregressive models. Accordingly, to avoid a clash of notation, we will henceforth denote the physical galaxy parameters as \(\mathbf{y}\) (previously \(\theta\)). This is to match the machine learning nomenclature of denoting the outputs to be predicted as \(\mathbf{y}\), and the model parameters as \(\theta\). The inputs remain \(\mathbf{x}\).
4.1. Using the Models#
In line with this program’s conventions of using configuration classes rather
than command-line arguments, all the options for running inference can be set in
the config.py file.
There are a few classes to bear in mind:
4.2. ForwardModelParameters#
We re-use the definitions of the forward model parameters to describe which parameters to learn. This avoids repetition any differences between the forward model and the machine learning models.
We use the following heuristics to detect parameters:
if an element of the parameter dictionary from
ForwardModelParams().all_paramshas an'isfree': Trueproperty, then it is treated as a parameter to model. Note that this might come from a template library, and not explicitly defined inForwardModelParams.model_params!to determine the bounds on the acceptable range of these parameters, we look at the
prior.rangeproperty, which is implemented on all the Prospector priors.If the prior distribution for a free parameter is
LogUniformorLogNormal, then we treat this parameter as having log scaling. This means that to normalise these parameters, we first exponentiate their values (as well as the bounds on the distribution), before applying the scale and offset to constrain the values to the [0, 1] range.
4.2.1. General Inference Parameters#
This method’s parameters are contained in the InferenceParams class (the
base class for this is defined in agnfinder/inference/inference.py).
- class InferenceParams(ConfigClass)#
General parameters for the inference code.
- Parameters
model (model_t) – The model to use for inference.
split_ratio (int) – The dataset train/test split ratio.
logging_frequency (int) – How often (in iterations) to output logs during training.
dataset_loc (str) – Path to a
hdf5file or directory ofhdf5files.retrain_model (bool) – Whether to re-train an identically configured model.
use_existing_checkpoints (bool) – Whether to pick-up training from any existing model checkpoints or start from scratch.
overwrite_results (bool) – Whether to overwrite results from identical model.
ident (str) – An optional string to identify a specific training run when saved to disk.
catalogue_loc (str) – Catalogue of (real) observations (for prediction)
filters (FilterSet) – Used for loading catalogue of real observations (for prediction)
- Example
>>> class InferenceParams(inference.InferenceParams): ... model: model_t = san.SAN ... split_ratio: float = 0.8 ... logging_frequency: int = 10000 ... dataset_loc: str = './data/cubes/photometry_simulation.hdf5' ... retrain_model: bool = True ... overwrite_results: bool = False ... ident: str = 'an_informative_identifier' ... ... # Prediction: ... catalogue_loc: str = './data/DES_VIDEO_v1.0.1.fits' ... filters: FilterSet = Filters.DES # {Euclid, DES, Reliable, All}
Most argument names along with their corresponding type should be self-explanatory.
The dataset_loc property should point to the output of a simulation run
(that is, the output of running make sim). Please see the Photometry
Sampling section for more information about this.
When a model is initialised, a descriptive name is generated based on its
parameters. If the training method (trainmodel, see below) is called on a
model with identical parameters to a previously trained and saved model, and
the retrain_model argument is set to False, then we attempt to load (the
state_dict of) this previous identical model instead of training the model
immediately. If loading fails for some reason (e.g. the file does not exist),
then training proceeds as normal.
If retrain_model == True, then the overwrite_results argument specifies
what to do when saving the resulting model—if set to True, then the
previously saved model will be overwritten. If set to False, then a number
is appended to the current model’s name to make it unique.
Since a number is not particularly informative, you can also set a unique and
ideally informative identifier using the ident field to differentiate models
which might have identical parameters (e.g. trained on different datasets etc.).
When we want to use a (trained) model to predict galaxy parameters (e.g. median
or mode), we can specify the catalogue of galaxy observations that we would like
to run the model on using the catalogue_loc parameter. In order to load this
successfully, you must also specify the filters used.
4.2.2. Model Parameters#
In general, each different model will have a number of parameters which are unique to it. However, there are some common parameters which are shared across all the models in the codebase.
To reflect this, model parameters inherit a base ModelParams class, which
specifies things such as the datatype, device memory to use and so forth.
- class spt.inference.base.ModelParams#
Generic parameters shared by all models. Users will generally not initialise this class directly; rather classes inheriting it for specific models.
- Example
>>> class ExampleModelParams(ModelParams): ... epochs: int = 20 ... batch_size: int = 1024 ... dtype: torch.dtype = torch.float32 ... device: torch.device = torch.device("cuda") ... cond_dim: int = 8 ... data_dim: int = 9
- abstract property batch_size: int#
The mini-batch size
- abstract property cond_dim: int#
Length of 1D conditioning information vector
- abstract property data_dim: int#
Length of the perhaps (flattened) 1D data vector, y
- property device: torch.device#
The device on which to run this model.
- abstract property dtype: torch.dtype#
The data type to use with this model. e.g. torch.float32
- abstract property epochs: int#
The number of epochs to train the model for.
Putting this parameter here risks incurring a ‘type error’; this is really an inference parameter (how long we train the model for), however since this has such a large effect on the resulting saved model, we prefer to associate it with the model itself.
Since all models are concerned with learning a distribution \(p(\mathbf{y} \vert
\mathbf{x})\), for \(\mathbf{y} \in \mathbb{R}^{N}\) and \(\mathbf{x}
\in \mathbb{R}^{M}\), we can reliably set parameters data_dim = N and
cond_dim = M for all models.
Aside:
At first, putting
epochsin theModelParams(instead of theInferenceParams) might seem to commit a ‘type error’: surely the training duration has more to do with the training procedure than the model itself? Thebatch_sizeparameter might also seem similarly misplaced. Since these parameters have a large effect on model performance, I claim that they should be treated similarly to architectural parameters, and are therefore associated with a model.For instance, when we come to load a trained model, we do care how long it was trained for, therefore it makes more sense to associate this parameter with the model itself; treating it as a model parameter rather than merely a parameter of the training procedure.
4.2.3. Training the models#
Having configured the inference parameters, you will also need a dataset loaded
to train the model on. A utility function (utils.load_simulated_data) is
available to help with this.
To initialise a model, we pass an initialised model parameter class to the
model’s constructor. Now the trainmodel method can be called to run the
training procedure.
During training, models will save checkpoints after every epoch. This means that you can interrupt training at any time, and only lose the progress made during the current checkpoint. You can also later check for overfitting by loading the model state from an earlier point during training.
The checkpoints are saved in a directory with the same name as the final model
results; which is saved with an additional .py extension. If you are
re-training an identical parametrised model, the code will first attempt to load
an existing saved model before falling back to running the training procedure.
The following is a full example, using the SAN model:
import agnfinder.nbutils as nbu
# Configure the logger (defaults to INFO-level logs)
cfg.configure_logging()
# Initialise the inference, and model parameters; defined in config.py
ip = cfg.InferenceParams()
sp = cfg.SANParams()
# Get the dataloaders for training and testing
train_loader, test_loader = utils.load_simulated_data(
path=ip.dataset_loc,
split_ratio=ip.split_ratio,
batch_size=sp.batch_size,
normalise_phot=utils.normalise_phot_np,
transforms=[transforms.ToTensor()])
logging.info('Created data loaders')
# Initialise the model
model = SAN(sp)
logging.info('Initialised SAN model')
# Run the training procedure
model.trainmodel(train_loader, ip)
logging.info('Trained SAN model')
# (Example: use the model for something)
x, _ = nbu.new_sample(test_loader)
posterior_samples = model.sample(x, n_samples=1000)
logging.info('Successfully sampled from model')
4.3. Creating New Models#
To ensure that there are consistent interfaces for all the models (to the
benefit of users), and that common code is not duplicated between models (to the
benefit of developers), all the models implemented in the codebase inherit from
an abstract Model class (found in agnfinder/inference/inference.py:Model).
To create a new model, inherit the Model class and ensure that
you have implemented all the abstract properties and methods.
The following shows the constructor, and abstract methods of the Model
class.
- class Model(torch.nn.Module, ABC)#
Base model class for AGNFinder
- __init__(self, mp: ModelParams, overwrite_results: bool = False, logging_callbacks: list[Callable] = [])#
- Parameters
mp (ModelParams) – The model parameters.
overwrite_results (bool) – Overwrite previous results when saving.
logging_callbacks (list[Callable[[Model], None]]) – Functions executed when logging.
- name(self) str#
Returns a natural-language name for the model.
- __repr__(self) str#
Give a natural-language description of the model. Do include information such as
self.epochs,self.nameandself.batch_size, as well as other architecture-specific details for your specific model.
- fpath(self) str#
Returns a file path to save the model to, which should be unique for every different parametrisation of the model.
- preprocess(self, x: Tensor, y: Tensor) tuple[Tensor, Tensor]#
- Parameters
x (Tensor) – The inputs (usually photometric observations)
y (Tensor) – The targets (usually physical galaxy parameters)
- Returns
The pre-processed parameters (e.g. cast to a specific data type, re-ordered or placed on a specific device’s memory.)
- trainmodel(self, train_loader: DataLoader, ip: InferenceParams, *args, **kwargs) None#
- Parameters
train_loader (DataLoader) – The PyTorch DataLoader containing the training data.
ip (InferenceParams) – Inference parameters containing details of the training procedure.
Note that any additional model-specific arguments can also be provided using the
*argsand**kwargs.This method has a decorator applied in the superclass (which is inherited by all sub-classes) which takes care of saving the trained model to disk (using
Model.fpath), as well as loading up an existing model rather than repeating training.
- sample(self, x: Tensor, n_samples: int = 1000, *args, **kwargs) Tensor#
- Parameters
x (Tensor) – The conditioning data, \(\mathbf{x}\).
n_samples (int) – The number of samples to draw from the posterior.
A convenience method for drawing (conditional) samples from \(p(\mathbf{y} \vert \mathbf{x})\) for a single conditioning point.
Since different models may require additional parameters to arguments to perform the sampling, these can be provided using the
argsandkwargsparameters.This is the only function pertaining to the actual use of the models which is required to be consistent across models. Individual models may provide different methods to use them.
4.3.1. Estimating Parameters#
The final stage is to estimate (statistics of) the parameters for real observations. For example, we might be interested in the (principle) mode and median of a parameter’s distribution.
The code for doing this is in agnfinder/inference/parameter_estimation.py.
This will use the model specified in InferenceParams.model, and that model’s
corresponding configuration as defined in config.py.