Source code for botorch.models.fully_bayesian_multitask

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


r"""Multi-task Gaussian Process Regression models with fully Bayesian inference."""

from __future__ import annotations

from collections.abc import Mapping
from typing import Any, NoReturn, TypeVar

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import (
    _check_jax_available,
    _HAS_JAX,
    matern52_kernel,
    MCMC_DIM,
    MIN_INFERRED_NOISE_LEVEL,
    reshape_and_detach,
    SaasPyroModel,
)

if _HAS_JAX:
    import jax.numpy as jnp
    import numpyro
    import numpyro.distributions as numpyro_dist
from botorch.models.gpytorch import (
    BatchedMultiOutputGPyTorchModel,
    MultiTaskGPyTorchModel,
)
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import MaternKernel
from gpytorch.kernels.index_kernel import IndexKernel
from gpytorch.kernels.kernel import Kernel
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.means.mean import Mean
from gpytorch.means.multitask_mean import MultitaskMean
from torch import Tensor
from torch.nn.parameter import Parameter
from typing_extensions import Self

# Can replace with Self type once 3.11 is the minimum version
TSaasFullyBayesianMultiTaskGP = TypeVar(
    "TSaasFullyBayesianMultiTaskGP", bound="SaasFullyBayesianMultiTaskGP"
)


