#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Model fitting routines."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import partial
from itertools import filterfalse
from typing import Any, TYPE_CHECKING
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
if TYPE_CHECKING:
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.logging import logger
from botorch.models import SingleTaskGP
from botorch.models.map_saas import get_map_saas_model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
from botorch.optim.utils import (
_warning_handler_template,
get_parameters,
sample_all_priors,
)
from botorch.utils.context_managers import (
module_rollback_ctx,
parameter_rollback_ctx,
TensorCheckpoint,
)
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from torch import device, Tensor
from torch.nn import Parameter
from torch.utils.data import DataLoader
def _debug_warn(w: WarningMessage) -> bool:
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)):
return True
# TODO: Better handle cases where warning handling logic
# affects both debug and rethrow functions.
return False
def _rethrow_warn(w: WarningMessage) -> bool:
if not issubclass(w.category, OptimizationWarning):
return True
if "Optimization timed out after" in str(w.message):
return True
return False
DEFAULT_WARNING_HANDLER = partial(
_warning_handler_template,
debug=_debug_warn,
rethrow=_rethrow_warn,
)
[docs]
def fit_gpytorch_mll(
mll: MarginalLogLikelihood,
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable | None = None,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Clearing house for fitting models passed as GPyTorch MarginalLogLikelihoods.
If a model defines a ``custom_fit`` method, it will be called directly.
Otherwise, a fit method is determined based on the types of the model and
MLL.
Args:
mll: A GPyTorch MarginalLogLikelihood instance.
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' ``grad`` attributes. If no closure is
provided, one will be obtained by calling ``get_loss_closure_with_grads``.
optimizer: User specified optimization algorithm. When ``optimizer is None``,
this keyword argument is omitted when calling the underlying fit routine.
closure_kwargs: Keyword arguments passed when calling ``closure``.
optimizer_kwargs: A dictionary of keyword arguments passed when
calling ``optimizer``.
**kwargs: Keyword arguments passed to the underlying fit routine.
Unexpected keywords are ignored.
Returns:
The ``mll`` instance. If fitting succeeded, then ``mll`` will be in
evaluation mode, i.e. ``mll.training == False``. Otherwise, ``mll``
will be in training mode.
"""
if optimizer is not None: # defer to per-method defaults
kwargs["optimizer"] = optimizer
if hasattr(mll.model, "custom_fit"):
return mll.model.custom_fit(
mll=mll,
closure=closure,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
if isinstance(mll, SumMarginalLogLikelihood) and isinstance(mll.model, ModelListGP):
mll.train()
for sub_mll in mll.mlls:
fit_gpytorch_mll(
mll=sub_mll,
closure=closure,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll
if isinstance(mll, _ApproximateMarginalLogLikelihood):
return _fit_fallback_approximate(
mll=mll,
closure=closure,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
return _fit_fallback(
mll=mll,
closure=closure,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
def _fit_fallback(
mll: MarginalLogLikelihood,
*,
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
optimizer: Callable = fit_gpytorch_mll_scipy,
closure_kwargs: dict[str, Any] | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
max_attempts: int = 5,
pick_best_of_all_attempts: bool = False,
warning_handler: Callable[[WarningMessage], bool] = DEFAULT_WARNING_HANDLER,
caught_exception_types: tuple[type[BaseException], ...] = (NotPSDError,),
**ignore: Any,
) -> MarginalLogLikelihood:
r"""Generic fallback method for fitting Gaussian processes.
Attempts to fit a model using the provided optimizer, then determines whether or
not to retry by evaluating a given policy on emitted warning messages. The first
attempt is run using the initialized parameter values; subsequent attempts begin
by resampling tunable parameters.
Args:
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' ``grad`` attributes. If no closure is
provided, one will be obtained by calling ``get_loss_closure_with_grads``.
optimizer: The underlying optimization algorithm to run. Should return
an ``OptimizationResult`` object, whose ``fval`` field records the negative
MLL value. Defaults to ``fit_gpytorch_mll_scipy``.
closure_kwargs: Keyword arguments passed to ``closure``.
optimizer_kwargs: Keyword arguments passed to ``optimizer``.
max_attempts: The maximum number of fit attempts allowed. The attempt budget
is NOT shared between calls to this method.
pick_best_of_all_attempts: If True, the model will be fit
``max_attempts`` times, and the attempt that produces largest MLL
value will be returned. First attempt uses the initial hyper
parameter values, the subsequent attempts will call
``sample_all_priors`` to sample the initial values. If any attempt
produces an error, the resulting parameters are discarded. If
optimizer timeout is used, the ``timeout_sec`` will be used as is
for each attempt, and it should be manually adjusted accordingly.
warning_handler: A function used to filter warnings produced when calling
``optimizer``. Any unfiltered warnings (those for which ``warning_handler``
returns ``False``) will be rethrown and trigger a model fitting retry.
caught_exception_types: A tuple of exception types whose instances should
be logged at the ``DEBUG`` level.
**ignore: This function ignores unrecognized keyword arguments.
Returns:
The ``mll`` instance. If fitting succeeded, then ``mll`` will be in
evaluation mode, i.e. ``mll.training == False``. Otherwise, ``mll``
will be in training mode.
"""
# Setup
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
params_nograd: dict[str, Parameter] = None # pyre-ignore [9]
ckpt_nograd: dict[str, TensorCheckpoint] = None # pyre-ignore [9]
ckpt: dict[str, TensorCheckpoint] = None # pyre-ignore [9]
# Build closure. When no closure is provided and no closure_kwargs are
# needed, pass closure=None through to the optimizer so that it can use
# its own internal dispatch (e.g. batched independent fitting in
# fit_gpytorch_mll_scipy).
mll.train()
if closure is None and closure_kwargs is not None:
closure = get_loss_closure_with_grads(
mll, parameters=get_parameters(mll, requires_grad=True)
)
if closure_kwargs is not None:
closure = partial(closure, **closure_kwargs)
# Record best MLL & corresponding state dict.
best_mll: float = -float("inf")
best_state_dict = None
# Attempt to fit the model
for attempt in range(1, 1 + max_attempts):
# Wrap with rollback contextmanager so that each loop iteration reloads the
# original state_dict upon exiting (unless we clear ``ckpt``).
with module_rollback_ctx(mll, checkpoint=ckpt, device=device("cpu")) as ckpt:
if attempt > 1: # resample free parameters
if params_nograd is None:
params_nograd = get_parameters(mll, requires_grad=False)
if ckpt_nograd is None: # reuse primary checkpoint
ckpt_nograd = {name: ckpt[name] for name in params_nograd}
with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd):
sample_all_priors(mll.model)
try:
# Fit the model
with catch_warnings(record=True) as warning_list:
simplefilter("always", category=OptimizationWarning)
result = optimizer(mll, closure=closure, **optimizer_kwargs)
# Resolve warnings and determine whether or not to retry
success = True
for w in filterfalse(warning_handler, warning_list):
warn_explicit(str(w.message), w.category, w.filename, w.lineno)
success = False
if success and not pick_best_of_all_attempts:
# If not picking best of all attempts, return the first
# successful attempt.
ckpt.clear() # do not rollback upon exiting
return mll.eval()
elif success:
# Update best MLL and corresponding state dict.
# Optimizers minimize negative MLL, so we negate fval.
current_mll = -result.fval
if current_mll > best_mll:
best_mll = current_mll
# Deepcopy is important here, otherwise they get updated.
best_state_dict = deepcopy(mll.state_dict())
message = f"Fit attempt #{attempt}: New best MLL: {best_mll}."
else:
message = (
f"Fit attempt #{attempt}: Current MLL {current_mll} did "
f"not beat best MLL so far {best_mll}."
)
logger.debug(message)
# Ensure mll is in the right mode if going for another attempt.
mll = mll if mll.training else mll.train()
if not success:
logger.debug(
f"Fit attempt #{attempt} of {max_attempts} triggered retry "
f"policy {'.' if attempt == max_attempts else '; retrying...'}",
)
except caught_exception_types as err:
logger.debug(
f"Fit attempt #{attempt} of {max_attempts} failed with exception:\n"
f"{err}",
)
# If picking best of all attempts, return MLL with best state dict.
if best_state_dict is not None:
mll.load_state_dict(best_state_dict)
return mll.eval()
raise ModelFittingError("All attempts to fit the model have failed.")
def _fit_fallback_approximate(
mll: _ApproximateMarginalLogLikelihood,
*,
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
data_loader: DataLoader | None = None,
optimizer: Callable | None = None,
full_batch_limit: int = 1024,
**kwargs: Any,
) -> _ApproximateMarginalLogLikelihood:
r"""Fallback method for fitting approximate Gaussian processes.
Args:
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' ``grad`` attributes. If no closure is
provided, one will be obtained by calling ``get_loss_closure_with_grads``.
optimizer: The underlying optimization algorithm to run. Default to
``fit_gpytorch_mll_scipy`` when ``closure=None`` and the model's internal
training set has no more than ``full_batch_limit`` observations; otherwise,
defaults to ``fit_gpytorch_mll_torch``.
data_loader: An optional DataLoader to pass to ``get_loss_closure_with_grads``.
May only be provided when ``closure=None``.
full_batch_limit: Threshold for determining the default choice of ``optimizer``
when ``closure=None``.
**kwargs: Keyword arguments passed to ``_fit_fallback``.
"""
if data_loader is not None:
if closure is not None:
raise UnsupportedError(
"Only one of `data_loader` or `closure` may be passed."
)
closure = get_loss_closure_with_grads(
mll=mll,
data_loader=data_loader,
parameters=get_parameters(mll, requires_grad=True),
)
if optimizer is None:
optimizer = (
fit_gpytorch_mll_scipy
if closure is None and len(mll.model.train_targets) <= full_batch_limit
else fit_gpytorch_mll_torch
)
return _fit_fallback(mll=mll, closure=closure, optimizer=optimizer, **kwargs)
[docs]
def fit_fully_bayesian_model_nuts(
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
max_tree_depth: int = 6,
warmup_steps: int = 512,
num_samples: int = 256,
thinning: int = 16,
disable_progbar: bool = False,
jit_compile: bool = False,
seed: int = 0,
) -> None:
r"""Fit a fully Bayesian model using the No-U-Turn-Sampler (NUTS)
Uses NumPyro's NUTS implementation (backed by JAX) for MCMC inference.
Args:
model: Fully Bayesian GP to be fitted.
max_tree_depth: Maximum tree depth for NUTS
warmup_steps: The number of burn-in steps for NUTS.
num_samples: The number of MCMC samples. Note that with thinning,
num_samples / thinning samples are retained.
thinning: The amount of thinning. Every nth sample is retained.
disable_progbar: A boolean indicating whether to print the progress
bar and diagnostics during MCMC.
jit_compile: Whether to use jit. Using jit may be ~2X faster (rough estimate),
but it will also increase the memory usage and sometimes result in runtime
errors.
seed: Random seed for JAX PRNG.
Example:
>>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
>>> fit_fully_bayesian_model_nuts(gp)
"""
# Local import to avoid pulling in JAX/numpyro at module level,
# which would break environments without NumPy >= 2.0.
import jax
from numpyro.infer import MCMC, NUTS
model.train()
# Do inference with NUTS
nuts = NUTS(
model.pyro_model.sample,
dense_mass=True,
max_tree_depth=max_tree_depth,
)
mcmc = MCMC(
nuts,
num_warmup=warmup_steps,
num_samples=num_samples,
progress_bar=not disable_progbar,
)
mcmc.run(jax.random.PRNGKey(seed))
# Get final MCMC samples from the NumPyro model
mcmc_samples = model.pyro_model.postprocess_mcmc_samples(
mcmc_samples=mcmc.get_samples()
)
for k, v in mcmc_samples.items():
mcmc_samples[k] = v[::thinning]
# Load the MCMC samples back into the BoTorch model
model.load_mcmc_samples(mcmc_samples)
model.eval()
[docs]
def get_fitted_map_saas_model(
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
tau: Tensor | float | None = None,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SingleTaskGP:
"""Get a fitted MAP SAAS model with a Matern kernel.
Args:
train_X: Tensor of shape ``n x d`` with training inputs.
train_Y: Tensor of shape ``n x 1`` with training targets.
train_Yvar: Optional tensor of shape ``n x 1`` with observed noise,
inferred if None.
input_transform: An optional input transform.
outcome_transform: An optional outcome transform.
tau: Fixed value of the global shrinkage tau. If None, the model
places a HC(0.1) prior on tau. Can be a tensor for batched models
where each batch has a different sparsity prior.
optimizer_kwargs: A dict of options for the optimizer passed
to fit_gpytorch_mll.
Returns:
A fitted SingleTaskGP with a Matern kernel.
"""
# make sure optimizer_kwargs is a Dict
optimizer_kwargs = optimizer_kwargs or {}
model = get_map_saas_model(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
outcome_transform=outcome_transform,
tau=tau,
)
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
return model