Source code for botorch.models.empirical_gps.empirical_1d_gp

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Empirical one-dimensional Gaussian Process models.

These models use a collection of historical one-dimensional curves to define an
empirical prior mean and covariance for a ``SingleTaskGP``. They support both
single-output and batch-independent multi-output modeling.

References

.. [lin2026empirical]
    J. A. Lin, S. Ament, L. C. Tiao, D. Eriksson, M. Balandat, and E. Bakshy.
    Empirical Gaussian Processes. International Conference on Machine Learning
    (ICML), 2026. https://arxiv.org/abs/2602.12082
"""

from __future__ import annotations

from botorch.exceptions.errors import UnsupportedError
from botorch.models import SingleTaskGP
from botorch.models.empirical_gps.utils import (
    center_curves,
    compute_basis_matrix,
    compute_orthogonal_basis,
    compute_sample_covariance,
    extract_slice_for_interp,
    instantiate_ard,
    LinearInterpolation1D,
)
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.means import Mean
from torch import Tensor


# =============================================================================
# EmpiricalOneDimensionalGP Model
# =============================================================================


[docs] class EmpiricalOneDimensionalGP(SingleTaskGP): """Single-task GP with an empirical prior learned from historical 1-D curves. The prior mean and covariance are estimated from a collection of related historical curves and interpolated to the query inputs (see ``EmpiricalOneDimensionalMean`` and ``EmpiricalOneDimensionalKernel``), yielding a data-driven prior for settings where related curves have been observed previously. """ def __init__( self, train_X: Tensor, train_Y: Tensor, historical_X: Tensor, historical_Y: Tensor, train_Yvar: Tensor | None = None, likelihood: Likelihood | None = None, input_transform: InputTransform | None = None, outcome_transform: OutcomeTransform | None = None, mean_module: Mean | None = None, covar_module: Kernel | None = None, ard: bool = False, ) -> None: """Instantiates an empirical one-dimensional GP model. This GP uses historical one-dimensional curves to define the prior mean and covariance, following [lin2026empirical]_. Supports both single-output and multi-output with independent outputs (batch-independent GPs). Args: train_X: `batch_shape x n x 1`-dim Tensor of training inputs. train_Y: `batch_shape x n x m`-dim Tensor of training targets. historical_X: `num_progression x 1`-dim Tensor of progression values. historical_Y: `num_curves x num_progression x m`-dim Tensor of historical curves, where m is the number of outputs (m=1 for single-output). train_Yvar: `batch_shape x n x m`-dim Tensor of observation noise. likelihood: A likelihood. If omitted, use a standard GaussianLikelihood with inferred noise level if train_Yvar is None, and a FixedNoiseGaussianLikelihood with the given noise observations if train_Yvar is not None. input_transform: Input transform for the model. Not yet supported. outcome_transform: Outcome transform for the model. Not yet supported. mean_module: Optional custom mean module. covar_module: Optional custom covariance module. ard: Whether to use Automatic Relevance Determination on the basis. Raises: ValueError: If historical_Y is not 3-dimensional or if the number of outputs in historical_Y does not match train_Y. UnsupportedError: If input_transform or outcome_transform is provided. """ # Check for unsupported transforms if input_transform is not None: raise UnsupportedError( "input_transform is not yet supported for EmpiricalOneDimensionalGP." ) if outcome_transform is not None: raise UnsupportedError( "outcome_transform is not yet supported for EmpiricalOneDimensionalGP." ) # Validate historical_Y is 3D if historical_Y.ndim != 3: raise ValueError( f"Expected historical_Y to be 3-dim (num_curves x num_progression x m)," f" got {historical_Y.ndim}-dim." ) # Validate matching number of outputs num_outputs_train = train_Y.shape[-1] num_outputs_historical = historical_Y.shape[-1] if num_outputs_train != num_outputs_historical: raise ValueError( f"Number of outputs in train_Y ({num_outputs_train}) must match " f"historical_Y ({num_outputs_historical})." ) if covar_module is None: covar_module = EmpiricalOneDimensionalKernel( X_full=historical_X, Y_full=historical_Y, ard=ard, ) elif not isinstance(covar_module, EmpiricalOneDimensionalKernel): raise ValueError( "covar_module must be an instance of EmpiricalOneDimensionalKernel." ) elif ard != covar_module.ard: raise ValueError("`ard` argument must equal `covar_module.ard`.") if mean_module is None: mean_module = EmpiricalOneDimensionalMean( X_full=historical_X, Y_full=historical_Y, ) super().__init__( train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar, likelihood=likelihood, mean_module=mean_module, covar_module=covar_module, input_transform=input_transform, outcome_transform=outcome_transform, )
# ============================================================================= # Mean Module # =============================================================================
[docs] class EmpiricalOneDimensionalMean(Mean): """Empirical one-dimensional mean function. Computes the mean by averaging historical curves and interpolating. """ def __init__( self, X_full: Tensor, Y_full: Tensor, ): """Instantiates an empirical one-dimensional mean function. Args: X_full: `num_progression x 1`-dim Tensor of progression values. Y_full: `num_curves x num_progression x m`-dim Tensor of historical curves, where m is the number of outputs (m=1 for single-output). """ if Y_full.ndim != 3: raise ValueError( f"Expected Y_full to be 3-dim (num_curves x num_progression x m), " f"got {Y_full.ndim}-dim." ) super().__init__() self.X_full = X_full # num_progression x 1 self.num_outputs = Y_full.shape[-1] # Compute mean across curves: # num_curves x num_progression x m -> num_progression x m # Then transpose to m x num_progression for interpolation self.mean_full = Y_full.mean(dim=0).T # m x num_progression self.f = LinearInterpolation1D( self.X_full.squeeze(-1), self.mean_full, )
[docs] def forward(self, x: Tensor) -> Tensor: """Computes the mean function at x. Args: x: Input locations. For single-output, a `batch_shape x n x d`-dim Tensor. For multi-output (after the SingleTaskGP transform), a `batch_shape x m x n x d`-dim Tensor whose m slices are identical (replicated by SingleTaskGP). Returns: A `batch_shape x n`-dim Tensor for single-output, or a `batch_shape x m x n`-dim Tensor for multi-output. """ x_for_interp = extract_slice_for_interp(x, self.num_outputs) # Interpolate: result is m x batch_shape x n. Out-of-range inputs are # rejected by the interpolant itself (bounds_error=True). y = self.f(x_for_interp) # Rearrange: move m from position 0 to position -2 → batch_shape x m x n y = y.movedim(0, -2) # For single-output, squeeze the m=1 dimension if self.num_outputs == 1: y = y.squeeze(-2) return y
# ============================================================================= # Kernel Module # =============================================================================
[docs] class EmpiricalOneDimensionalKernel(Kernel): r"""Empirical One-Dimensional Kernel. This kernel computes the empirical covariance of one-dimensional curves at given progression points by interpolating historical curve data. By default, when `num_curves > num_progression` and `ard=False`, the kernel uses SVD decomposition to accelerate computation. This reduces complexity from O(n1 * num_curves * n2) to O(n1 * r * n2) where r = min(num_curves, num_progression). """ ard: bool = False def __init__( self, X_full: Tensor, Y_full: Tensor, ard: bool = False, curve_weights: Tensor | None = None, use_svd: bool | None = None, correction: int = 0, ) -> None: """Instantiates an empirical one-dimensional kernel. Args: X_full: `num_progression x 1`-dim Tensor of progression values. Y_full: `num_curves x num_progression x m`-dim Tensor of historical curves, where m is the number of outputs (m=1 for single-output). ard: Whether to use Automatic Relevance Determination (ARD). curve_weights: `num_curves`-dim Tensor of ARD weights. use_svd: Whether to use SVD acceleration. If None (default), SVD is used when num_curves > num_progression and ard=False. If True or False, directly toggles SVD on or off. Note: using ARD on the SVD basis implies a different prior than ARD on the original basis. This flag explicitly allows both approaches. correction: Degree of freedom correction to use for the computation of sample covariance, see `compute_sample_covariance` for details. """ if Y_full.ndim != 3: raise ValueError( f"Expected Y_full to be 3-dim (num_curves x num_progression x m), " f"got {Y_full.ndim}-dim." ) super().__init__() self.X_full = X_full self.num_curves = Y_full.shape[0] self.correction = correction self.num_outputs = Y_full.shape[-1] num_progression = Y_full.shape[1] # Center curves across the curve dimension (dim 0) _, Y_centered = center_curves(Y_full, curve_dim=0) # Reshape for interpolation: m x num_curves x num_progression Y_for_interp = Y_centered.movedim(-1, 0) # Apply SVD if requested (must be after reshaping for correct batched SVD) if use_svd is None: # Default: use SVD when num_curves > num_progression and ard is False self._use_svd = not ard and self.num_curves > num_progression else: self._use_svd = use_svd if self._use_svd: Y_for_interp = compute_orthogonal_basis(Y_for_interp, method="eigh") # After SVD, curve dim becomes r = min(num_curves, num_progression) self._effective_num_curves = min(self.num_curves, num_progression) else: self._effective_num_curves = self.num_curves self.f = LinearInterpolation1D(self.X_full.squeeze(-1), Y_for_interp) self.Y_full = Y_full self.Y_full_centered = Y_centered if ard: # When using SVD + ARD, apply ARD weights to the SVD basis vectors # which have dimension r = min(num_curves, num_progression) instantiate_ard( obj=self, num_curves=self._effective_num_curves, curve_weights=curve_weights, dtype=Y_full.dtype, device=Y_full.device, ) else: self.curve_weights = curve_weights self.ard = False @property def use_svd(self) -> bool: """A Boolean indicating whether the kernel uses the SVD efficiency technique.""" return self._use_svd
[docs] def forward( self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, ) -> Tensor: """Computes the kernel matrix k(x1, x2). Args: x1: Input tensor. For single-output, `batch_shape x n1 x d`; for multi-output, `batch_shape x m x n1 x d` (after `multioutput_to_batch_mode_transform` has been applied). x2: Input tensor with the same shape structure as x1. diag: If True, only returns the diagonal. last_dim_is_batch: If True, treats the last dimension as batch. Returns: A `batch_shape x n1 x n2`-dim covariance matrix for single-output, or a `batch_shape x m x n1 x n2`-dim covariance matrix for multi-output. If diag=True, returns the diagonal with one fewer trailing dimension. """ if last_dim_is_batch: raise NotImplementedError( "last_dim_is_batch=True not supported by EmpiricalOneDimensionalKernel." ) # Prepare inputs for interpolation (extracts slice for multi-output) x1_for_interp = extract_slice_for_interp(x1, self.num_outputs) x2_for_interp = ( x1_for_interp if x2 is x1 else extract_slice_for_interp(x2, self.num_outputs) ) # Compute basis matrices using shared helper Ux1 = compute_basis_matrix( f=self.f, x=x1_for_interp, num_outputs=self.num_outputs, curve_weights=self.curve_weights, ) Ux2 = ( Ux1 if x2_for_interp is x1_for_interp else compute_basis_matrix( f=self.f, x=x2_for_interp, num_outputs=self.num_outputs, curve_weights=self.curve_weights, ) ) # Compute sample covariance K = compute_sample_covariance( U1=Ux1, U2=None if x2_for_interp is x1_for_interp else Ux2, num_curves=self.num_curves, diag=diag, correction=self.correction, ) # Rearrange: move m from position 0 to position -3 → batch_shape x m x n1 x n2 if diag: K = K.movedim(0, -2) # batch_shape x m x n else: K = K.movedim(0, -3) # batch_shape x m x n1 x n2 # For single-output, squeeze the m=1 dimension if self.num_outputs == 1: if diag: K = K.squeeze(-2) else: K = K.squeeze(-3) return K