Source code for botorch.models.utils.priors

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Callable, Optional

import torch
from gpytorch.priors import Prior
from gpytorch.priors.utils import BUFFERED_PREFIX
from torch import Tensor
from torch.distributions import Beta
from torch.nn import Module as TModule


[docs] class BetaPrior(Prior, Beta): """Beta Prior parameterized by concentration1 (alpha) and concentration0 (beta). pdf(x) = x^(alpha - 1) * (1 - x)^(beta - 1) / B(alpha, beta) where alpha > 0 and beta > 0 are the concentration parameters. Supported on [0, 1], useful as a prior on correlation parameters. """ # Beta.concentration1/concentration0 are @property descriptors (they # delegate to an internal Dirichlet), so _bufferize_attributes cannot # delattr them. We store separate buffers with the BUFFERED_PREFIX and # sync them back after load_state_dict. _PARAM_NAMES = ("concentration1", "concentration0") def __init__( self, concentration1: float, concentration0: float, validate_args: bool = False, transform: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: """Initialize BetaPrior. Args: concentration1: Alpha (first concentration) parameter. concentration0: Beta (second concentration) parameter. validate_args: Whether to validate input arguments. transform: Optional transform to apply before computing log_prob. """ TModule.__init__(self) Beta.__init__( self, concentration1=concentration1, concentration0=concentration0, validate_args=validate_args, ) for attr in self._PARAM_NAMES: self.register_buffer( f"{BUFFERED_PREFIX}{attr}", getattr(self, attr).clone() ) self._transform = transform def _update_dirichlet_concentration(self) -> None: """ Sync buffered values back into the underlying Dirichlet distribution. """ c1 = getattr(self, f"{BUFFERED_PREFIX}concentration1") c0 = getattr(self, f"{BUFFERED_PREFIX}concentration0") self._dirichlet.concentration = torch.stack([c1, c0], dim=-1) def _apply(self, fn: Callable): to_return = super()._apply(fn) self._update_dirichlet_concentration() return to_return def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) self._update_dirichlet_concentration()
[docs] def expand(self, batch_shape: torch.Size) -> "BetaPrior": batch_shape = torch.Size(batch_shape) return BetaPrior( self.concentration1.expand(batch_shape), self.concentration0.expand(batch_shape), validate_args=self._validate_args, transform=self._transform, )