Probabilistic models

UnlockNN includes an extensible interface for adding uncertainty quantification to any trained keras model using a variational Gaussian process (VGP). A VGP is a modification of a Gaussian process that vastly improves model scalability; instead of using the entire training dataset as index points, a VGP uses variational inference to compute a smaller set of inducing index points that lead to a good approximation of the “full” Gaussian process.

By supplanting the last layer(s) of a keras model with a VGP, the model’s predictions become Gaussian distributions, rather than tensors. The mean of this distribution is the de facto “prediction” and the standard deviation indicates uncertainty. For example, two standard deviations gives the 95% confidence interval.

Important

The caveat of using a VGP is that the probabilistic model must undergo further training to determine the inducing index point locations and the kernel parameters, which can be expensive!

The module also contains a specific implementation for use with MEGNet, a highly performant graph neural network model for materials property prediction, with the MEGNetProbModel.

The ProbNN API

To add uncertainty quantification to an arbitrary Keras model, you should create a class that inherits from ProbNN and implements ProbNN.train(), as well as appending any variables that are necessary for reloading the model to ProbNN.CONFIG_VARS.

class unlocknn.model.ProbNN(nn: ~keras.engine.training.Model, num_inducing_points: int, kernel: ~unlocknn.kernel_layers.KernelLayer = <unlocknn.kernel_layers.RBFKernelFn object>, latent_layer: ~typing.Union[str, int] = -2, target_shape: ~typing.Union[~typing.Tuple[int], int] = 1, metrics: ~typing.List[~typing.Union[str, ~keras.metrics.Metric]] = [], kl_weight: float = 1.0, optimizer: ~keras.optimizer_v2.optimizer_v2.OptimizerV2 = <keras.optimizer_v2.adam.Adam object>, index_initializer: ~typing.Optional[~keras.initializers.initializers_v2.Initializer] = None, use_normalization: bool = True)

Wrapper for creating a probabilistic NN model.

Parameters
  • nn – The base NN model to modify.

  • num_inducing_points – The number of inducing index points for the VGP.

  • kernel – A KernelLayer for the VGP to use.

  • latent_layer – The name or index of the layer of the NN to be fed into the VGP.

  • target_shape – The shape of the target values.

  • metrics – A list of metrics to record during training.

  • kl_weight – The relative weighting of the Kullback-Leibler divergence in the loss function.

  • optimizer – The model optimizer, needed for recompilation.

  • index_initializer – A custom initializer to use for the VGP index points.

  • use_normalization – Whether to use a BatchNormalization layer before the VGP. Recommended for better training efficiency.

  • compile – Whether to compile the model for training. Not needed when loading the model for inference only.

CONFIG_VARS

A list of attribute names, as strings, to include in metadata when saving. These variables will be saved in a config.json file and used when re-instantiating the model upon loading with load(): they are passed as keyword arguments.

Type

List[str]

ckpt_callback(ckpt_path: PathLike = 'checkpoint.h5')

Get the default configuration for a model checkpoint callback.

compile(new_kl_weight: ~typing.Optional[float] = None, optimizer: ~keras.optimizer_v2.optimizer_v2.OptimizerV2 = <keras.optimizer_v2.adam.Adam object>, new_metrics: ~typing.Optional[~typing.List[~typing.Union[str, ~keras.metrics.Metric]]] = None)

Compile the probabilistic NN.

Recompilation is required whenever layers are (un)frozen.

Parameters
  • new_kl_weight – The relative weighting of the Kullback-Leibler divergence in the loss function. Default (None) is to leave unchanged.

  • optimizer – The model optimizer, needed for recompilation.

  • new_metrics – New metrics with which to compile.

property config: Dict[str, Any]

Get the configuration parameters needed to save to disk.

classmethod load(save_path: PathLike, load_ckpt: bool = True, **kwargs) ProbNN

Load a ProbNN from disk.

Loaded models must be recompiled before training.

Parameters
  • save_path – The path to the model’s save directory.

  • load_ckpt – Whether to load the best checkpoint’s weights, instead of those saved at the time of the last save().

  • **kwargs – Keyword arguments required by subclasses.

Returns

The loaded model.

Raises

