Source code for botorch.posteriors.ensemble

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Ensemble posteriors. Used in conjunction with ensemble models.
"""

from __future__ import annotations

import torch
from botorch.posteriors.posterior import Posterior
from torch import Tensor
from torch.distributions.multinomial import Multinomial


[docs] class EnsemblePosterior(Posterior): r"""Ensemble posterior, that should be used for ensemble models that compute eagerly a finite number of samples per X value as for example a deep ensemble or a random forest.""" def __init__(self, values: Tensor, weights: Tensor | None = None) -> None: r""" Args: values: Values of the samples produced by this posterior as a ``(b) x s x q x m`` tensor where ``m`` is the output size of the model and ``s`` is the ensemble size. weights: Optional weights for the ensemble members as a tensor of shape ``(s,)``. If None, uses uniform weights. """ if values.ndim < 3: raise ValueError("Values has to be at least three-dimensional.") self.values = values self._weights = weights.to(values) if weights is not None else None # Pre-compute normalized weights and mixture properties for efficiency self._mixture_dims = list(range(self.values.ndim - 2)) self._normalized_weights = self._compute_normalized_weights() self._normalized_mixture_weights = self._compute_normalized_mixture_weights() @property def ensemble_size(self) -> int: r"""The size of the ensemble""" return self.values.shape[-3] @property def mixture_size(self) -> int: r"""The total number of elements in the mixture dimensions""" return self.values.shape[:-2].numel() def _compute_normalized_weights(self) -> Tensor: r"""Compute and cache normalized weights.""" if self._weights is not None: return self._weights / self._weights.sum(dim=-1, keepdim=True) else: return ( torch.ones( self.ensemble_size, dtype=self.dtype, device=self.device, ) / self.ensemble_size ) def _compute_normalized_mixture_weights(self) -> Tensor: r"""Compute and cache normalized mixture weights.""" if self._weights is not None: unnorm_weights = self._weights.expand(self.values.shape[:-2]) return unnorm_weights / unnorm_weights.sum( dim=self._mixture_dims, keepdim=True ) else: return ( torch.ones( self.values.shape[:-2], dtype=self.dtype, device=self.device, ) / self.mixture_size ) @property def weights(self) -> Tensor: r"""The weights of the individual models in the ensemble. uniformly weighted by default.""" return self._normalized_weights @property def mixture_weights(self) -> Tensor: r"""The weights of the individual models in the ensemble. uniformly weighted by default, and normalized over ensemble and batch dimensions of the model.""" return self._normalized_mixture_weights @property def mixture_dims(self) -> list[int]: r"""The mixture dimensions of the posterior. For ensemble posteriors, this includes all dimensions except the last two (query points and outputs).""" return self._mixture_dims @property def device(self) -> torch.device: r"""The torch device of the posterior.""" return self.values.device @property def dtype(self) -> torch.dtype: r"""The torch dtype of the posterior.""" return self.values.dtype @property def mean(self) -> Tensor: r"""The mean of the posterior as a ``(b) x n x m``-dim Tensor.""" # Weighted average across ensemble dimension return (self.values * self.weights[..., None, None]).sum(dim=-3) @property def variance(self) -> Tensor: r"""The variance of the posterior as a ``(b) x n x m``-dim Tensor. Computed as the weighted sample variance across the ensemble outputs. This treats weights as probability weights (normalized to sum to 1) and computes the unbiased weighted sample variance using the formula: Var = Σ(w_i * (x_i - μ)²) / (1 - Σw_i²) where the sum over w_i² is taken over the ensemble dimension only. Source: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean under "Reliability Weights". """ if self.ensemble_size == 1: return torch.zeros_like(self.values.squeeze(-3)) # Add dimensions for query points and outputs to enable broadcasting weights = self.weights[..., None, None] squared_deviations = (self.values - self.mean.unsqueeze(-3)) ** 2 return (weights * squared_deviations).sum(dim=-3) / (1 - (weights**2).sum()) @property def mixture_mean(self) -> Tensor: r"""The mixture mean of the posterior as a ``(b) x n x m``-dim Tensor. Computed as the weighted average across the ensemble outputs. """ return (self.values * self.mixture_weights[..., None, None]).sum( dim=self.mixture_dims ) @property def mixture_variance(self) -> Tensor: r"""The mixture variance of the posterior as a ``(b) x n x m``-dim Tensor. Computed as the weighted sample variance across the ensemble outputs. This treats weights as probability weights (normalized to sum to 1) and computes the unbiased weighted sample variance using the formula: Var = Σ(w_i * (x_i - μ)²) / (1 - Σw_i²) where w_i is normalized over the entire mixture, and the sum over w_i² is taken over all mixture dimensions. Source: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean under "Reliability Weights". """ # Add dimensions for query points and outputs to enable broadcasting weights = self.mixture_weights[..., None, None] squared_deviations = (self.values - self.mixture_mean.unsqueeze(-3)) ** 2 return (weights * squared_deviations).sum(dim=self.mixture_dims) / ( 1 - (weights**2).sum() ) def _extended_shape( self, sample_shape: torch.Size = torch.Size(), # noqa: B008 ) -> torch.Size: r"""Returns the shape of the samples produced by the posterior with the given ``sample_shape``. """ return sample_shape + self.values.shape[:-3] + self.values.shape[-2:] @property def batch_shape(self) -> torch.Size: return self.values.shape[:-3]
[docs] def rsample( self, sample_shape: torch.Size | None = None, ) -> Tensor: r"""Sample from the posterior (with gradients). Based on the sample shape, base samples are generated and passed to ``rsample_from_base_samples``. Args: sample_shape: A ``torch.Size`` object specifying the sample shape. To draw ``n`` samples, set to ``torch.Size([n])``. To draw ``b`` batches of ``n`` samples each, set to ``torch.Size([b, n])``. Returns: Samples from the posterior, a tensor of shape ``self._extended_shape(sample_shape=sample_shape)``. """ if sample_shape is None or len(sample_shape) == 0: sample_shape = torch.Size([1]) # NOTE This occasionally happens in Hypervolume evals when there # are no points which improve over the reference point. In this case, we # create a posterior for all the points which improve over the reference, # which is an empty set. if self.values.numel() == 0: return torch.empty( *self._extended_shape(sample_shape=sample_shape), device=self.device, dtype=self.dtype, ) base_samples = ( Multinomial( probs=self.mixture_weights, ) .sample(sample_shape=sample_shape) .argmax(dim=-1) ) return self.rsample_from_base_samples( sample_shape=sample_shape, base_samples=base_samples )
[docs] def rsample_from_base_samples( self, sample_shape: torch.Size, base_samples: Tensor ) -> Tensor: r"""Sample from the posterior (with gradients) using base samples. This is intended to be used with a sampler that produces the corresponding base samples, and enables acquisition optimization via Sample Average Approximation. Args: sample_shape: A ``torch.Size`` object specifying the sample shape. To draw ``n`` samples, set to ``torch.Size([n])``. To draw ``b`` batches of ``n`` samples each, set to ``torch.Size([b, n])``. base_samples: A Tensor of indices as base samples of shape ``sample_shape``, typically obtained from ``IndexSampler``. This is used for deterministic optimization. The predictions of the ensemble corresponding to the indices are then sampled. Returns: Samples from the posterior, a tensor of shape ``self._extended_shape(sample_shape=sample_shape)``. """ # Check that the first dimensions of base_samples match sample_shape if base_samples.shape != sample_shape + self.batch_shape: raise ValueError( f"Sample_shape={sample_shape + self.batch_shape} does not match " f"the leading dimensions of base_samples.shape={base_samples.shape}." ) if self.batch_shape: # Values is always going to be 4-dimensional with this reshape, # even if we have more than one batch dimension values = self.values.reshape( ((self.batch_shape.numel(),) + self.values.shape[-3:]) ) # Collapse the base samples to enable index selecting along the # ensemble dim (dim -3) batch_numel = self.batch_shape.numel() collapsed_base_samples = base_samples.reshape(sample_shape + (batch_numel,)) # First dimension is just 1, 2, 3, ..., batch_shape.numel() -1 to flatten # the first dimension and extract one index # second dimension extracts the ensemble member, for each element in the # entire batch shape return values[torch.arange(batch_numel), collapsed_base_samples].reshape( self._extended_shape(sample_shape=sample_shape) ) return self.values[base_samples]