#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
References
.. [lin2025scalable]
J. A. Lin, S. Ament, M. Balandat, D. Eriksson, J. M. Hernández-Lobato, E. Bakshy.
Scalable Gaussian Processes with Latent Kronecker Structure.
International Conference on Machine Learning 2025.
.. [lin2024scaling]
J. A. Lin, S. Ament, M. Balandat, E. Bakshy. Scaling Gaussian Processes
for Learning Curve Prediction via Latent Kronecker Structure. NeurIPS 2024
Bayesian Decision-making and Uncertainty Workshop.
.. [lin2023sampling]
J. A. Lin, J. Antorán, s. Padhy, D. Janz, J. M. Hernández-Lobato, A. Terenin.
Sampling from Gaussian Process Posterior using Stochastic Gradient Descent.
Advances in Neural Information Processing Systems 2023.
"""
import contextlib
import warnings
from typing import Any
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import FantasizeMixin, Model
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.latent_kronecker import LatentKroneckerGPPosterior
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.distributions import Distribution, MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.means import Mean, ZeroMean
from gpytorch.models.exact_gp import ExactGP
from gpytorch.module import Module
from linear_operator import settings
from linear_operator.operators import (
ConstantDiagLinearOperator,
KroneckerProductLinearOperator,
LinearOperator,
MaskedLinearOperator,
)
from torch import Tensor
[docs]
class LatentKroneckerGP(GPyTorchModel, ExactGP, FantasizeMixin):
r"""
A multi-task GP model which uses Kronecker structure despite missing entries.
Leverages pathwise conditioning and iterative linear system solvers to
efficiently draw samples from the GP posterior. See [lin2024scaling]_
and [lin2025scalable]_ for details.
For more information about pathwise conditioning, see [wilson2021pathwise]_
and [Maddox2021bohdo]_. Details about iterative linear system solvers for GPs
with pathwise conditioning can be found in [lin2023sampling]_.
NOTE: This model requires iterative methods for efficient posterior inference.
To enable iterative methods, the ``use_iterative_methods`` helper function can be
used as a context manager.
Example:
>>> model = LatentKroneckerGP(train_X, train_T, train_Y)
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> with model.use_iterative_methods():
>>> fit_gpytorch_mll(mll)
>>> samples = model.posterior(test_X, test_T).rsample()
"""
def __init__(
self,
train_X: Tensor,
train_T: Tensor,
train_Y: Tensor,
likelihood: Likelihood | None = None,
mean_module_X: Mean | None = None,
mean_module_T: Mean | None = None,
covar_module_X: Module | None = None,
covar_module_T: Module | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
) -> None:
r"""
Args:
train_X: A ``batch_shape x n x d`` tensor of training features.
train_T: A ``batch_shape x t x 1`` tensor of training time steps.
train_Y: A ``batch_shape x n x t`` tensor of training observations,
corresponding to the Cartesian product of ``train_X`` and ``train_T``.
likelihood: A likelihood. If omitted, use a standard
``GaussianLikelihood`` with inferred homoskedastic noise level.
mean_module_X: The mean function to be used for X.
If omitted, a ``ZeroMean`` will be used.
mean_module_T: The mean function to be used for T.
If omitted, a ``ZeroMean`` will be used.
covar_module_X: The module computing the covariance matrix of X.
If omitted, a ``MaternKernel`` will be used.
covar_module_T: The module computing the covariance matrix of T.
If omitted, a ``MaternKernel`` wrapped in a ``ScaleKernel``
will be used.
input_transform: An input transform that is applied to X in the
model's forward pass.
outcome_transform: An outcome transform that is applied to Y.
Note that ``.train()`` will be called on the outcome transform during
instantiation of the model.
"""
with torch.no_grad():
# transform inputs here to check resulting shapes
# actual transforms will be applied in forward() and posterior()
transformed_X = self.transform_inputs(
X=train_X, input_transform=input_transform
)
self._validate_tensor_args(X=transformed_X, Y=train_Y)
batch_shape, ard_num_dims = transformed_X.shape[:-2], transformed_X.shape[-1]
self._num_outputs = train_Y.shape[-1]
expected_shape = torch.Size([*batch_shape, self._num_outputs, 1])
train_T = torch.broadcast_to(train_T, (*batch_shape, *train_T.shape[-2:]))
if train_T.shape != expected_shape:
raise BotorchTensorDimensionError(
f"Expected train_T with shape {expected_shape} but got {train_T.shape}."
)
mask_valid_batch = train_Y.isfinite()
# flatten over batch_shape
mask_valid_flat = mask_valid_batch.reshape(-1, *mask_valid_batch.shape[-2:])
# check that all masks are equal across batch_shape
if not torch.all((mask_valid_flat == mask_valid_flat[0]).all(dim=(-2, -1))):
raise ValueError(
"Pattern of missing values in train_Y must be equal across batch_shape."
)
self.mask_valid = mask_valid_flat[0].flatten()
train_Y = train_Y.reshape(*batch_shape, -1)[..., self.mask_valid]
if outcome_transform == DEFAULT:
outcome_transform = Standardize(m=1, batch_shape=batch_shape)
if outcome_transform is not None:
outcome_transform.train()
# transform outputs once and keep the results
train_Y, _ = outcome_transform(train_Y.unsqueeze(-1), X=transformed_X)
train_Y = train_Y.squeeze(-1)
if likelihood is None:
likelihood = GaussianLikelihood(batch_shape=batch_shape)
ExactGP.__init__(
self,
train_inputs=[train_X, train_T],
train_targets=train_Y,
likelihood=likelihood,
)
if mean_module_X is None:
mean_module_X = ZeroMean(batch_shape=batch_shape)
self.mean_module_X: Module = mean_module_X
if mean_module_T is None:
mean_module_T = ZeroMean(batch_shape=batch_shape)
self.mean_module_T: Module = mean_module_T
if covar_module_X is None:
covar_module_X = MaternKernel(
ard_num_dims=ard_num_dims, batch_shape=batch_shape
)
if covar_module_T is None:
covar_module_T = ScaleKernel(
base_kernel=MaternKernel(ard_num_dims=1, batch_shape=batch_shape),
)
self.covar_module_X: Module = covar_module_X
self.covar_module_T: Module = covar_module_T
if input_transform is not None:
self.input_transform = input_transform
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self.to(train_X)
@property
def train_T(self) -> Tensor:
"""The training T values (second element of train_inputs).
T is stored in train_inputs (alongside X) to enable GPyTorch's
multi-input prediction strategy via ``_get_test_prior_mean_and_covariances``.
This also allows using different T values at test time, e.g., evaluating
the posterior at a subset of task indices.
The helper methods below (``transform_inputs``, ``_set_transformed_inputs``,
``_revert_to_original_inputs``) ensure T is preserved through BoTorch's
input transform machinery, which expects single-input models.
"""
return self.train_inputs[1]
def _is_T_input(self, X: Tensor) -> bool:
"""Check if X is the T input by identity comparison."""
return (
hasattr(self, "train_inputs")
and self.train_inputs is not None
and len(self.train_inputs) > 1
and X is self.train_inputs[1]
)
def _set_transformed_inputs(self) -> None:
r"""Transform X while preserving T in train_inputs."""
if not (hasattr(self, "train_inputs") and len(self.train_inputs) > 1):
return super()._set_transformed_inputs()
T = self.train_inputs[1]
super()._set_transformed_inputs()
# super() calls set_train_data which sets train_inputs = (X_tf,), losing T
if hasattr(self, "train_inputs"):
self.train_inputs = (self.train_inputs[0], T)
def _revert_to_original_inputs(self) -> None:
r"""Revert X while preserving T in train_inputs."""
T = (
self.train_inputs[1]
if (hasattr(self, "train_inputs") and len(self.train_inputs) > 1)
else None
)
super()._revert_to_original_inputs()
# super() calls set_train_data which sets train_inputs = (X,), losing T
if T is not None and hasattr(self, "train_inputs"):
self.train_inputs = (self.train_inputs[0], T)
[docs]
def use_iterative_methods(
self,
tol: float = 0.01,
max_iter: int = 10000,
covar_root_decomposition: bool = False,
log_prob: bool = True,
solves: bool = True,
):
with contextlib.ExitStack() as stack:
stack.enter_context(
settings.fast_computations(
covar_root_decomposition=covar_root_decomposition,
log_prob=log_prob,
solves=solves,
)
)
stack.enter_context(settings.cg_tolerance(tol))
stack.enter_context(settings.max_cg_iterations(max_iter))
return stack.pop_all()
def _get_mean(self, X: Tensor, T: Tensor, mask: Tensor | None = None) -> Tensor:
mean_X = self.mean_module_X(X).unsqueeze(-1)
mean_T = self.mean_module_T(T).unsqueeze(-1)
mean = KroneckerProductLinearOperator(mean_X, mean_T).squeeze(-1)
return mean[..., mask] if mask is not None else mean
def _get_test_prior_mean_and_covariances(
self,
train_inputs: list[Tensor],
test_inputs: list[Tensor],
**kwargs,
) -> tuple[
Tensor,
LinearOperator,
LinearOperator,
torch.Size,
torch.Size,
type[Distribution],
]:
"""Computes Kronecker-structured covariances with masking for posterior.
This enables proper posterior mean and variance computation while maintaining
the Kronecker structure for efficiency. The test_train_covar is masked on the
train dimension to handle missing observations.
Args:
train_inputs: List containing [X_train, T_train].
test_inputs: List containing [X_test, T_test].
**kwargs: Additional arguments (unused, kept for compatibility).
Returns:
A tuple containing:
- test_mean: The prior mean evaluated on the test set
- test_test_covar: Covariance between test points (Kronecker structure)
- test_train_covar: Covariance between test and train points (masked)
- batch_shape: The batch shape of the model
- test_shape: Shape of the test output
- posterior_class: MultivariateNormal
"""
X_train, T_train = train_inputs[0], train_inputs[1]
X_test, T_test = test_inputs[0], test_inputs[1]
# Compute Kronecker-structured covariances
K_X_test_test = self.covar_module_X(X_test)
K_T_test_test = self.covar_module_T(T_test)
K_X_test_train = self.covar_module_X(X_test, X_train)
K_T_test_train = self.covar_module_T(T_test, T_train)
test_test_covar = KroneckerProductLinearOperator(K_X_test_test, K_T_test_test)
test_train_covar_full = KroneckerProductLinearOperator(
K_X_test_train, K_T_test_train
)
# Apply masking for missing observations
# The train dimension needs masking, test dimension is full
n_test = X_test.shape[-2] * T_test.shape[-2]
# Create full test mask (all valid)
test_mask = torch.ones(n_test, dtype=torch.bool, device=X_test.device)
# Apply mask to test_train_covar (only train dimension masked)
test_train_covar = MaskedLinearOperator(
test_train_covar_full, row_mask=test_mask, col_mask=self.mask_valid
)
# Compute prior mean on test set
test_mean = self._get_mean(X_test, T_test)
batch_shape = torch.broadcast_shapes(X_train.shape[:-2], X_test.shape[:-2])
test_shape = torch.Size([n_test])
return (
test_mean,
test_test_covar,
test_train_covar,
batch_shape,
test_shape,
MultivariateNormal,
)
def __call__(self, *args, **kwargs):
"""Forward pass that handles optional T parameter.
Appends ``self.train_T`` when only X is provided. This is necessary
because ``fit_gpytorch_mll`` and the MLL training pipeline call
``model(train_X)`` with only the X input.
Args:
*args: Either (X,) or (X, T). If only X is provided, uses self.train_T.
"""
if len(args) == 1:
args = (args[0], self.train_T)
return ExactGP.__call__(self, *args, **kwargs)
[docs]
def forward(self, *args, **kwargs) -> MultivariateNormal:
r"""
Computes the joint distribution at the given input locations.
Args:
*args: Either (X,) for backward compatibility, or (X, T).
If only X is provided, uses self.train_T for T.
Returns:
MultivariateNormal: The joint distribution at the specified input locations.
"""
if len(args) == 1:
X = args[0]
T = self.train_T
else:
X, T = args[0], args[1]
if self.training:
X = self.transform_inputs(X)
mask = self.mask_valid
else:
num_outputs = X.shape[-2] * T.shape[-2]
mask = torch.ones(num_outputs, dtype=torch.bool, device=X.device)
mask[: self.mask_valid.shape[-1]] = self.mask_valid
mean = self._get_mean(X, T, mask=mask)
covar_X = self.covar_module_X(X)
covar_T = self.covar_module_T(T)
covar = KroneckerProductLinearOperator(covar_X, covar_T)
covar = MaskedLinearOperator(covar, row_mask=mask, col_mask=mask)
return MultivariateNormal(mean, covar)
[docs]
def posterior(
self,
X: Tensor,
T: Tensor | None = None,
observation_noise: bool | Tensor = False,
posterior_transform: PosteriorTransform | None = None,
**kwargs: Any,
) -> GPyTorchPosterior:
r"""Computes the posterior over model outputs at the provided points.
Leverages GPyTorch's inference stack with our custom Kronecker-structured
covariances (via the overridden ``_get_test_prior_mean_and_covariances``).
Sampling uses pathwise conditioning for efficiency.
NOTE: For efficient inference with large datasets, wrap the call in the
``model.use_iterative_methods()`` context manager, e.g.:
>>> with model.use_iterative_methods():
... posterior = model.posterior(X, T)
Args:
X: A ``(batch_shape) x q x d``-dim Tensor of test features.
T: A ``(batch_shape) x t x 1``-dim Tensor of test T values.
If None, defaults to using ``self.train_T``.
observation_noise: If True, add observation noise. Currently not
supported.
posterior_transform: An optional PosteriorTransform. Currently not
supported.
Returns:
A ``LatentKroneckerGPPosterior`` with proper mean/variance and efficient
pathwise sampling.
"""
if posterior_transform is not None:
raise NotImplementedError(
"Posterior transforms currently not supported for "
f"{self.__class__.__name__}"
)
if not isinstance(self.likelihood, GaussianLikelihood):
raise NotImplementedError(
"Only GaussianLikelihood currently supported for "
f"{self.__class__.__name__}"
)
if observation_noise is not False:
raise NotImplementedError(
"Observation noise currently not supported for "
f"{self.__class__.__name__}"
)
if T is None:
T = self.train_T
X_test = self.transform_inputs(X)
# Compute the real posterior distribution via GPyTorch's inference stack,
# which uses our overridden _get_test_prior_mean_and_covariances for
# Kronecker structure. This gives exact posterior mean and variance.
#
# NOTE: This eagerly computes the train-train solve (via GPyTorch's
# cached mean_cache), using either direct Cholesky or CG depending on
# whether use_iterative_methods() is active. This is the same system
# that pathwise sampling solves. The posterior covariance remains lazy
# (no Cholesky until .covariance_matrix is accessed; .variance only
# needs the diagonal). Sampling (rsample) still uses pathwise
# conditioning for efficiency, see
# LatentKroneckerGPPosterior.rsample_from_base_samples.
distribution = self(X_test, T)
return LatentKroneckerGPPosterior(self, distribution, X, T)
def _rsample_from_base_samples(
self,
X: Tensor,
T: Tensor,
base_samples: Tensor,
observation_noise: bool | Tensor = False,
) -> Tensor:
r"""Sample from the posterior distribution at the provided points ``X``
using Matheron's rule, requiring ``n + 2 n_train`` base samples.
Args:
X: A ``(batch_shape) x q x d``-dim Tensor, where ``d`` is the dimension
of the feature space and ``q`` is the number of points considered
jointly
T: A ``(batch_shape) x t x 1``-dim Tensor of ``T``-locations at which to
evaluate the posterior samples.
base_samples: A Tensor of ``N(0, I)`` base samples of shape
``sample_shape x base_sample_shape``, typically obtained from
a ``Sampler``. This is used for deterministic optimization.
Returns:
Samples from the posterior, a tensor of shape
``self._extended_shape(sample_shape=sample_shape)``.
"""
# toggle eval mode to switch the behavior of input / outcome transforms
# this also implicitly applies the input transform to the train_inputs
self.eval()
X_train = self.train_inputs[0]
X_test = self.transform_inputs(X)
n_train_full = X_train.shape[-2] * self._num_outputs
n_train = self.train_targets.shape[-1]
n_test = X_test.shape[-2] * T.shape[-2]
sample_shape = base_samples.shape[: -len(self.batch_shape) - 1]
w_train, eps_base, w_test = torch.split(
base_samples, [n_train_full, n_train, n_test], dim=-1
)
eps = torch.sqrt(self.likelihood.noise) * eps_base
# calculate prior sample evaluated at training data
K_train_train_X = self.covar_module_X(X_train)
K_train_train_T = self.covar_module_T(self.train_T)
K_train_train = KroneckerProductLinearOperator(K_train_train_X, K_train_train_T)
L_train_train_X = K_train_train_X.cholesky(upper=False)
L_train_train_T = K_train_train_T.cholesky(upper=False)
L_train_train = KroneckerProductLinearOperator(L_train_train_X, L_train_train_T)
m_train = self._get_mean(X_train, self.train_T, mask=self.mask_valid)
f_prior_train = L_train_train @ w_train.unsqueeze(-1)
f_prior_train = m_train + f_prior_train.squeeze(-1)[..., self.mask_valid]
# assemble and solve pathwise conditioning linear system
K_train_train_valid = MaskedLinearOperator(
K_train_train, row_mask=self.mask_valid, col_mask=self.mask_valid
)
noise_covar = ConstantDiagLinearOperator(
self.likelihood.noise
* torch.ones(*self.batch_shape, 1, dtype=X.dtype, device=X.device),
diag_shape=n_train,
)
H = K_train_train_valid + noise_covar
v = self.train_targets - (f_prior_train + eps)
# expand once here to avoid repeated expansion
# by MaskedLinearOperator later
H_inv_v = torch.zeros(
*sample_shape,
*self.batch_shape,
n_train_full,
dtype=X.dtype,
device=X.device,
)
with self.use_iterative_methods():
H_inv_v[..., self.mask_valid] = H.solve(v.unsqueeze(-1)).squeeze(-1)
# calculate prior sample evaluated at test data via conditional sampling
K_test_test_X = self.covar_module_X(X_test).evaluate_kernel()
K_test_test_T = self.covar_module_T(T).evaluate_kernel()
K_train_test_X = self.covar_module_X(X_train, X_test).evaluate_kernel()
K_train_test_T = self.covar_module_T(self.train_T, T).evaluate_kernel()
L_train_test_X = L_train_train_X.solve_triangular(
K_train_test_X.tensor, upper=False
)
L_train_test_T = L_train_train_T.solve_triangular(
K_train_test_T.tensor, upper=False
)
L_test_test_X = (
K_test_test_X - L_train_test_X.transpose(-2, -1) @ L_train_test_X
).cholesky(upper=False)
L_test_test_T = (
K_test_test_T - L_train_test_T.transpose(-2, -1) @ L_train_test_T
).cholesky(upper=False)
L_test_train = KroneckerProductLinearOperator(
L_train_test_X.transpose(-2, -1), L_train_test_T.transpose(-2, -1)
)
L_test_test = KroneckerProductLinearOperator(L_test_test_X, L_test_test_T)
# match dimensions for broadcasting
broadcast_shape = L_test_train.shape[:-2]
extra_batch_dims = len(broadcast_shape) - len(self.batch_shape)
for _ in range(extra_batch_dims):
w_train = w_train.unsqueeze(len(sample_shape))
w_test = w_test.unsqueeze(len(sample_shape))
H_inv_v = H_inv_v.unsqueeze(len(sample_shape))
m_test = self._get_mean(X_test, T)
f_prior_test = L_test_train @ w_train.unsqueeze(-1)
f_prior_test = f_prior_test + L_test_test @ w_test.unsqueeze(-1)
f_prior_test = m_test + f_prior_test.squeeze(-1)
K_train_test = KroneckerProductLinearOperator(K_train_test_X, K_train_test_T)
# no MaskedLinearOperator here because H_inv_v is already expanded
samples = K_train_test.transpose(-2, -1) @ H_inv_v.unsqueeze(-1)
samples = samples + f_prior_test.unsqueeze(-1)
# reshape samples to separate X and T dimensions
# samples.shape = (*sample_shape, *broadcast_shape, n_test_x * n_t, 1)
samples = samples.reshape(*samples.shape[:-2], X_test.shape[-2], T.shape[-2])
# samples.shape = (*sample_shape, *broadcast_shape, n_test_x, n_t)
if hasattr(self, "outcome_transform") and self.outcome_transform is not None:
samples, _ = self.outcome_transform.untransform(samples, X=X)
return samples
[docs]
def condition_on_observations(
self, X: Tensor, Y: Tensor, noise: Tensor | None = None, **kwargs: Any
) -> Model:
raise NotImplementedError(
f"Conditioning currently not supported for {self.__class__.__name__}"
)