FileNotFoundError – If the save_path or any components do not exist.

property nn_frozen: bool

Determine whether all NN layers are frozen.

property norm_frozen: Optional[bool]

Determine whether the BatchNormalization layer is frozen, if it exists.

predict(input) ndarray

Predict target values and standard deviations for a given input.

Parameters

input – The input(s) to the model.

Returns

A numpy array containing predicted means and standard deviations.

save(save_path: PathLike, ckpt_path: Optional[PathLike] = 'checkpoint.h5')

Save the model to disk.

Parameters
  • save_path – The directory in which to save the model.

  • ckpt_path – Where to look for checkpoints, which will be copied over to the save directory for future usage. Specify ckpt_path=None if no checkpoints exist.

save_kernel(kernel_save_path: Path)

Save the VGP’s kernel to disk.

set_frozen(layers: Union[Literal['NN', 'VGP', 'Norm'], List[Literal['NN', 'VGP', 'Norm']]], freeze: bool = True, recompile: bool = True, **compilation_kwargs) None

Freeze or thaw probabilistic NN layers.

Parameters
  • layers – Name or list of names of layers to thaw.

  • freeze – Whether to freeze (True) or thaw (False) the layers.

  • recompile – Whether to recompile the model after the operation.

  • **compilation_kwargs – Keyword arguments to pass to compile().

Raises

ValueError – If one or more layers are invalid names.

abstract train(*args, **kwargs) None

Train the model.

This method should handle data processing, then call self.model.fit to train the underlying model.

update_pred_model(force_new: bool = False) None

Instantiate or update the predictor model.

The predictor model is saved in pred_model. This method is a workaround to reconcile the inability to save or train a model that returns the VGP distribution’s mean and standard deviation simultaneously.

This method creates a clone model and so it must be called before making a prediction whenever the model’s weights have changed. By default, the method checks whether the pre-existing pred_model’s weights are similar to the model’s weights before cloning, and skips execution if they are. Setting force_new=True skips this check.

Parameters

force_new – Whether to force the creation of a new model, skipping the weights equality check.

property vgp_frozen: bool

Determine whether the VGP is frozen.

MEGNet probabilistic model

Use MEGNetProbModel to add uncertainty quantification to a MEGNetModel:

class unlocknn.model.MEGNetProbModel(meg_model: ~megnet.models.megnet.MEGNetModel, num_inducing_points: int, kernel: ~unlocknn.kernel_layers.KernelLayer = <unlocknn.kernel_layers.RBFKernelFn object>, latent_layer: ~typing.Union[str, int] = -2, target_shape: ~typing.Optional[~typing.Union[~typing.Tuple[int], int]] = None, metrics: ~typing.List[~typing.Union[str, ~keras.metrics.Metric]] = [], kl_weight: float = 1.0, optimizer: ~keras.optimizer_v2.optimizer_v2.OptimizerV2 = <keras.optimizer_v2.adam.Adam object>, index_initializer: ~typing.Optional[~keras.initializers.initializers_v2.Initializer] = None, use_normalization: bool = True, **kwargs)

ProbNN for MEGNetModels.

Parameters
  • meg_model – The base MEGNetModel to modify.

  • num_inducing_points – The number of inducing index points for the VGP.

  • kernel – A KernelLayer for the VGP to use.

  • latent_layer – The name or index of the layer of the NN to be fed into the VGP.

  • target_shape – The shape of the target values.

  • metrics – A list of metrics to record during training.

  • kl_weight – The relative weighting of the Kullback-Leibler divergence in the loss function.

  • optimizer – The model optimizer, needed for recompilation.

  • index_initializer – A custom initializer to use for the VGP index points. See also unlocknn.initializers.

  • use_normalization – Whether to use a BatchNormalization layer before the VGP. Recommended for better training efficiency.

Warning

metrics are malfunctional and may give vastly incorrect values – use unlocknn.metrics.evaluate_uq_metrics() instead!

evaluate(eval_structs: List[Structure], eval_targets: List[Union[float, ndarray]], batch_size: int = 128, scrub_failed_structs: bool = False) Dict[str, float]

Evaluate model metrics.

Warning

