Probabilistic models¶
UnlockNN includes an extensible interface for adding uncertainty quantification to any trained keras model using a variational Gaussian process (VGP). A VGP is a modification of a Gaussian process that vastly improves model scalability; instead of using the entire training dataset as index points, a VGP uses variational inference to compute a smaller set of inducing index points that lead to a good approximation of the “full” Gaussian process.
By supplanting the last layer(s) of a keras model with a VGP, the model’s predictions become Gaussian distributions, rather than tensors. The mean of this distribution is the de facto “prediction” and the standard deviation indicates uncertainty. For example, two standard deviations gives the 95% confidence interval.
Important
The caveat of using a VGP is that the probabilistic model must undergo further training to determine the inducing index point locations and the kernel parameters, which can be expensive!
The module also contains a specific implementation for use with
MEGNet, a highly performant
graph neural network model for materials property prediction, with the
MEGNetProbModel.
The ProbNN API¶
To add uncertainty quantification to an arbitrary Keras model,
you should create a class that inherits from ProbNN
and implements ProbNN.train(), as well as appending
any variables that are necessary for reloading the model
to ProbNN.CONFIG_VARS.
- class unlocknn.model.ProbNN(nn: ~keras.engine.training.Model, num_inducing_points: int, kernel: ~unlocknn.kernel_layers.KernelLayer = <unlocknn.kernel_layers.RBFKernelFn object>, latent_layer: ~typing.Union[str, int] = -2, target_shape: ~typing.Union[~typing.Tuple[int], int] = 1, metrics: ~typing.List[~typing.Union[str, ~keras.metrics.Metric]] = [], kl_weight: float = 1.0, optimizer: ~keras.optimizer_v2.optimizer_v2.OptimizerV2 = <keras.optimizer_v2.adam.Adam object>, index_initializer: ~typing.Optional[~keras.initializers.initializers_v2.Initializer] = None, use_normalization: bool = True)¶
Wrapper for creating a probabilistic NN model.
- Parameters
nn – The base NN model to modify.
num_inducing_points – The number of inducing index points for the VGP.
kernel – A
KernelLayerfor the VGP to use.latent_layer – The name or index of the layer of the NN to be fed into the VGP.
target_shape – The shape of the target values.
metrics – A list of metrics to record during training.
kl_weight – The relative weighting of the Kullback-Leibler divergence in the loss function.
optimizer – The model optimizer, needed for recompilation.
index_initializer – A custom initializer to use for the VGP index points.
use_normalization – Whether to use a
BatchNormalizationlayer before the VGP. Recommended for better training efficiency.compile – Whether to compile the model for training. Not needed when loading the model for inference only.
- CONFIG_VARS¶
A list of attribute names, as strings, to include in metadata when saving. These variables will be saved in a
config.jsonfile and used when re-instantiating the model upon loading withload(): they are passed as keyword arguments.- Type
List[str]
- ckpt_callback(ckpt_path: PathLike = 'checkpoint.h5')¶
Get the default configuration for a model checkpoint callback.
- compile(new_kl_weight: ~typing.Optional[float] = None, optimizer: ~keras.optimizer_v2.optimizer_v2.OptimizerV2 = <keras.optimizer_v2.adam.Adam object>, new_metrics: ~typing.Optional[~typing.List[~typing.Union[str, ~keras.metrics.Metric]]] = None)¶
Compile the probabilistic NN.
Recompilation is required whenever layers are (un)frozen.
- Parameters
new_kl_weight – The relative weighting of the Kullback-Leibler divergence in the loss function. Default (
None) is to leave unchanged.optimizer – The model optimizer, needed for recompilation.
new_metrics – New metrics with which to compile.
- property config: Dict[str, Any]¶
Get the configuration parameters needed to save to disk.
- classmethod load(save_path: PathLike, load_ckpt: bool = True, **kwargs) ProbNN¶
Load a ProbNN from disk.
Loaded models must be recompiled before training.
- Parameters
save_path – The path to the model’s save directory.
load_ckpt – Whether to load the best checkpoint’s weights, instead of those saved at the time of the last
save().**kwargs – Keyword arguments required by subclasses.
- Returns
The loaded model.
- Raises
FileNotFoundError – If the
save_pathor any components do not exist.
- property nn_frozen: bool¶
Determine whether all NN layers are frozen.
- property norm_frozen: Optional[bool]¶
Determine whether the BatchNormalization layer is frozen, if it exists.
- predict(input) ndarray¶
Predict target values and standard deviations for a given input.
- Parameters
input – The input(s) to the model.
- Returns
A numpy array containing predicted means and standard deviations.
- save(save_path: PathLike, ckpt_path: Optional[PathLike] = 'checkpoint.h5')¶
Save the model to disk.
- Parameters
save_path – The directory in which to save the model.
ckpt_path – Where to look for checkpoints, which will be copied over to the save directory for future usage. Specify
ckpt_path=Noneif no checkpoints exist.
- save_kernel(kernel_save_path: Path)¶
Save the VGP’s kernel to disk.
- set_frozen(layers: Union[Literal['NN', 'VGP', 'Norm'], List[Literal['NN', 'VGP', 'Norm']]], freeze: bool = True, recompile: bool = True, **compilation_kwargs) None¶
Freeze or thaw probabilistic NN layers.
- Parameters
layers – Name or list of names of layers to thaw.
freeze – Whether to freeze (
True) or thaw (False) the layers.recompile – Whether to recompile the model after the operation.
**compilation_kwargs – Keyword arguments to pass to
compile().
- Raises
ValueError – If one or more
layersare invalid names.
- abstract train(*args, **kwargs) None¶
Train the model.
This method should handle data processing, then call
self.model.fitto train the underlying model.
- update_pred_model(force_new: bool = False) None¶
Instantiate or update the predictor model.
The predictor model is saved in
pred_model. This method is a workaround to reconcile the inability to save or train a model that returns the VGP distribution’s mean and standard deviation simultaneously.This method creates a clone model and so it must be called before making a prediction whenever the
model’s weights have changed. By default, the method checks whether the pre-existingpred_model’s weights are similar to themodel’s weights before cloning, and skips execution if they are. Settingforce_new=Trueskips this check.- Parameters
force_new – Whether to force the creation of a new model, skipping the weights equality check.
- property vgp_frozen: bool¶
Determine whether the VGP is frozen.
MEGNet probabilistic model¶
Use MEGNetProbModel to add uncertainty quantification
to a MEGNetModel:
- class unlocknn.model.MEGNetProbModel(meg_model: ~megnet.models.megnet.MEGNetModel, num_inducing_points: int, kernel: ~unlocknn.kernel_layers.KernelLayer = <unlocknn.kernel_layers.RBFKernelFn object>, latent_layer: ~typing.Union[str, int] = -2, target_shape: ~typing.Optional[~typing.Union[~typing.Tuple[int], int]] = None, metrics: ~typing.List[~typing.Union[str, ~keras.metrics.Metric]] = [], kl_weight: float = 1.0, optimizer: ~keras.optimizer_v2.optimizer_v2.OptimizerV2 = <keras.optimizer_v2.adam.Adam object>, index_initializer: ~typing.Optional[~keras.initializers.initializers_v2.Initializer] = None, use_normalization: bool = True, **kwargs)¶
ProbNNfor MEGNetModels.- Parameters
meg_model – The base
MEGNetModelto modify.num_inducing_points – The number of inducing index points for the VGP.
kernel – A
KernelLayerfor the VGP to use.latent_layer – The name or index of the layer of the NN to be fed into the VGP.
target_shape – The shape of the target values.
metrics – A list of metrics to record during training.
kl_weight – The relative weighting of the Kullback-Leibler divergence in the loss function.
optimizer – The model optimizer, needed for recompilation.
index_initializer – A custom initializer to use for the VGP index points. See also
unlocknn.initializers.use_normalization – Whether to use a
BatchNormalizationlayer before the VGP. Recommended for better training efficiency.
Warning
metricsare malfunctional and may give vastly incorrect values – useunlocknn.metrics.evaluate_uq_metrics()instead!- evaluate(eval_structs: List[Structure], eval_targets: List[Union[float, ndarray]], batch_size: int = 128, scrub_failed_structs: bool = False) Dict[str, float]¶
Evaluate model metrics.
Warning
This method is malfunctional and may give vastly incorrect values – use
unlocknn.metrics.evaluate_uq_metrics()instead!- Parameters
eval_structs – Structures on which to evaluate performance.
eval_targets – True target values for structures.
batch_size – The batch size for training and validation.
scrub_failed_structures – Whether to discard structures that could not be converted to graphs.
- Returns
Dictionary of
{metric: value}.
- classmethod load(save_path: PathLike, load_ckpt: bool = True) MEGNetProbModel¶
Load a MEGNetProbModel from disk.
- Parameters
save_path – The path to the model’s save directory.
load_ckpt – Whether to load the best checkpoint’s weights, instead of those saved at the time of the last
save().
- Returns
The loaded model.
- Raises
FileNotFoundError – If the
save_pathor any components do not exist.
- property nn_frozen: bool¶
Determine whether all NN layers are frozen.
- property norm_frozen: Optional[bool]¶
Determine whether the BatchNormalization layer is frozen, if it exists.
- predict(input: Union[Structure, Dict[str, Union[ndarray, List[Union[int, float]]]], List[Union[Structure, Dict[str, Union[ndarray, List[Union[int, float]]]]]]], batch_size: int = 128) Tuple[ndarray, ndarray]¶
Predict target values and standard deviations for a given input.
- Parameters
input – The input structure(s).
batch_size – The batch size for predictions.
- Returns
The mean values of the predicted distribution(s).
The standard deviations of the predicted distribution(s).
Examples
Predict the formation energy of a binary compound with a 95% confidence interval (two standard deviations) uncertainty estimate:
>>> from unlocknn.download import load_data, load_pretrained >>> binary_model = load_pretrained("binary_e_form") >>> binary_data = load_data("binary_e_form") >>> example_struct = binary_data.loc[0, "structure"] >>> prediction, stddev = binary_model.predict(example_struct) >>> print( ... "Predicted formation energy = " ... f"{prediction.item():.3f} ± {stddev.item() * 2:.3f} eV." ... ) Predicted formation energy = -0.736 ± 0.054 eV.
- save(save_path: PathLike, ckpt_path: Optional[PathLike] = 'checkpoint.h5') None¶
Save the model to disk.
- Parameters
save_path – The directory in which to save the model.
ckpt_path – Where to look for checkpoints, which will be copied over to the save directory for future usage. Specify
ckpt_path=Noneif no checkpoints exist.
- set_frozen(layers: Union[Literal['NN', 'VGP', 'Norm'], List[Literal['NN', 'VGP', 'Norm']]], freeze: bool = True, recompile: bool = True, **compilation_kwargs) None¶
Freeze or thaw probabilistic NN layers.
- Parameters
layers – Name or list of names of layers to thaw.
freeze – Whether to freeze (
True) or thaw (False) the layers.recompile – Whether to recompile the model after the operation.
**compilation_kwargs – Keyword arguments to pass to
compile().
- Raises
ValueError – If one or more
layersare invalid names.
- train(inputs: List[Structure], targets: List[Union[float, ndarray]], epochs: int, val_inputs: Optional[List[Union[Structure, Dict[str, Union[ndarray, List[Union[int, float]]]]]]] = None, val_targets: Optional[List[Union[float, ndarray]]] = None, callbacks: List[Callback] = [], use_default_ckpt_handler: bool = True, ckpt_path: PathLike = 'checkpoint.h5', batch_size: int = 32, scrub_failed_structs: bool = False, verbose: Literal[0, 1, 2] = 2)¶
Train the model.
- Parameters
inputs – A list of training crystal structures or graphs.
targets – A list of training target values.
epochs – The number of training epochs.
val_inputs – A list of validation crystal structures or graphs.
val_targets – A list of validation target values.
callbacks – A list of additional callbacks.
use_default_ckpt_handler – Whether to use the default checkpoint callback.
ckpt_path – Where to save checkpoints, if
use_default_ckpt_handler=True.batch_size – The batch size for training and validation.
scrub_failed_structures – Whether to discard structures that could not be converted to graphs.
verbose – The level of verbosity. See Keras’s documentation on
Model.fit.
- property vgp_frozen: bool¶
Determine whether the VGP is frozen.