#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Assorted helper methods and objects for working with BoTorch models."""
from __future__ import annotations
import json
import warnings
from collections.abc import Iterator
from contextlib import contextmanager, ExitStack
from typing import TYPE_CHECKING
import torch
from botorch import settings
from botorch.exceptions import InputDataError, InputDataWarning
from botorch.settings import _Flag
from gpytorch import settings as gpt_settings
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.module import Module
from torch import Tensor
if TYPE_CHECKING:
from botorch.models.gpytorch import GPyTorchModel
def _make_X_full(X: Tensor, output_indices: list[int], tf: int) -> Tensor:
r"""Helper to construct input tensor with task indices.
Args:
X: The raw input tensor (without task information).
output_indices: The output indices to generate (passed in via ``posterior``).
tf: The task feature index.
Returns:
Tensor: The full input tensor for the multi-task model, including task
indices.
"""
index_shape = X.shape[:-1] + torch.Size([1])
indexers = (
torch.full(index_shape, fill_value=i, device=X.device, dtype=X.dtype)
for i in output_indices
)
X_l, X_r = X[..., :tf], X[..., tf:]
return torch.cat(
[torch.cat([X_l, indexer, X_r], dim=-1) for indexer in indexers], dim=-2
)
[docs]
def add_output_dim(X: Tensor, original_batch_shape: torch.Size) -> tuple[Tensor, int]:
r"""Insert the output dimension at the correct location.
The trailing batch dimensions of X must match the original batch dimensions
of the training inputs, but can also include extra batch dimensions.
Args:
X: A ``(new_batch_shape) x (original_batch_shape) x n x d`` tensor of
features.
original_batch_shape: the batch shape of the model's training inputs.
Returns:
2-element tuple containing
- A ``(new_batch_shape) x (original_batch_shape) x m x n x d`` tensor of
features.
- The index corresponding to the output dimension.
"""
X_batch_shape = X.shape[:-2]
if len(X_batch_shape) > 0 and len(original_batch_shape) > 0:
# check that X_batch_shape supports broadcasting or augments
# original_batch_shape with extra batch dims
try:
torch.broadcast_shapes(X_batch_shape, original_batch_shape)
except RuntimeError:
raise RuntimeError(
"The trailing batch dimensions of X must match the trailing "
f"batch dimensions of the training inputs. Got {X.shape=} "
f"and {original_batch_shape=}."
)
# insert ``m`` dimension
X = X.unsqueeze(-3)
output_dim_idx = max(len(original_batch_shape), len(X_batch_shape))
return X, output_dim_idx
[docs]
def check_no_nans(Z: Tensor) -> None:
r"""Check that tensor does not contain NaN values.
Raises an InputDataError if ``Z`` contains NaN values.
Args:
Z: The input tensor.
"""
if torch.any(torch.isnan(Z)).item():
raise InputDataError("Input data contains NaN values.")
[docs]
def check_min_max_scaling(
X: Tensor,
strict: bool = False,
atol: float = 1e-2,
raise_on_fail: bool = False,
ignore_dims: list[int] | None = None,
) -> None:
r"""Check that tensor is normalized to the unit cube.
Args:
X: A ``batch_shape x n x d`` input tensor. Typically the training inputs
of a model.
strict: If True, require ``X`` to be scaled to the unit cube (rather than
just to be contained within the unit cube).
atol: The tolerance for the boundary check. Only used if ``strict=True``.
raise_on_fail: If True, raise an exception instead of a warning.
ignore_dims: Subset of dimensions where the min-max scaling check is omitted.
"""
ignore_dims = ignore_dims or []
check_dims = list(set(range(X.shape[-1])) - set(ignore_dims))
if len(check_dims) == 0:
return None
with torch.no_grad():
X_check = X[..., check_dims]
Xmin = torch.min(X_check, dim=-1).values
Xmax = torch.max(X_check, dim=-1).values
msg = None
if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol:
msg = "scaled"
if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol):
msg = "contained"
if msg is not None:
# NOTE: If you update this message, update the warning filters as well.
# See https://github.com/meta-pytorch/botorch/pull/2508.
msg = (
f"Data (input features) is not {msg} to the unit cube. "
"Please consider min-max scaling the input data."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)
[docs]
def check_standardization(
Y: Tensor,
atol_mean: float = 1e-2,
atol_std: float = 1e-2,
raise_on_fail: bool = False,
) -> None:
r"""Check that tensor is standardized (zero mean, unit variance).
Args:
Y: The input tensor of shape ``batch_shape x n x m``. Typically the
train targets of a model. Standardization is checked across the
``n``-dimension.
atol_mean: The tolerance for the mean check.
atol_std: The tolerance for the std check.
raise_on_fail: If True, raise an exception instead of a warning.
"""
with torch.no_grad():
Ymean = torch.mean(Y, dim=-2)
mean_not_zero = torch.abs(Ymean).max() > atol_mean
if Y.shape[-2] <= 1:
if mean_not_zero:
# NOTE: If you update this message, update the warning filters as well.
# See https://github.com/meta-pytorch/botorch/pull/2508.
msg = (
f"Data (outcome observations) is not standardized (mean = {Ymean})."
" Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)
else:
Ystd = torch.std(Y, dim=-2)
std_not_one = torch.abs(Ystd - 1).max() > atol_std
if mean_not_zero or std_not_one:
# NOTE: If you update this message, update the warning filters as well.
# See https://github.com/meta-pytorch/botorch/pull/2508.
msg = (
"Data (outcome observations) is not standardized "
f"(std = {Ystd}, mean = {Ymean})."
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)
[docs]
def mod_batch_shape(module: Module, names: list[str], b: int) -> None:
r"""Recursive helper to modify gpytorch modules' batch shape attribute.
Modifies the module in-place.
Args:
module: The module to be modified.
names: The list of names to access the attribute. If the full name of
the module is ``"module.sub_module.leaf_module"``, this will be
``["sub_module", "leaf_module"]``.
b: The new size of the last element of the module's ``batch_shape``
attribute.
"""
if len(names) == 0:
return
m = getattr(module, names[0])
if len(names) == 1 and hasattr(m, "batch_shape") and len(m.batch_shape) > 0:
m.batch_shape = m.batch_shape[:-1] + torch.Size([b] if b > 0 else [])
else:
mod_batch_shape(module=m, names=names[1:], b=b)
[docs]
@contextmanager
def gpt_posterior_settings():
r"""Context manager for settings used for computing model posteriors."""
with ExitStack() as es:
if gpt_settings.debug.is_default():
es.enter_context(gpt_settings.debug(False))
if gpt_settings.fast_pred_var.is_default():
es.enter_context(gpt_settings.fast_pred_var())
es.enter_context(
gpt_settings.detach_test_caches(settings.propagate_grads.off())
)
yield
[docs]
def detect_duplicates(
X: Tensor,
rtol: float = 0,
atol: float = 1e-8,
) -> Iterator[tuple[int, int]]:
"""Returns an iterator over index pairs ``(duplicate index, original index)``
for all duplicate entries of ``X``. Supporting 2-d Tensor only.
Args:
X: the datapoints tensor with potential duplicated entries
rtol: relative tolerance
atol: absolute tolerance
"""
if len(X.shape) != 2:
raise ValueError("X must have 2 dimensions.")
tols = atol
if rtol:
rval = X.abs().max(dim=-1, keepdim=True).values
tols = tols + rtol * rval.max(rval.transpose(-1, -2))
n = X.shape[-2]
dist = torch.full((n, n), float("inf"), device=X.device, dtype=X.dtype)
dist[torch.triu_indices(n, n, offset=1).unbind()] = torch.nn.functional.pdist(
X, p=float("inf")
)
return (
(i, int(j))
# pyre-fixme[19]: Expected 1 positional argument.
for diff, j, i in zip(*(dist - tols).min(dim=-2), range(n))
if diff < 0
)
[docs]
def consolidate_duplicates(
X: Tensor, Y: Tensor, rtol: float = 0.0, atol: float = 1e-8
) -> tuple[Tensor, Tensor, Tensor]:
"""Drop duplicated Xs and update the indices tensor Y accordingly.
Supporting 2d Tensor only as in batch mode block design is not guaranteed.
Args:
X: the datapoints tensor
Y: the index tensor to be updated (e.g., pairwise comparisons)
rtol: relative tolerance
atol: absolute tolerance
Returns:
consolidated_X: the consolidated X
consolidated_Y: the consolidated Y (e.g., pairwise comparisons indices)
new_indices: new index of each original item in X, a tensor of size X.shape[-2]
"""
if len(X.shape) != 2:
raise ValueError("X must have 2 dimensions.")
n = X.shape[-2]
dup_map = dict(detect_duplicates(X=X, rtol=rtol, atol=atol))
# Handle edge cases conservatively
# If a item is in both dup set and kept set, do not remove it
common_set = set(dup_map.keys()).intersection(dup_map.values())
for k in list(dup_map.keys()):
if k in common_set or dup_map[k] in common_set:
del dup_map[k]
if dup_map:
dup_indices, kept_indices = zip(*dup_map.items())
unique_indices = sorted(set(range(n)) - set(dup_indices))
# After dropping the duplicates,
# the kept ones' indices may also change by being shifted up
new_idx_map = dict(zip(unique_indices, range(len(unique_indices))))
new_indices_for_dup = (new_idx_map[idx] for idx in kept_indices)
new_idx_map.update(dict(zip(dup_indices, new_indices_for_dup)))
consolidated_X = X[list(unique_indices), :]
consolidated_Y = torch.tensor(
[[new_idx_map[item.item()] for item in row] for row in Y.unbind()],
dtype=torch.long,
device=Y.device,
)
new_indices = (
torch.arange(n, dtype=torch.long)
.apply_(lambda x: new_idx_map[x])
.to(Y.device)
)
return consolidated_X, consolidated_Y, new_indices
else:
return X, Y, torch.arange(n, device=Y.device, dtype=Y.dtype)
[docs]
class fantasize(_Flag):
r"""A flag denoting whether we are currently in a ``fantasize`` context."""
_state: bool = False
[docs]
def get_task_value_remapping(
all_task_values: Tensor,
dtype: torch.dtype,
) -> Tensor | None:
"""Construct a mapping of task values to contiguous int-valued floats.
This function creates a mapping tensor that remaps task indices. All tasks
in ``all_task_values`` are mapped to contiguous integers starting from 0.
Task values not in ``all_task_values`` are mapped to NaN.
Args:
all_task_values: A sorted long-valued tensor of all possible task values
in the full task space.
dtype: The dtype of the model inputs (e.g. ``X``), which the new
task values should have mapped to (e.g. float, double).
Returns:
A tensor of shape ``all_task_values.max() + 1`` that maps task values
to new task values. The indexing operation ``mapper[task_value]``
will produce a tensor of new task values, of the same shape as
the original. All task values in ``all_task_values`` are mapped to
contiguous integers [0, 1, ..., n-1] where n is the number of tasks.
Task values not in ``all_task_values`` are mapped to NaN. Returns
``None`` when ``all_task_values`` equals [0, 1, ..., n-1].
"""
if dtype not in (torch.float, torch.double):
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
task_range = torch.arange(
len(all_task_values),
dtype=all_task_values.dtype,
device=all_task_values.device,
)
mapper = None
# if task values are not contiguous integers starting from 0,
# then map them to contiguous integers
if not torch.equal(task_range, all_task_values):
# Create a tensor that maps task values to new task values.
# The number of tasks should be small, so this should be quite efficient.
mapper = torch.full(
(int(all_task_values.max().item()) + 1,),
float("nan"),
dtype=dtype,
device=all_task_values.device,
)
mapper[all_task_values] = task_range.to(dtype=dtype)
return mapper
[docs]
def extract_targets_and_noise_single_output(model) -> tuple[Tensor, Tensor | None]:
r"""Extract targets and noise variance for single-output models (m=1).
Args:
model: A GPyTorch model.
Returns:
A tuple of (Y, Yvar) where Y and Yvar have shape ``batch_shape x n x 1``.
"""
Y = model.train_targets.unsqueeze(-1)
Yvar = None
if isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
Yvar = model.likelihood.noise_covar.noise.unsqueeze(-1)
return Y, Yvar
[docs]
def restore_targets_and_noise_single_output(
model, Y: Tensor, Yvar: Tensor | None, strict: bool
) -> None:
r"""Restore targets and noise variance for single-output models (m=1).
Args:
model: A GPyTorch model.
Y: Targets tensor in shape ``batch_shape x n x 1``.
Yvar: Optional noise variance tensor in shape ``batch_shape x n x 1``.
strict: Whether to strictly enforce shape constraints.
"""
Y = Y.squeeze(-1)
if Yvar is not None and isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
Yvar = Yvar.squeeze(-1)
model.likelihood.noise_covar.noise = Yvar
model.set_train_data(targets=Y, strict=strict)
[docs]
def get_data_for_optimization_help(
model: GPyTorchModel,
path: str = "optimization_help_data.json",
) -> None:
r"""Save model and training data as JSON for filing Optimization Help issues.
This function packages all the information needed to diagnose optimization
issues into a single JSON file that can be uploaded to a GitHub issue.
See the following tutorial for an example of how to use this file to get
help with optimization:
https://github.com/meta-pytorch/botorch/blob/main/tutorials/optimization_issue_diagnostics/optimization_issue_diagnostics.ipynb
Args:
model: A BoTorch model with training data.
path: File path where the JSON data will be saved.
Defaults to "optimization_help_data.json".
"""
train_X = model.train_inputs[0]
train_Y = model.train_targets
if train_Y.ndim == 1:
train_Y = train_Y.unsqueeze(-1)
dtype = str(train_X.dtype).replace("torch.", "")
state_dict = {
key: tensor.detach().cpu().tolist()
for key, tensor in model.state_dict().items()
}
data = {
"dtype": dtype,
"train_X": train_X.detach().cpu().tolist(),
"train_Y": train_Y.detach().cpu().tolist(),
"state_dict": state_dict,
}
with open(path, "w") as f:
json.dump(data, f)