#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Helpers for multitask modeling.
"""
from __future__ import annotations
import torch
from gpytorch.distributions import MultitaskMultivariateNormal
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from linear_operator import to_linear_operator
[docs]
def separate_mtmvn(mvn: MultitaskMultivariateNormal) -> list[MultivariateNormal]:
"""
Separate a MTMVN into a list of MVNs, where covariance across data within each task
are preserved, while covariance across task are dropped.
"""
# T150340766 Upstream as a class method on gpytorch MultitaskMultivariateNormal.
full_covar = mvn.lazy_covariance_matrix
num_data, num_tasks = mvn.mean.shape[-2:]
mvns = []
for c in range(num_tasks):
if mvn._interleaved:
# For interleaved: task c data points are at positions
# c, c+num_tasks, c+2*num_tasks, ...
# Must use tensor indexing for strided access.
task_indices = torch.arange(
c, num_data * num_tasks, num_tasks, device=full_covar.device
)
task_covar = full_covar[..., task_indices, :]
task_covar = task_covar[..., :, task_indices]
else:
# For non-interleaved: task c data points are at contiguous positions
# c*num_data to (c+1)*num_data. Use slice-based indexing which
# LinearOperator handles more efficiently than tensor indexing.
start = c * num_data
end = start + num_data
task_covar = full_covar[..., start:end, start:end]
mvns.append(
MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar))
)
return mvns