Source code for botorch.models.contextual_multioutput

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
References

.. [Feng2020HDCPS]
    Q. Feng, B. Latham, H. Mao and E. Backshy. High-Dimensional Contextual Policy
    Search with Unknown Context Rewards using Bayesian Optimization.
    Advances in Neural Information Processing Systems 33, NeurIPS 2020.
"""

from typing import Any

import torch
from botorch.models.multitask import _compute_multitask_mean, MultiTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.constraints import Interval
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels.rbf_kernel import RBFKernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.module import Module
from linear_operator.operators import LinearOperator
from torch import Tensor
from torch.nn import ModuleList


[docs] class LCEMGP(MultiTaskGP): r"""The Multi-Task GP with the latent context embedding multioutput (LCE-M) kernel. See [Feng2020HDCPS]_ for a reference on the model and its use in Bayesian optimization. """ def __init__( self, train_X: Tensor, train_Y: Tensor, task_feature: int, train_Yvar: Tensor | None = None, mean_module: Module | None = None, covar_module: Module | None = None, likelihood: Likelihood | None = None, context_cat_feature: Tensor | None = None, context_emb_feature: Tensor | None = None, embs_dim_list: list[int] | None = None, output_tasks: list[int] | None = None, all_tasks: list[int] | None = None, outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT, input_transform: InputTransform | None = None, ) -> None: r""" Args: train_X: (n x d) X training data. train_Y: (n x 1) Y training data. task_feature: Column index of train_X to get context indices. train_Yvar: An optional (n x 1) tensor of observed variances of each training Y. If None, we infer the noise. Note that the inferred noise is common across all tasks. mean_module: The mean function to be used. Defaults to ``ConstantMean``. covar_module: The module for computing the covariance matrix between the non-task features. Defaults to ``RBFKernel``. likelihood: A likelihood. The default is selected based on ``train_Yvar``. If ``train_Yvar`` is None, a standard ``GaussianLikelihood`` with inferred noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used. context_cat_feature: (n_contexts x k) one-hot encoded context features. Rows are ordered by context indices, where k is the number of categorical variables. If None, task indices will be used and k = 1. context_emb_feature: (n_contexts x m) pre-given continuous embedding features. Rows are ordered by context indices. embs_dim_list: Embedding dimension for each categorical variable. The length equals k. If None, the embedding dimension is set to 1 for each categorical variable. output_tasks: A list of task indices for which to compute model outputs for. If omitted, return outputs for all task indices. all_tasks: By default, multi-task GPs infer the list of all tasks from the task features in ``train_X``. This is an experimental feature that enables creation of multi-task GPs with tasks that don't appear in the training data. Note that when a task is not observed, the corresponding task covariance will heavily depend on random initialization and may behave unexpectedly. outcome_transform: An outcome transform that is applied to the training data during instantiation and to the posterior during inference (that is, the ``Posterior`` obtained by calling ``.posterior`` on the model will be on the original scale). We use a ``Standardize`` transform if no ``outcome_transform`` is specified. Pass down ``None`` to use no outcome transform. input_transform: An input transform that is applied in the model's forward pass. """ super().__init__( train_X=train_X, train_Y=train_Y, task_feature=task_feature, train_Yvar=train_Yvar, mean_module=mean_module, covar_module=covar_module, likelihood=likelihood, output_tasks=output_tasks, all_tasks=all_tasks, outcome_transform=outcome_transform, input_transform=input_transform, ) # Overwriting the covar_module created in the parent class if covar_module is None: self.covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=self.num_non_task_features ) else: self.covar_module = covar_module self.device = train_X.device if all_tasks is None: all_tasks_tensor = train_X[:, task_feature].unique() self.all_tasks = all_tasks_tensor.to(dtype=torch.long).tolist() else: all_tasks_tensor = torch.tensor(all_tasks, dtype=torch.long) self.all_tasks = all_tasks self.all_tasks.sort() # These are the context indices. if context_cat_feature is None: context_cat_feature = all_tasks_tensor.unsqueeze(-1).to(device=self.device) self.context_cat_feature: Tensor = ( context_cat_feature # row indices = context indices ) self.context_emb_feature = context_emb_feature # construct emb_dims based on categorical features if embs_dim_list is None: # set embedding_dim = 1 for each categorical variable embs_dim_list = [1 for _i in range(context_cat_feature.size(1))] n_embs = sum(embs_dim_list) self.emb_dims = [ (len(context_cat_feature[:, i].unique()), embs_dim_list[i]) for i in range(context_cat_feature.size(1)) ] # construct embedding layer: need to handle multiple categorical features self.emb_layers = ModuleList( [ torch.nn.Embedding(num_embeddings=x, embedding_dim=y, max_norm=1.0) for x, y in self.emb_dims ] ) self.task_covar_module_base = RBFKernel( ard_num_dims=n_embs, lengthscale_constraint=Interval( 0.0, 2.0, transform=None, initial_value=1.0 ), ) self.to(train_X) def _eval_context_covar(self) -> LinearOperator: """Obtain the context covariance matrix, a linear operator with shape (num_contexts x num_contexts). This first generates the embedding features for all contexts, then evaluates the task covariance matrix with those embeddings to get the task covariance matrix. """ all_embs = self._task_embeddings() return self.task_covar_module_base(all_embs) def _task_embeddings(self) -> Tensor: """Generate embedding features for all contexts.""" embeddings = [ emb_layer( self.context_cat_feature[:, i].to( dtype=torch.long, device=self.device ) # pyre-ignore ) for i, emb_layer in enumerate(self.emb_layers) ] embeddings = torch.cat(embeddings, dim=1) # add given embeddings if any if self.context_emb_feature is not None: embeddings = torch.cat( [embeddings, self.context_emb_feature.to(self.device)], dim=1, # pyre-ignore ) return embeddings
[docs] def task_covar_module(self, task_idcs: Tensor) -> Tensor: r"""Compute the task covariance matrix for a given tensor of task / context indices. Args: task_idcs: Task index tensor of shape (n x 1) or (b x n x 1). Returns: Task covariance matrix of shape (b x n x n). """ # NOTE: This can probably be re-written more efficiently using # IndexKernel (or an IndexKernel subclass) and the ``evaluate_task_covar`` # and then have the forward pass evaluate a ProductKernel of the two. # This is a tensor of shape (num_tasks x num_tasks). covar_matrix = self._eval_context_covar().to_dense() # Here, we index into the base covar matrix to extract # the rows & columns corresponding to the task indices. # First indexing operation picks the rows for each index in # task indices (results in b x n x num_tasks). We then transpose # to make the picked rows into columns (b x num_tasks x n), and # pick the rows again to result in the final covariance matrix. # The result is a symmetric tensor of shape (b x n x n). # An alternative implementation could pick the columns directly # by moving the transpose operation into the index of gather, # however, this does not seem to make any noticeable difference. base_idx = task_idcs.squeeze(-1) expanded_idx = task_idcs.expand( *([-1] * (task_idcs.dim() - 1)), task_idcs.shape[-2] ) return ( covar_matrix[base_idx].transpose(-1, -2).gather(index=expanded_idx, dim=-2) )
[docs] def forward(self, x: Tensor) -> MultivariateNormal: if self.training: x = self.transform_inputs(x) x_before, task_idcs, x_after = self._split_inputs(x) x_basic = torch.cat([x_before, x_after], dim=-1) # Compute base mean using helper function mean_x = _compute_multitask_mean(self.mean_module, x_before, task_idcs, x_after) covar_x = self.covar_module(x_basic) # Compute task covariances covar_i = self.task_covar_module(task_idcs) # Combine the two in an ICM fashion covar = covar_x.mul(covar_i) return MultivariateNormal(mean_x, covar)
[docs] @classmethod def construct_inputs( cls, training_data: SupervisedDataset | MultiTaskDataset, task_feature: int, output_tasks: list[int] | None = None, context_cat_feature: Tensor | None = None, context_emb_feature: Tensor | None = None, embs_dim_list: list[int] | None = None, **kwargs, ) -> dict[str, Any]: r"""Construct ``Model`` keyword arguments from a dataset and other args. Args: training_data: A ``SupervisedDataset`` or a ``MultiTaskDataset``. task_feature: Column index of embedded task indicator features. output_tasks: A list of task indices for which to compute model outputs for. If omitted, return outputs for all task indices. context_cat_feature: (n_contexts x k) one-hot encoded context features. Rows are ordered by context indices, where k is the number of categorical variables. If None, task indices will be used and k = 1. context_emb_feature: (n_contexts x m) pre-given continuous embedding features. Rows are ordered by context indices. embs_dim_list: Embedding dimension for each categorical variable. The length equals k. If None, the embedding dimension is set to 1 for each categorical variable. """ base_inputs = super().construct_inputs( training_data=training_data, task_feature=task_feature, output_tasks=output_tasks, **kwargs, ) if context_cat_feature is not None: base_inputs["context_cat_feature"] = context_cat_feature if context_emb_feature is not None: base_inputs["context_emb_feature"] = context_emb_feature if embs_dim_list is not None: base_inputs["embs_dim_list"] = embs_dim_list return base_inputs