#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Tools for model fitting."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any
from warnings import warn
from botorch.exceptions.warnings import OptimizationWarning
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import (
OptimizationResult,
OptimizationStatus,
scipy_minimize,
torch_minimize,
)
from botorch.optim.stopping import ExpMAStoppingCriterion, StoppingCriterion
from botorch.optim.utils import get_parameters_and_bounds, TorchAttr
from botorch.utils.types import DEFAULT
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from numpy import ndarray
from torch import Tensor
from torch.nn import Module
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
TBoundsDict = dict[str, tuple[float | None, float | None]]
TScipyObjective = Callable[
[ndarray, MarginalLogLikelihood, dict[str, TorchAttr]], tuple[float, ndarray]
]
TModToArray = Callable[
[Module, TBoundsDict | None, set[str] | None],
tuple[ndarray, dict[str, TorchAttr], ndarray | None],
]
TArrayToMod = Callable[[Module, ndarray, dict[str, TorchAttr]], Module]
[docs]
def fit_gpytorch_mll_scipy(
mll: MarginalLogLikelihood,
parameters: dict[str, Tensor] | None = None,
bounds: dict[str, tuple[float | None, float | None]] | None = None,
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
closure_kwargs: dict[str, Any] | None = None,
method: str = "L-BFGS-B",
options: dict[str, Any] | None = None,
callback: Callable[[dict[str, Tensor], OptimizationResult], None] | None = None,
timeout_sec: float | None = None,
) -> OptimizationResult:
r"""Generic scipy.optimize-based fitting routine for GPyTorch MLLs.
For ``BatchedMultiOutputGPyTorchModel`` instances with a non-trivial
``_aug_batch_shape`` (e.g., multi-output ``SingleTaskGP`` or
``EnsembleMapSaasSingleTaskGP``), this automatically runs
``fmin_l_bfgs_b_batched`` to optimize each batch element's
hyperparameters independently. This converts the single
high-dimensional optimization problem into multiple lower-dimensional
problems that are easier to solve.
The model and likelihood in mll must already be in train mode.
Args:
mll: MarginalLogLikelihood to be maximized.
parameters: Optional dictionary of parameters to be optimized. Defaults
to all parameters of ``mll`` that require gradients.
bounds: A dictionary of user-specified bounds for ``parameters``. Used to update
default parameter bounds obtained from ``mll``.
closure: Callable that returns a tensor and an iterable of gradient
tensors. Responsible for setting the ``grad`` attributes of
``parameters``. If no closure is provided, one will be obtained
by calling ``get_loss_closure_with_grads``. When no closure is
provided and the model is a batched multi-output model, batched
independent fitting is used automatically.
closure_kwargs: Keyword arguments passed to ``closure``.
method: Solver type, passed along to scipy.optimize.minimize.
options: Dictionary of solver options, passed along to
scipy.optimize.minimize or ``fmin_l_bfgs_b_batched``.
callback: Optional callback taking ``parameters`` and an
``OptimizationResult`` as its sole arguments.
timeout_sec: Timeout in seconds after which to terminate the fitting loop
(note that timing out can result in bad fits!). Not currently
supported for batched independent fitting.
Returns:
The final OptimizationResult.
"""
# Avoid circular import: models.gpytorch imports from optim.
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
model = mll.model
if (
closure is None
and isinstance(model, BatchedMultiOutputGPyTorchModel)
and model._aug_batch_shape.numel() > 1
):
return _fit_gpytorch_mll_scipy_independent(
mll=mll,
parameters=parameters,
bounds=bounds,
options=options,
callback=callback,
timeout_sec=timeout_sec,
)
# Resolve ``parameters`` and update default bounds
_parameters, _bounds = get_parameters_and_bounds(mll)
bounds = _bounds if bounds is None else {**_bounds, **bounds}
if parameters is None:
parameters = {n: p for n, p in _parameters.items() if p.requires_grad}
if closure is None:
closure = get_loss_closure_with_grads(mll, parameters=parameters)
if closure_kwargs is not None:
closure = partial(closure, **closure_kwargs)
result = scipy_minimize(
closure=closure,
parameters=parameters,
bounds=bounds,
method=method,
options=options,
callback=callback,
timeout_sec=timeout_sec,
)
if result.status not in [OptimizationStatus.SUCCESS, OptimizationStatus.STOPPED]:
warn(
f"`scipy_minimize` terminated with status {result.status}, displaying"
f" original message from `scipy.optimize.minimize`: {result.message}",
OptimizationWarning,
stacklevel=2,
)
return result
[docs]
def fit_gpytorch_mll_torch(
mll: MarginalLogLikelihood,
parameters: dict[str, Tensor] | None = None,
bounds: dict[str, tuple[float | None, float | None]] | None = None,
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
closure_kwargs: dict[str, Any] | None = None,
step_limit: int | None = None,
stopping_criterion: StoppingCriterion | None = DEFAULT,
optimizer: Optimizer | Callable[..., Optimizer] = Adam,
scheduler: _LRScheduler | Callable[..., _LRScheduler] | None = None,
callback: Callable[[dict[str, Tensor], OptimizationResult], None] | None = None,
timeout_sec: float | None = None,
) -> OptimizationResult:
r"""Generic torch.optim-based fitting routine for GPyTorch MLLs.
Args:
mll: MarginalLogLikelihood to be maximized.
parameters: Optional dictionary of parameters to be optimized. Defaults
to all parameters of ``mll`` that require gradients.
bounds: A dictionary of user-specified bounds for ``parameters``. Used to update
default parameter bounds obtained from ``mll``.
closure: Callable that returns a tensor and an iterable of gradient
tensors. Responsible for setting the ``grad`` attributes of
``parameters``. If no closure is provided, one will be obtained
by calling ``get_loss_closure_with_grads``.
closure_kwargs: Keyword arguments passed to ``closure``.
step_limit: Optional upper bound on the number of optimization steps.
stopping_criterion: A StoppingCriterion for the optimization loop.
optimizer: A ``torch.optim.Optimizer`` instance or a factory that takes
a list of parameters and returns an ``Optimizer`` instance.
scheduler: A ``torch.optim.lr_scheduler._LRScheduler`` instance or a factory
that takes an ``Optimizer`` instance and returns an ``_LRSchedule``.
callback: Optional callback taking ``parameters`` and an
OptimizationResult as its sole arguments.
timeout_sec: Timeout in seconds after which to terminate the fitting loop
(note that timing out can result in bad fits!).
Returns:
The final OptimizationResult.
"""
if stopping_criterion == DEFAULT:
stopping_criterion = ExpMAStoppingCriterion()
# Resolve ``parameters`` and update default bounds
param_dict, bounds_dict = get_parameters_and_bounds(mll)
if parameters is None:
parameters = {n: p for n, p in param_dict.items() if p.requires_grad}
if closure is None:
closure = get_loss_closure_with_grads(mll, parameters)
if closure_kwargs is not None:
closure = partial(closure, **closure_kwargs)
return torch_minimize(
closure=closure,
parameters=parameters,
bounds=bounds_dict if bounds is None else {**bounds_dict, **bounds},
optimizer=optimizer,
scheduler=scheduler,
step_limit=step_limit,
stopping_criterion=stopping_criterion,
callback=callback,
timeout_sec=timeout_sec,
)
def _fit_gpytorch_mll_scipy_independent(
mll: MarginalLogLikelihood,
parameters: dict[str, Tensor] | None = None,
bounds: dict[str, tuple[float | None, float | None]] | None = None,
options: dict[str, Any] | None = None,
callback: Callable[[dict[str, Tensor], OptimizationResult], None] | None = None,
timeout_sec: float | None = None,
) -> OptimizationResult:
r"""Fit a batched model by independently optimizing each batch element's
hyperparameters using parallel L-BFGS-B.
This is an internal helper called by ``fit_gpytorch_mll_scipy`` when the
model is a ``BatchedMultiOutputGPyTorchModel`` with a non-trivial
``_aug_batch_shape``.
Args:
mll: MarginalLogLikelihood to be maximized.
parameters: Optional dictionary of parameters to be optimized.
bounds: A dictionary of user-specified bounds for ``parameters``.
options: Dictionary of solver options passed to
``fmin_l_bfgs_b_batched`` (e.g., ``maxiter``, ``pgtol``).
callback: Optional callback passed to ``fmin_l_bfgs_b_batched``.
timeout_sec: Timeout in seconds. Not currently supported for batched
fitting; a warning is issued if provided.
Returns:
The final OptimizationResult. The ``fval`` field contains the sum of
per-batch-element negative MLL values.
"""
if timeout_sec is not None:
warn(
"timeout_sec is not supported for batched independent fitting "
"and will be ignored.",
OptimizationWarning,
stacklevel=2,
)
# Avoid circular imports: closures and batched_lbfgs_b import from optim.
from botorch.optim.batched_lbfgs_b import fmin_l_bfgs_b_batched
from botorch.optim.closures import (
BatchedNDarrayOptimizationClosure,
get_loss_closure,
)
from botorch.optim.utils.numpy_utils import get_per_element_bounds
# Resolve parameters and bounds
_parameters, _bounds = get_parameters_and_bounds(mll)
bounds = _bounds if bounds is None else {**_bounds, **bounds}
if parameters is None:
parameters = {n: p for n, p in _parameters.items() if p.requires_grad}
batch_shape = mll.model._aug_batch_shape
# Build forward closure (returns per-batch neg MLL, NOT summed)
forward = get_loss_closure(mll)
# Build batched closure
batched_closure = BatchedNDarrayOptimizationClosure(
forward=forward,
parameters=parameters,
batch_shape=batch_shape,
)
# Extract per-element bounds
bounds_np = get_per_element_bounds(parameters, bounds, batch_shape)
# Get initial state
x0 = batched_closure.state # (batch_size, per_element_size)
# Resolve options for fmin_l_bfgs_b_batched
_recognized_options = {
"gtol",
"maxiter",
"maxcor",
"ftol",
"pgtol",
"maxls",
"factr",
}
lbfgsb_options: dict[str, Any] = {}
if options is not None:
# Map scipy-style option names to fmin_l_bfgs_b_batched kwargs
for key, value in options.items():
if key == "gtol":
lbfgsb_options["pgtol"] = value
elif key in ("maxiter", "maxcor", "ftol", "pgtol", "maxls", "factr"):
lbfgsb_options[key] = value
unrecognized = set(options.keys()) - _recognized_options
if unrecognized:
warn(
f"Unrecognized options for batched independent fitting "
f"will be ignored: {sorted(unrecognized)}.",
OptimizationWarning,
stacklevel=2,
)
# Run batched L-BFGS-B
xs, fs, results = fmin_l_bfgs_b_batched(
func=batched_closure,
x0=x0,
bounds=bounds_np,
pass_batch_indices=True,
callback=callback,
**lbfgsb_options,
)
# Write optimal state back to model parameters
batched_closure.state = xs
# Determine overall status from individual results
all_success = all(r.get("success", False) for r in results)
max_nit = max(r.get("nit", 0) for r in results)
if all_success:
status = OptimizationStatus.SUCCESS
else:
# Check if any hit maxiter
any_maxiter = any(r.get("warnflag", 0) == 1 for r in results)
status = (
OptimizationStatus.STOPPED if any_maxiter else OptimizationStatus.FAILURE
)
return OptimizationResult(
fval=float(fs.sum()),
step=max_nit,
status=status,
message=(
f"Batched L-BFGS-B: {sum(r.get('success', False) for r in results)}"
f"/{len(results)} outputs converged."
),
)