This method is malfunctional and may give vastly incorrect values – use unlocknn.metrics.evaluate_uq_metrics() instead!

Parameters
  • eval_structs – Structures on which to evaluate performance.

  • eval_targets – True target values for structures.

  • batch_size – The batch size for training and validation.

  • scrub_failed_structures – Whether to discard structures that could not be converted to graphs.

Returns

Dictionary of {metric: value}.

classmethod load(save_path: PathLike, load_ckpt: bool = True) MEGNetProbModel

Load a MEGNetProbModel from disk.

Parameters
  • save_path – The path to the model’s save directory.

  • load_ckpt – Whether to load the best checkpoint’s weights, instead of those saved at the time of the last save().

Returns

The loaded model.

Raises

FileNotFoundError – If the save_path or any components do not exist.

property nn_frozen: bool

Determine whether all NN layers are frozen.

property norm_frozen: Optional[bool]

Determine whether the BatchNormalization layer is frozen, if it exists.

predict(input: Union[Structure, Dict[str, Union[ndarray, List[Union[int, float]]]], List[Union[Structure, Dict[str, Union[ndarray, List[Union[int, float]]]]]]], batch_size: int = 128) Tuple[ndarray, ndarray]

Predict target values and standard deviations for a given input.

Parameters
  • input – The input structure(s).

  • batch_size – The batch size for predictions.

Returns

The mean values of the predicted distribution(s).

The standard deviations of the predicted distribution(s).

Examples

Predict the formation energy of a binary compound with a 95% confidence interval (two standard deviations) uncertainty estimate:

>>> from unlocknn.download import load_data, load_pretrained
>>> binary_model = load_pretrained("binary_e_form")
>>> binary_data = load_data("binary_e_form")
>>> example_struct = binary_data.loc[0, "structure"]
>>> prediction, stddev = binary_model.predict(example_struct)
>>> print(
...     "Predicted formation energy = "
...     f"{prediction.item():.3f} ± {stddev.item() * 2:.3f} eV."
... )
Predicted formation energy = -0.736 ± 0.054 eV.
save(save_path: PathLike, ckpt_path: Optional[PathLike] = 'checkpoint.h5') None

Save the model to disk.

Parameters
  • save_path – The directory in which to save the model.

  • ckpt_path – Where to look for checkpoints, which will be copied over to the save directory for future usage. Specify ckpt_path=None if no checkpoints exist.

set_frozen(layers: Union[Literal['NN', 'VGP', 'Norm'], List[Literal['NN', 'VGP', 'Norm']]], freeze: bool = True, recompile: bool = True, **compilation_kwargs) None

Freeze or thaw probabilistic NN layers.

Parameters
  • layers – Name or list of names of layers to thaw.

  • freeze – Whether to freeze (True) or thaw (False) the layers.

  • recompile – Whether to recompile the model after the operation.

  • **compilation_kwargs – Keyword arguments to pass to compile().

Raises

ValueError – If one or more layers are invalid names.

train(inputs: List[Structure], targets: List[Union[float, ndarray]], epochs: int, val_inputs: Optional[List[Union[Structure, Dict[str, Union[ndarray, List[Union[int, float]]]]]]] = None, val_targets: Optional[List[Union[float, ndarray]]] = None, callbacks: List[Callback] = [], use_default_ckpt_handler: bool = True, ckpt_path: PathLike = 'checkpoint.h5', batch_size: int = 32, scrub_failed_structs: bool = False, verbose: Literal[0, 1, 2] = 2)

Train the model.

Parameters
  • inputs – A list of training crystal structures or graphs.

  • targets – A list of training target values.

  • epochs – The number of training epochs.

  • val_inputs – A list of validation crystal structures or graphs.

  • val_targets – A list of validation target values.

  • callbacks – A list of additional callbacks.

  • use_default_ckpt_handler – Whether to use the default checkpoint callback.

  • ckpt_path – Where to save checkpoints, if use_default_ckpt_handler=True.

  • batch_size – The batch size for training and validation.

  • scrub_failed_structures – Whether to discard structures that could not be converted to graphs.

  • verbose – The level of verbosity. See Keras’s documentation on Model.fit.

property vgp_frozen: bool

Determine whether the VGP is frozen.