Source code for botorch.optim.closures.core

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Core methods for building closures in torch and interfacing with numpy."""

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any

import numpy as np
import numpy.typing as npt
import torch
from botorch.optim.utils import _handle_numerical_errors
from botorch.optim.utils.numpy_utils import as_ndarray
from botorch.utils.context_managers import zero_grad_ctx
from numpy import float64 as np_float64, zeros as np_zeros
from torch import Tensor

FILL_VALUE = 0.0


[docs] class ForwardBackwardClosure: r"""Wrapper for fused forward and backward closures.""" def __init__( self, forward: Callable[[], Tensor], parameters: dict[str, Tensor] ) -> None: r"""Initializes a ForwardBackwardClosure instance. Args: forward: Callable that returns a tensor. parameters: A dictionary of tensors whose ``grad`` fields are to be returned. """ self.forward = forward self.parameters = parameters def __call__(self, **kwargs: Any) -> tuple[Tensor, tuple[Tensor | None, ...]]: with zero_grad_ctx(parameters=self.parameters): value = self.forward(**kwargs).sum() value.backward() grads = tuple(param.grad for param in self.parameters.values()) return value, grads
[docs] class NdarrayOptimizationClosure: r"""Adds stateful behavior and a numpy.ndarray-typed API to a closure with an expected return type Tuple[Tensor, Union[Tensor, Sequence[Optional[Tensor]]]]. NaN values will be replaced with 0.0 in the returned ndarray.""" def __init__( self, closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]], parameters: dict[str, Tensor], ) -> None: r"""Initializes a NdarrayOptimizationClosure instance. Args: closure: A ForwardBackwardClosure instance. parameters: A dictionary of tensors representing the closure's state. Expected to correspond with the first ``len(parameters)`` optional gradient tensors returned by ``closure``. """ self.closure = closure self.parameters = parameters self._gradient_ndarray: npt.NDArray | None = None def __call__( self, state: npt.NDArray | None = None, **kwargs: Any ) -> tuple[npt.NDArray, npt.NDArray]: if state is not None: self.state = state try: value_tensor, grad_tensors = self.closure(**kwargs) value = as_ndarray(values=value_tensor, dtype=np_float64) grads = self._get_gradient_ndarray() index = 0 for param, grad in zip(self.parameters.values(), grad_tensors): size = param.numel() if grad is not None: grads[index : index + size] = as_ndarray( values=grad.view(-1), dtype=np_float64 ) index += size except RuntimeError as e: value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64) return value, grads @property def state(self) -> npt.NDArray: if len(self.parameters) == 0: raise RuntimeError("No parameters to get state from.") size = sum(tnsr.numel() for tnsr in self.parameters.values()) tnsr = next(iter(self.parameters.values())) dtype = np_float64 out = np.empty([size], dtype=dtype) index = 0 for tnsr in self.parameters.values(): size = tnsr.numel() out[index : index + size] = as_ndarray(tnsr.view(-1)) index += size return out @state.setter def state(self, state: npt.NDArray) -> None: with torch.no_grad(): index = 0 for tnsr in self.parameters.values(): size = tnsr.numel() vals = state[index : index + size] if tnsr.ndim else state[index] tnsr.copy_( torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view( tnsr.shape ) ) index += size def _get_gradient_ndarray(self) -> npt.NDArray: if self._gradient_ndarray is not None: self._gradient_ndarray.fill(FILL_VALUE) return self._gradient_ndarray size = sum(param.numel() for param in self.parameters.values()) self._gradient_ndarray = np_zeros(size, dtype=np_float64) return self._gradient_ndarray
[docs] class BatchedNDarrayOptimizationClosure: r"""Wraps a forward closure and batched parameters for use with ``fmin_l_bfgs_b_batched``. Unlike ``NdarrayOptimizationClosure`` which flattens all parameters into a single 1D vector, this class manages parameters as a 2D array of shape ``(batch_size, per_element_size)`` where each row corresponds to one batch element's independent parameter vector. This enables independent optimization of each batch element (e.g., each output of a ``BatchedMultiOutputGPyTorchModel``) using batched L-BFGS-B. """ def __init__( self, forward: Callable[[], Tensor], parameters: dict[str, Tensor], batch_shape: torch.Size, ) -> None: r"""Initializes a BatchedNDarrayOptimizationClosure instance. Args: forward: Callable that returns a tensor of shape ``batch_shape`` (per-batch-element loss values, e.g., negated per-output MLL). parameters: A dictionary of parameter tensors, each with shape ``(*batch_shape, *trailing_shape)``. batch_shape: The batch shape shared by all parameters (typically ``model._aug_batch_shape``). """ self.forward = forward self.parameters = parameters self.batch_shape = batch_shape self.batch_size = max(int(torch.Size(batch_shape).numel()), 1) n_batch_dims = len(batch_shape) self._trailing_sizes: dict[str, int] = {} self._per_element_size = 0 for name, param in parameters.items(): trailing = param.shape[n_batch_dims:] trailing_numel = max(int(torch.Size(trailing).numel()), 1) self._trailing_sizes[name] = trailing_numel self._per_element_size += trailing_numel @property def state(self) -> npt.NDArray: """Returns the current parameter state as a 2D ndarray of shape ``(batch_size, per_element_size)``.""" out = np.empty((self.batch_size, self._per_element_size), dtype=np_float64) index = 0 for name, param in self.parameters.items(): size = self._trailing_sizes[name] out[:, index : index + size] = as_ndarray( param.detach().reshape(self.batch_size, size), dtype=np_float64 ) index += size return out @state.setter def state(self, state: npt.NDArray) -> None: """Sets parameter values from a 2D ndarray of shape ``(batch_size, per_element_size)``.""" with torch.no_grad(): index = 0 for name, param in self.parameters.items(): size = self._trailing_sizes[name] vals = state[:, index : index + size] param.copy_( torch.as_tensor( vals, device=param.device, dtype=param.dtype ).reshape(param.shape) ) index += size def __call__( self, state: npt.NDArray | None = None, batch_indices: npt.NDArray | None = None, **kwargs: Any, ) -> tuple[npt.NDArray, npt.NDArray]: """Evaluate the closure and return per-batch values and gradients. Args: state: Optional 2D ndarray to set as the current state before evaluation. Shape ``(active_batch_size, per_element_size)`` if ``batch_indices`` is provided, else ``(batch_size, per_element_size)``. batch_indices: Optional 1D ndarray of indices into the original batch, indicating which elements are being evaluated. Used with ``fmin_l_bfgs_b_batched(pass_batch_indices=True)``. **kwargs: Keyword arguments passed to ``self.forward``. Returns: A tuple ``(values, grads)`` where ``values`` has shape ``(active_batch_size,)`` and ``grads`` has shape ``(active_batch_size, per_element_size)``. """ if state is not None: if batch_indices is not None: # Update only active batch elements full_state = self.state full_state[batch_indices] = state self.state = full_state else: self.state = state try: with zero_grad_ctx(parameters=self.parameters): per_batch_values = self.forward(**kwargs) scalar = per_batch_values.sum() scalar.backward() values = as_ndarray( per_batch_values.detach().reshape(self.batch_size), dtype=np_float64, ) grads = np.zeros( (self.batch_size, self._per_element_size), dtype=np_float64 ) index = 0 for name, param in self.parameters.items(): size = self._trailing_sizes[name] if param.grad is not None: grads[:, index : index + size] = as_ndarray( param.grad.reshape(self.batch_size, size), dtype=np_float64, ) index += size except RuntimeError as e: value, grad_flat = _handle_numerical_errors( e, x=self.state.ravel(), dtype=np_float64 ) values = np.full(self.batch_size, value / self.batch_size, dtype=np_float64) grads = np.zeros( (self.batch_size, self._per_element_size), dtype=np_float64 ) if batch_indices is not None: return values[batch_indices], grads[batch_indices] return values, grads