# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
This file contains a readily usable implementation of the robust Gaussian process
model of [Ament2024pursuit]_, leveraging the Relevance Pursuit algorithm.
In particular, this file contains a ``RobustRelevancePursuitMixin`` class,
and a concrete implementation of a ``SingleTaskGP`` model,
``RobustRelevancePursuitSingleTaskGP``, which has the same API as a standard
``SingleTaskGP`` model, but automatically instantiates the robust likelihood
``SparseOutlierGaussianLikelihood`` and dispatches the relevance pursuit
algorithm during model fitting via ``fit_gpytorch_mll``.
Even though a standard ``SingleTaskGP`` model is expressive enough to implement
the robust model by changing the likelihood, its optimization is more complex.
So the main reason for the ``RobustRelevancePursuitMixin`` class is to hide
this complexity by using multiple dispatch of ``fit_gpytorch_mll``, which needs
to do two distinct operations in the context of the robust model:
(1) It needs to toggle the relevance pursuit discrete optimization algorithm that
changes the support, and as a sub-task,
(2) it needs to still carry out the numerical optimization of the
hyper-parameters given a fixed support, but still with a
``SparseOutlierGaussianLikelihood``. Since the types of the marginal
likelihood (``MarginalLogLikelihood``) and the likelihood
(``SparseOutlierGaussianLikelihood``) are the same in both calls, the
only way we can leverage the multiple dispatch mechanism is the model type.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Self
import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.models import SingleTaskGP
from botorch.models.likelihoods.sparse_outlier_noise import (
SparseOutlierGaussianLikelihood,
SparseOutlierNoise,
)
from botorch.models.model import Model
from botorch.models.relevance_pursuit import (
backward_relevance_pursuit,
get_posterior_over_support,
)
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.likelihoods import (
FixedNoiseGaussianLikelihood,
GaussianLikelihood,
Likelihood,
)
from gpytorch.means.mean import Mean
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.module import Module
from torch import Tensor
# default fractions of outliers to consider during relevance pursuit
FRACTIONS_OF_OUTLIERS = [
0.0,
0.05,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
]
[docs]
class RobustRelevancePursuitMixin(ABC):
"""A Mixin class for robust relevance pursuit models, which wraps a base likelihood
with a ``SparseOutlierGaussianLikelihood`` to detect outliers, and calls the
relevance pursuit algorithm during model fitting via ``fit_gpytorch_mll``.
This is distinct from the ``RelevancePursuitMixin`` class, which is a Mixin class to
equip a specific module (the likelihood, in the case of the robust model) with the
relevance pursuit algorithms.
"""
def __init__(
self,
base_likelihood: GaussianLikelihood | FixedNoiseGaussianLikelihood,
dim: int,
prior_mean_of_support: float | None = None,
convex_parameterization: bool = True,
cache_model_trace: bool = False,
) -> None:
"""Initializes a robust relevance pursuit model, which wraps a base likelihood
with a ``SparseOutlierGaussianLikelihood`` to detect outliers, and calls the
relevance pursuit algorithm during model fitting via ``fit_gpytorch_mll``.
For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222.
Args:
base_likelihood: The base likelihood that will be wrapped by a
``SparseOutlierGaussianLikelihood`` to detect outliers.
dim: The number of training data points, i.e. the maximum dimensionality
of the support set of the likelihood.
prior_mean_of_support: The mean value for the default exponential prior
distribution over the support size.
convex_parameterization: If True, use a convex parameterization of the
sparse noise model. See ``SparseOutlierGaussianLikelihood`` for details.
cache_model_trace: If True, cache the model trace during relevance pursuit.
"""
self.likelihood = SparseOutlierGaussianLikelihood(
base_noise=base_likelihood.noise_covar,
dim=dim,
convex_parameterization=convex_parameterization,
)
self.bmc_support_sizes: Tensor | None = None
self.bmc_probabilities: Tensor | None = None
self.cache_model_trace = cache_model_trace
self.model_trace: list[SingleTaskGP] | None = None
self.prior_mean_of_support: float = (
int(0.2 * dim) if prior_mean_of_support is None else prior_mean_of_support
)
[docs]
@abstractmethod
def to_standard_model(self) -> Model:
"""Converts this ``RobustRelevancePursuitMixin`` to an equivalent
standard model with the same robust likelihood and hyper-parameters. This
leaves the model structure and predictions unchanged, but leads
``fit_gpytorch_mll``'s dispatch to *numerically* optimize the
hyper-parameters of the model with a fixed support set, as opposed to
dispatching to the discrete optimization via the relevance pursuit
algorithm.
Returns:
A standard model.
"""
[docs]
def load_standard_model(self, standard_model: Model) -> Self:
"""Loads the state dict of a model into the ``RobustRelevancePursuitMixin``.
Args:
standard_model: A standard model with the same parameter structure and
likelihood as the ``RobustRelevancePursuitMixin`` model.
Returns:
The ``RobustRelevancePursuitMixin`` with the standard model's state dict.
"""
# need special case for the likelihood because raw_rho's shape changes
# throughout the optimization
self.likelihood = standard_model.likelihood
# overwrite state_dict in place
self.load_state_dict(standard_model.state_dict())
return self
[docs]
def custom_fit(
self,
mll: MarginalLogLikelihood,
*,
numbers_of_outliers: list[int] | None = None,
fractions_of_outliers: list[float] | None = None,
timeout_sec: float | None = None,
relevance_pursuit_optimizer: Callable = backward_relevance_pursuit,
reset_parameters: bool = True,
reset_dense_parameters: bool = False,
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: Mapping[str, Any] | None = None,
) -> MarginalLogLikelihood:
"""Fits a RobustRelevancePursuitGP model using the given marginal likelihood.
For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222.
Args:
mll: The marginal likelihood to fit.
numbers_of_outliers: An optional list of numbers of outliers to consider
during relevance pursuit. By default, the algorithm falls back to a
default list of fractions of outliers, see below.
fractions_of_outliers: An optional list of fractions of outliers to
consider if numbers_of_outliers is None. By default, the algorithm
uses ``[0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0]``.
relevance_pursuit_optimizer: The relevance pursuit optimizer to use.
reset_parameters: If True, reset sparse parameters after each iteration.
reset_dense_parameters: If True, reset dense parameters after each
iteration.
closure: A closure to compute loss and gradients.
optimizer: The numerical optimizer.
closure_kwargs: Additional arguments to pass to the closure.
optimizer_kwargs: Additional arguments to pass to fit_gpytorch_mll.
Returns:
The fitted marginal likelihood.
"""
if isinstance(mll, _ApproximateMarginalLogLikelihood):
raise UnsupportedError(
"Relevance Pursuit does not yet support approximate inference. "
)
sparse_module = SparseOutlierNoise._from_model(mll.model)
n = sparse_module.dim # equal to the number of training data points
if numbers_of_outliers is None:
if fractions_of_outliers is None:
fractions_of_outliers = FRACTIONS_OF_OUTLIERS
# list from which BMC chooses
numbers_of_outliers = [int(p * n) for p in fractions_of_outliers]
optimizer_kwargs_: dict[str, Any] = (
{} if optimizer_kwargs is None else dict(optimizer_kwargs)
)
if timeout_sec is not None:
optimizer_kwargs_["timeout_sec"] = timeout_sec / len(numbers_of_outliers)
# Need to convert model to avoid recursion through fit_gpytorch_mll,
# since relevance pursuit expects to call the base fit_gpytorch_mll.
original_model = mll.model # Robust Relevance Pursuit Model
mll.model = original_model.to_standard_model()
sparse_module = SparseOutlierNoise._from_model(mll.model)
sparse_module, model_trace = relevance_pursuit_optimizer(
sparse_module=sparse_module,
mll=mll,
sparsity_levels=numbers_of_outliers,
reset_parameters=reset_parameters,
reset_dense_parameters=reset_dense_parameters,
record_model_trace=True,
# These are the args of the canonical mll fit routine
closure=closure,
optimizer=optimizer,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs_,
)
# Bayesian model comparison
bmc_support_sizes, bmc_probabilities = get_posterior_over_support(
SparseOutlierNoise,
model_trace,
prior_mean_of_support=original_model.prior_mean_of_support,
)
map_index = torch.argmax(bmc_probabilities)
map_model = model_trace[map_index] # choosing model with highest BMC score
# overwrite mll.model with chosen model
mll.model = original_model # first restore original model pointer
mll.model.load_standard_model(map_model)
# Store the bmc results
mll.model.bmc_support_sizes = bmc_support_sizes
mll.model.bmc_probabilities = bmc_probabilities
if mll.model.cache_model_trace:
mll.model.model_trace = model_trace
return mll
[docs]
class RobustRelevancePursuitSingleTaskGP(SingleTaskGP, RobustRelevancePursuitMixin):
def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
likelihood: Likelihood | None = None,
covar_module: Module | None = None,
mean_module: Mean | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
input_transform: InputTransform | None = None,
convex_parameterization: bool = True,
prior_mean_of_support: float | None = None,
cache_model_trace: bool = False,
) -> None:
r"""A robust single-task GP model that toggles the relevance pursuit algorithm
during model fitting via ``fit_gpytorch_mll``.
For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222.
Args:
train_X: A ``batch_shape x n x d`` tensor of training features.
train_Y: A ``batch_shape x n x m`` tensor of training observations.
train_Yvar: An optional ``batch_shape x n x m`` tensor of observed
measurement noise.
likelihood: A base likelihood that will be wrapped by a
``SparseOutlierGaussianLikelihood`` to detect outliers. If omitted,
use a standard ``GaussianLikelihood`` with inferred noise level if
``train_Yvar`` is None, and a ``FixedNoiseGaussianLikelihood`` with the
given noise observations if ``train_Yvar`` is not None.
covar_module: The module computing the covariance (Kernel) matrix.
If omitted, uses an ``RBFKernel``.
mean_module: The mean function to be used. If omitted, use a
``ConstantMean``.
outcome_transform: An outcome transform that is applied to the
training data during instantiation and to the posterior during
inference (that is, the ``Posterior`` obtained by calling
``.posterior`` on the model will be on the original scale). We use a
``Standardize`` transform if no ``outcome_transform`` is specified.
Pass down ``None`` to use no outcome transform.
input_transform: An input transform that is applied in the model's
forward pass.
convex_parameterization: If True, use a convex parameterization of the
sparse noise model. See ``SparseOutlierGaussianLikelihood`` for details.
prior_mean_of_support: The mean value for the default exponential prior
distribution over the support size.
cache_model_trace: If True, cache the model trace during relevance pursuit.
Example:
>>> m = RobustRelevancePursuitSingleTaskGP(train_X=X, train_Y=Y)
>>> mll = ExactMarginalLogLikelihood(model=m, likelihood=m.likelihood)
>>> mll = fit_gpytorch_mll(mll)
"""
self._original_X = train_X
self._original_Y = train_Y
super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
likelihood=likelihood,
covar_module=covar_module,
mean_module=mean_module,
outcome_transform=outcome_transform,
input_transform=input_transform,
)
# After canonical GP is instantiated, modify the likelihood
RobustRelevancePursuitMixin.__init__(
self,
base_likelihood=self.likelihood,
dim=train_X.shape[-2],
prior_mean_of_support=prior_mean_of_support,
convex_parameterization=convex_parameterization,
cache_model_trace=cache_model_trace,
)
[docs]
def to_standard_model(self) -> Model:
"""Returns a standard SingleTaskGP with the same parameters as this model.
This is used to avoid recursion through the fit_gpytorch_mll dispatch."""
# don't need to put model into training mode to access the untransformed inputs,
# since we cached the original train_inputs
is_training = self.training
model = SingleTaskGP(
train_X=self._original_X,
train_Y=self._original_Y,
train_Yvar=None, # not needed because likelihood is already instantiated
likelihood=self.likelihood,
covar_module=self.covar_module,
mean_module=self.mean_module,
outcome_transform=getattr(self, "outcome_transform", None),
input_transform=getattr(self, "input_transform", None),
)
if not is_training:
model.eval()
return model