[docs] class MultiTaskPyroMixin: r"""Mixin with universal multi-task logic for PyroModel subclasses. Stores task-related attributes (``task_feature``, ``num_tasks``, ``task_rank``) and adjusts ``ard_num_dims`` to exclude the task column. Overrides ``sample_mean`` to return per-task means and ``_prepare_features`` to strip the task column. Place before the ``PyroModel`` subclass in the MRO. """
[docs] def set_inputs( self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor | None = None, task_feature: int | None = None, task_rank: int | None = None, all_tasks: list[int] | None = None, ) -> None: """Set training data and configure multi-task attributes. Args: train_X: Training inputs (n x (d + 1)), including a task column. train_Y: Training targets (n x 1). train_Yvar: Observed noise variance (n x 1). Inferred if None. task_feature: The index of the task feature column. task_rank: The number of learned task embeddings. Defaults to the number of tasks. all_tasks: A list of all task indices. If omitted, all tasks will be inferred from the task feature column of the training data. """ super().set_inputs(train_X, train_Y, train_Yvar) task_feature = task_feature % train_X.shape[-1] self.task_feature = task_feature if all_tasks is None: all_tasks = train_X[:, task_feature].unique().to(dtype=torch.long).tolist() self.num_tasks = len(all_tasks) self.task_rank = task_rank or self.num_tasks self.ard_num_dims = self.train_X.shape[-1] - 1
[docs] def sample_mean(self) -> jnp.ndarray: r"""Sample per-task mean constants. Returns a vector of shape ``(num_tasks,)`` with one mean per task. """ return numpyro.sample( "mean", numpyro_dist.Normal( jnp.array(0.0), jnp.array(1.0), ).expand((self.num_tasks,)), )
def _get_task_indices_and_base_idxr(self, **tkwargs: Any) -> tuple[Tensor, Tensor]: r"""Compute the task indices and the base feature index selector. Returns: A tuple of ``(task_indices, base_idxr)`` where ``task_indices`` are long-typed task assignments and ``base_idxr`` selects the non-task columns. """ base_idxr = torch.arange(self.ard_num_dims, device=tkwargs["device"]) base_idxr[self.task_feature :] += 1 task_indices = self.train_X[..., self.task_feature].to( device=tkwargs["device"], dtype=torch.long ) return task_indices, base_idxr def _get_task_indices_and_base_idxr_jax( self, ) -> tuple[jnp.ndarray, jnp.ndarray]: r"""JAX version of _get_task_indices_and_base_idxr for use in sample().""" base_idxr = jnp.arange(self.ard_num_dims) base_idxr = base_idxr.at[self.task_feature :].add(1) task_indices = self.train_X_jax[..., self.task_feature].astype(jnp.int32) return task_indices, base_idxr def _prepare_features(self, X: jnp.ndarray) -> jnp.ndarray: """Strip the task column from X, selecting only base features.""" _, base_idxr = self._get_task_indices_and_base_idxr_jax() return X[..., base_idxr]
[docs] class LatentFeatureMultiTaskPyroMixin(MultiTaskPyroMixin): r"""Mixin that adds ICM-style multi-task capabilities via latent features. Extends ``MultiTaskPyroMixin`` with an ICM task covariance using learned latent task embeddings and a Matern-5/2 task kernel. Place before the ``PyroModel`` subclass in the MRO:: class MultitaskSaasPyroModel(LatentFeatureMultiTaskPyroMixin, SaasPyroModel): ... Overrides the dispatch methods ``_maybe_multitask_transform``, ``_build_mean_module``, ``_build_multitask_covariance``, and ``get_dummy_mcmc_samples``. """
[docs] def sample_latent_features(self) -> jnp.ndarray: r"""Sample latent task feature embeddings.""" return numpyro.sample( "latent_features", numpyro_dist.Normal( jnp.array(0.0), jnp.array(1.0), ).expand((self.num_tasks, self.task_rank)), )
[docs] def sample_task_lengthscale( self, concentration: float = 6.0, rate: float = 3.0 ) -> jnp.ndarray: r"""Sample the task kernel lengthscale.""" return numpyro.sample( "task_lengthscale", numpyro_dist.Gamma( jnp.array(concentration), jnp.array(rate), ).expand((self.task_rank,)), )
def _build_task_covar(self) -> tuple[jnp.ndarray, jnp.ndarray]: r"""Sample latent features and task lengthscale and build n x n task covar. Returns: A tuple of ``(task_covar, task_indices)`` where ``task_covar`` is an ``n x n`` task covariance matrix and ``task_indices`` are the task assignments. """ task_indices, _ = self._get_task_indices_and_base_idxr_jax() task_latent_features = self.sample_latent_features()[task_indices] task_lengthscale = self.sample_task_lengthscale() task_covar = matern52_kernel( X=task_latent_features, lengthscale=task_lengthscale ) return task_covar, task_indices def _maybe_multitask_transform( self, K_noiseless: jnp.ndarray, mean: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray]: r"""Multiply K by task covariance and index mean by task assignments.""" task_covar, task_indices = self._build_task_covar() K_noiseless = K_noiseless * task_covar return K_noiseless, mean[task_indices] def _build_mean_module( self, mcmc_samples: dict[str, Tensor], batch_shape: torch.Size, **tkwargs: Any, ) -> Mean: """Build a ``MultitaskMean`` with per-task constants from MCMC samples.""" mean_module = MultitaskMean( base_means=ConstantMean(batch_shape=batch_shape), num_tasks=self.num_tasks, ).to(**tkwargs) for i in range(self.num_tasks): mean_module.base_means[i].constant.data = reshape_and_detach( target=mean_module.base_means[i].constant.data, new_value=mcmc_samples["mean"][:, i], ) return mean_module def _build_multitask_covariance( self, mcmc_samples: dict[str, Tensor], covar_module: Kernel, batch_shape: torch.Size, **tkwargs: Any, ) -> Kernel: """Build task IndexKernel and combine with data covariance.""" data_indices = torch.arange(self.train_X.shape[-1] - 1) data_indices[self.task_feature :] += 1 covar_module.active_dims = data_indices.to(device=tkwargs["device"]) latent_covar_module = MaternKernel( nu=2.5, ard_num_dims=self.task_rank, batch_shape=batch_shape, ).to(**tkwargs) latent_covar_module.lengthscale = reshape_and_detach( target=latent_covar_module.lengthscale, new_value=mcmc_samples["task_lengthscale"], ) latent_features = mcmc_samples["latent_features"] task_covar = latent_covar_module(latent_features) task_covar_module = IndexKernel( num_tasks=self.num_tasks, rank=self.task_rank, batch_shape=latent_features.shape[:-2], active_dims=torch.tensor([self.task_feature], device=tkwargs["device"]), ) task_covar_module.covar_factor = Parameter( task_covar.cholesky().to_dense().detach() ) task_covar_module = task_covar_module.to(**tkwargs) task_covar_module.var = torch.zeros_like(task_covar_module.var) covar_module = covar_module * task_covar_module return covar_module
[docs] def get_dummy_mcmc_samples( self, num_mcmc_samples: int, **tkwargs: Any, ) -> dict[str, Tensor]: """Return dummy MCMC samples for state dict loading. Calls ``super()`` for base model keys, then reshapes ``mean`` to ``(S, num_tasks)`` and adds ``task_lengthscale`` and ``latent_features``. """ mcmc_samples = super().get_dummy_mcmc_samples( num_mcmc_samples=num_mcmc_samples, **tkwargs ) mcmc_samples["mean"] = torch.ones(num_mcmc_samples, self.num_tasks, **tkwargs) mcmc_samples["task_lengthscale"] = torch.ones( num_mcmc_samples, self.task_rank, **tkwargs ) mcmc_samples["latent_features"] = torch.ones( num_mcmc_samples, self.num_tasks, self.task_rank, **tkwargs ) return mcmc_samples
[docs] class MultitaskSaasPyroModel(LatentFeatureMultiTaskPyroMixin, SaasPyroModel): r""" Multi-task SAAS model. Backward-compatible subclass that composes ``LatentFeatureMultiTaskPyroMixin`` with ``SaasPyroModel``. """ pass
[docs] class SaasFullyBayesianMultiTaskGP(MultiTaskGP): r"""A fully Bayesian multi-task GP model with the SAAS prior. This model assumes that the inputs have been normalized to [0, 1]^d and that the output has been stratified standardized to have zero mean and unit variance for each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data kernel by default. You are expected to use ``fit_fully_bayesian_model_nuts`` to fit this model as it isn't compatible with ``fit_gpytorch_mll``. Example: >>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2) >>> i1, i2 = torch.zeros(10, 1), torch.ones(20, 1) >>> train_X = torch.cat([ >>> torch.cat([X1, i1], -1), torch.cat([X2, i2], -1), >>> ]) >>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1) >>> train_Yvar = 0.01 * torch.ones_like(train_Y) >>> mtsaas_gp = SaasFullyBayesianMultiTaskGP( >>> train_X, train_Y, train_Yvar, task_feature=-1, >>> ) >>> fit_fully_bayesian_model_nuts(mtsaas_gp) >>> posterior = mtsaas_gp.posterior(test_X) """ _is_fully_bayesian = True _is_ensemble = True def __init__( self, train_X: Tensor, train_Y: Tensor, task_feature: int, train_Yvar: Tensor | None = None, output_tasks: list[int] | None = None, rank: int | None = None, all_tasks: list[int] | None = None, outcome_transform: OutcomeTransform | None = None, input_transform: InputTransform | None = None, pyro_model: MultitaskSaasPyroModel | None = None, validate_task_values: bool = True, ) -> None: r"""Initialize the fully Bayesian multi-task GP model. Args: train_X: Training inputs (n x (d + 1)) train_Y: Training targets (n x 1) train_Yvar: Observed noise variance (n x 1). If None, we infer the noise. Note that the inferred noise is common across all tasks. task_feature: The index of the task feature (``-d <= task_feature <= d``). output_tasks: A list of task indices for which to compute model outputs for. If omitted, return outputs for all task indices. rank: The num of learned task embeddings to be used in the task kernel. If omitted, use a full rank (i.e. number of tasks) kernel. all_tasks: A list of all task indices. If omitted, all tasks will be inferred from the task feature column of the training data. Used to inform the model about the total number of tasks, including any unobserved tasks. outcome_transform: An outcome transform that is applied to the training data during instantiation and to the posterior during inference (that is, the ``Posterior`` obtained by calling ``.posterior`` on the model will be on the original scale). Note that ``.train()`` will be called on the outcome transform during instantiation of the model. input_transform: An input transform that is applied to the inputs ``X`` in the model's forward pass. pyro_model: Optional ``PyroModel`` that has the same signature as ``MultitaskSaasPyroModel``. Defaults to ``MultitaskSaasPyroModel``. validate_task_values: If True, validate that the task values supplied in the input are expected tasks values. If false, unexpected task values will be mapped to the first output_task if supplied. """ _check_jax_available() if not ( train_X.ndim == train_Y.ndim == 2 and len(train_X) == len(train_Y) and train_Y.shape[-1] == 1 ): raise ValueError( "Expected train_X to have shape n x d and train_Y to have shape n x 1" ) if train_Yvar is not None and train_Y.shape != train_Yvar.shape: raise ValueError( "Expected train_Yvar to be None or have the same shape as train_Y" ) with torch.no_grad(): transformed_X = self.transform_inputs( X=train_X, input_transform=input_transform ) if outcome_transform is not None: outcome_transform.train() # Ensure we learn parameters here on init train_Y, train_Yvar = outcome_transform( Y=train_Y, Yvar=train_Yvar, X=transformed_X ) if train_Yvar is not None: # Clamp after transforming train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL) super().__init__( train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar, task_feature=task_feature, output_tasks=output_tasks, rank=rank, # We already transformed the data above, this avoids applying the # default ``Standardize`` transform twice. As outcome_transform is # set on ``self`` below, it will be applied to the posterior in the # ``posterior`` method of ``MultiTaskGP``. outcome_transform=None, all_tasks=all_tasks, validate_task_values=validate_task_values, ) self.to(train_X) self.mean_module = None self.covar_module = None self.likelihood = None if pyro_model is None: pyro_model = MultitaskSaasPyroModel() # apply task_mapper x_before, task_idcs, x_after = self._split_inputs(transformed_X) pyro_model.set_inputs( train_X=torch.cat([x_before, task_idcs, x_after], dim=-1), train_Y=train_Y, train_Yvar=train_Yvar, task_feature=task_feature, task_rank=self._rank, all_tasks=all_tasks, ) self.pyro_model: MultitaskSaasPyroModel = pyro_model if outcome_transform is not None: self.outcome_transform = outcome_transform if input_transform is not None: self.input_transform = input_transform
[docs] def train( self, mode: bool = True, reset: bool = True ) -> TSaasFullyBayesianMultiTaskGP: r"""Puts the model in ``train`` mode. Args: mode: A boolean indicating whether to put the model in training mode. reset: A boolean indicating whether to reset the model to its initial state. If ``mode`` is False, this argument is ignored. Returns: The model itself. """ super().train(mode=mode) if mode and reset: self.mean_module = None self.covar_module = None self.likelihood = None return self
@property def median_lengthscale(self) -> Tensor: r"""Median lengthscales across the MCMC samples.""" self._check_if_fitted() lengthscale = self.covar_module.kernels[0].base_kernel.lengthscale.clone() return lengthscale.median(0).values.squeeze(0) @property def num_mcmc_samples(self) -> int: r"""Number of MCMC samples in the model.""" self._check_if_fitted() return self.covar_module.kernels[0].batch_shape[0] @property def batch_shape(self) -> torch.Size: r"""Batch shape of the model, equal to the number of MCMC samples. Note that ``SaasFullyBayesianMultiTaskGP`` does not support batching over input data at this point. """ self._check_if_fitted() return torch.Size([self.num_mcmc_samples])
[docs] def fantasize(self, *args, **kwargs) -> NoReturn: raise NotImplementedError("Fantasize is not implemented!")
def _check_if_fitted(self): r"""Raise an exception if the model hasn't been fitted.""" if self.covar_module is None: raise RuntimeError( "Model has not been fitted. You need to call " "`fit_fully_bayesian_model_nuts` to fit the model." )
[docs] def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: r"""Load the MCMC hyperparameter samples into the model. This method will be called by ``fit_fully_bayesian_model_nuts`` when the model has been fitted in order to create a batched MultiTaskGP model. """ ( self.mean_module, self.covar_module, self.likelihood, _, ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
[docs] def eval(self) -> Self: r"""Puts the model in eval mode. Circumvents the need to call MultiTaskGP.eval(), which computes the task_covar_matrix for non-observed tasks. This is not needed for fully Bayesian models, since the non-observed tasks' covar factors are instead sampled. Returns: The model itself. """ self._check_if_fitted() return MultiTaskGPyTorchModel.eval(self)
[docs] def posterior( self, X: Tensor, output_indices: list[int] | None = None, observation_noise: bool = False, posterior_transform: PosteriorTransform | None = None, **kwargs: Any, ) -> GaussianMixturePosterior: r"""Computes the posterior over model outputs at the provided points. Returns: A ``GaussianMixturePosterior`` object. Includes observation noise if specified. """ self._check_if_fitted() posterior = super().posterior( X=X.unsqueeze(MCMC_DIM), output_indices=output_indices, observation_noise=observation_noise, posterior_transform=posterior_transform, **kwargs, ) posterior = GaussianMixturePosterior(distribution=posterior.distribution) return posterior
[docs] def forward(self, X: Tensor) -> MultivariateNormal: self._check_if_fitted() return super().forward(X)
[docs] def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): r"""Custom logic for loading the state dict. The standard approach of calling ``load_state_dict`` currently doesn't play well with the ``SaasFullyBayesianMultiTaskGP`` since the ``mean_module``, ``covar_module`` and ``likelihood`` aren't initialized until the model has been fitted. The reason for this is that we don't know the number of MCMC samples until NUTS is called. Given the state dict, we can initialize a new model with some dummy samples and then load the state dict into this model. This currently only works for a ``MultitaskSaasPyroModel`` and supporting more Pyro models likely requires moving the model construction logic into the Pyro model itself. TODO: If this were to inherit from ``SaasFullyBayesianSingleTaskGP``, we could simplify this method and eliminate some others. """ if not isinstance(self.pyro_model, MultitaskSaasPyroModel): raise NotImplementedError( # pragma: no cover "load_state_dict only works for MultitaskSaasPyroModel" ) raw_mean = state_dict["mean_module.base_means.0.raw_constant"] num_mcmc_samples = len(raw_mean) tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} mcmc_samples = self.pyro_model.get_dummy_mcmc_samples( num_mcmc_samples=num_mcmc_samples, **tkwargs ) self.load_mcmc_samples(mcmc_samples=mcmc_samples) # Load the actual samples from the state dict super().load_state_dict(state_dict=state_dict, strict=strict)
[docs] def condition_on_observations( self, X: Tensor, Y: Tensor, **kwargs: Any ) -> BatchedMultiOutputGPyTorchModel: """Conditions on additional observations for a Fully Bayesian model (either identical across models or unique per-model). Args: X: A ``batch_shape x num_samples x d``-dim Tensor, where ``d`` is the dimension of the feature space and ``batch_shape`` is the number of sampled models. Y: A ``batch_shape x num_samples x 1``-dim Tensor, where ``d`` is the dimension of the feature space and ``batch_shape`` is the number of sampled models. Returns: BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on given observations. The returned model has ``batch_shape`` copies of the training data in case of identical observations (and ``batch_shape`` training datasets otherwise). """ if X.ndim == 2 and Y.ndim == 2: # To avoid an error in GPyTorch when inferring the batch dimension, we add # the explicit batch shape here. The result is that the conditioned model # will have 'batch_shape' copies of the training data. X = X.repeat(self.batch_shape + (1, 1)) Y = Y.repeat(self.batch_shape + (1, 1)) elif X.ndim < Y.ndim: # We need to duplicate the training data to enable correct batch # size inference in gpytorch. X = X.repeat(*(Y.shape[:-2] + (1, 1))) return super().condition_on_observations(X, Y, **kwargs)