Source code for botorch.optim.closures.model_closures

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Utilities for building model-based closures."""

from __future__ import annotations

from collections.abc import Callable, Sequence
from functools import partial
from itertools import chain, repeat
from typing import Any

from botorch.optim.closures.core import ForwardBackwardClosure
from gpytorch.mlls import (
    ExactMarginalLogLikelihood,
    MarginalLogLikelihood,
    SumMarginalLogLikelihood,
)
from torch import Tensor
from torch.utils.data import DataLoader


[docs] def get_loss_closure( mll: MarginalLogLikelihood, data_loader: DataLoader | None = None, ) -> Callable[[], Tensor]: r"""Factory function for creating loss closures from MarginalLogLikelihoods. This method acts as a clearing house for factory functions that define how ``mll`` is evaluated. Users may specify custom evaluation routines by passing an ``mll`` or an ``mll.model`` with a method ``compute_custom_loss``. Args: mll: A MarginalLogLikelihood instance whose negative defines the loss. data_loader: An optional DataLoader instance for cases where training data is passed in rather than obtained from ``mll.model``. Returns: A closure that takes zero positional arguments and returns the negated value of ``mll``. """ if hasattr(mll, "compute_custom_loss"): return mll.compute_custom_loss if hasattr(mll.model, "compute_custom_loss"): return partial(mll.model.compute_custom_loss, mll=mll) if data_loader is not None: return _get_loss_closure_fallback_external(mll=mll, data_loader=data_loader) if isinstance(mll, ExactMarginalLogLikelihood): return _get_loss_closure_exact_internal(mll=mll) if isinstance(mll, SumMarginalLogLikelihood): return _get_loss_closure_sum_internal(mll=mll) return _get_loss_closure_fallback_internal(mll=mll)
[docs] def get_loss_closure_with_grads( mll: MarginalLogLikelihood, parameters: dict[str, Tensor], data_loader: DataLoader | None = None, ) -> ForwardBackwardClosure: """ Add a backward pass to a loss closure obtained by calling ``get_loss_closure``, wrapping it in a ``ForwardBackwardClosure``. For further details, see ``get_loss_closure``. Args: mll: A MarginalLogLikelihood instance whose negative defines the loss. parameters: A dictionary of tensors whose ``grad`` fields are to be returned. data_loader: An optional DataLoader instance for cases where training data is passed in rather than obtained from ``mll.model``. Returns: A closure that takes zero positional arguments and returns the reduced and negated value of ``mll`` along with the gradients of ``parameters``. """ loss_closure = get_loss_closure(mll=mll, data_loader=data_loader) return ForwardBackwardClosure(forward=loss_closure, parameters=parameters)
def _get_loss_closure_fallback_external( mll: MarginalLogLikelihood, data_loader: DataLoader, ) -> Callable[[], Tensor]: r"""Fallback loss closure with externally provided data.""" batch_generator = chain.from_iterable(iter(data_loader) for _ in repeat(None)) def closure(**kwargs: Any) -> Tensor: batch = next(batch_generator) if not isinstance(batch, Sequence): raise TypeError( "Expected `data_loader` to generate a batch of tensors, " f"but found {type(batch)}." ) num_inputs = len(mll.model.train_inputs) model_output = mll.model(*batch[:num_inputs]) log_likelihood = mll(model_output, *batch[num_inputs:], **kwargs) return -log_likelihood return closure def _get_loss_closure_fallback_internal( mll: MarginalLogLikelihood, ) -> Callable[[], Tensor]: r"""Fallback loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: model_output = mll.model(*mll.model.train_inputs) log_likelihood = mll(model_output, mll.model.train_targets, **kwargs) return -log_likelihood return closure def _get_loss_closure_exact_internal( mll: ExactMarginalLogLikelihood, ) -> Callable[[], Tensor]: r"""ExactMarginalLogLikelihood loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: model = mll.model # The inputs will get transformed in forward here. model_output = model(*model.train_inputs) log_likelihood = mll( model_output, model.train_targets, # During model training, the model inputs get transformed in the forward # pass. The train_inputs property is not transformed yet, so we need to # transform it before passing it to the likelihood for consistency. *(model.transform_inputs(X=t_in) for t_in in model.train_inputs), **kwargs, ) return -log_likelihood return closure def _get_loss_closure_sum_internal( mll: SumMarginalLogLikelihood, ) -> Callable[[], Tensor]: r"""SumMarginalLogLikelihood loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: model = mll.model # The inputs will get transformed in forward here. model_output = model(*model.train_inputs) log_likelihood = mll( model_output, model.train_targets, # During model training, the model inputs get transformed in the forward # pass. The train_inputs property is not transformed yet, so we need to # transform it before passing it to the likelihood for consistency. *( (model.transform_inputs(X=t_in) for t_in in sub_t_in) for sub_t_in in model.train_inputs ), **kwargs, ) return -log_likelihood return closure