# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyMC-ArviZ conversion code."""
import logging
import warnings
from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
cast,
)
import numpy as np
import xarray
from arviz import dict_to_dataset, rcParams
from arviz_base.base import requires
from arviz_base.types import CoordSpec, DimSpec
from pytensor.graph import ancestors
from pytensor.tensor.sharedvar import SharedVariable
from rich.progress import Console
from rich.theme import Theme
from xarray import Dataset, DataTree
import pymc
from pymc.model import Model, modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames
if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
___all__ = [""]
_log = logging.getLogger(__name__)
RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False
# random variable object ...
Var = Any
def dict_to_dataset_drop_incompatible_coords(
vars_dict, *args, dims, coords, sample_dims=None, **kwargs
):
safe_coords = coords
dropped_coords = {}
if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
coords_lengths = {k: len(v) for k, v in coords.items()}
for var_name, var in vars_dict.items():
# Iterate in reversed because of chain/draw batch dimensions
for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)):
coord_length = coords_lengths.get(dim, None)
if (coord_length is not None) and (coord_length != dim_length):
warnings.warn(
f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n"
"This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`."
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables. "
"Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n"
"To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
UserWarning,
)
if safe_coords is coords:
safe_coords = coords.copy()
dropped_coords[dim] = safe_coords.pop(dim)
coords_lengths.pop(dim)
ds = dict_to_dataset(
vars_dict, *args, dims=dims, coords=safe_coords, sample_dims=sample_dims, **kwargs
)
# After alignment, shorter variables are padded with NaN so the final dim
# length may match the original coord length. Re-assign those coords.
reassign = {
dim: coord_vals
for dim, coord_vals in dropped_coords.items()
if dim in ds.dims and ds.sizes[dim] == len(coord_vals)
}
if reassign:
ds = ds.assign_coords(reassign)
return ds
def find_observations(model: "Model") -> dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
aux_obs = model.rvs_to_values.get(obs, None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")
return observations
def find_constants(model: "Model") -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
model_vars = model.basic_RVs + model.deterministics + model.potentials
value_vars = set(model.rvs_to_values.values())
constant_data = {}
for var in model.data_vars:
if var in value_vars:
# An observed value variable could also be part of the generative graph
if var not in ancestors(model_vars):
continue
if isinstance(var, SharedVariable):
var_value = var.get_value()
else:
var_value = var.data
constant_data[var.name] = var_value
return constant_data
def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}
return coords, dims
class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Name comes from its similarity to ``defaultdict``:
entries are lazily created.
Parameters
----------
samples : int
The number of samples that will be collected, per variable,
into the trace.
Attributes
----------
trace_dict : Dict[str, np.ndarray]
A dictionary constituting a trace. Should be extracted
after a procedure has filled the `_DefaultTrace` using the
`insert()` method
"""
def __init__(self, samples: int):
self._len: int = samples
self.trace_dict: dict[str, np.ndarray] = {}
def insert(self, k: str, v, idx: int):
"""
Insert `v` as the value of the `idx`th sample for the variable `k`.
Parameters
----------
k: str
Name of the variable.
v: anything that can go into a numpy array (including a numpy array)
The value of the `idx`th sample from variable `k`
ids: int
The index of the sample we are inserting into the trace.
"""
value_shape = np.shape(v)
# initialize if necessary
if k not in self.trace_dict:
array_shape = (self._len, *value_shape)
self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)
# do the actual insertion
if value_shape == ():
self.trace_dict[k][idx] = v
else:
self.trace_dict[k][idx, :] = v
class DataTreeConverter:
"""Encapsulate conventions of InferenceData schema in DataTree conversion."""
model: Model | None = None
posterior_predictive: Mapping[str, np.ndarray] | None = None
predictions: Mapping[str, np.ndarray] | None = None
prior: Mapping[str, np.ndarray] | None = None
def __init__(
self,
*,
trace=None,
prior=None,
posterior_predictive=None,
log_likelihood=False,
log_prior=False,
predictions=None,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
model=None,
save_warmup: bool | None = None,
include_transformed: bool = False,
):
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.include_transformed = include_transformed
self.trace = trace
# this permits us to get the model from command-line argument or from with model:
self.model = modelcontext(model)
self.attrs = None
if trace is not None:
self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
if hasattr(trace.report, "n_draws") and trace.report.n_draws is not None:
self.ndraws = trace.report.n_draws
self.attrs = {
"sampling_time": trace.report.t_sampling,
"tuning_steps": trace.report.n_tune,
}
else:
self.ndraws = len(trace)
if self.save_warmup:
warnings.warn(
"Warmup samples will be stored in posterior group and will not be"
" excluded from stats and diagnostics."
" Do not slice the trace manually before conversion",
UserWarning,
)
self.ntune = len(self.trace) - self.ndraws
self.posterior_trace, self.warmup_trace = self.split_trace()
else:
self.nchains = self.ndraws = 0
self.prior = prior
self.posterior_predictive = posterior_predictive
self.log_likelihood = log_likelihood
self.log_prior = log_prior
self.predictions = predictions
if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)):
raise ValueError(
"When constructing DataTree you must pass at least"
" one of trace, prior, posterior_predictive or predictions."
)
user_coords = {} if coords is None else coords
user_dims = {} if dims is None else dims
model_coords, model_dims = coords_and_dims_for_inferencedata(self.model)
self.coords = {**model_coords, **user_coords}
self.dims = {**model_dims, **user_dims}
if sample_dims is None:
sample_dims = ["chain", "draw"]
self.sample_dims = sample_dims
self.observations = find_observations(self.model)
def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Returns
-------
trace_posterior: MultiTrace or None
The slice of the trace corresponding to the posterior. If the posterior
trace is empty, None is returned
trace_warmup: MultiTrace or None
The slice of the trace corresponding to the warmup. If the warmup trace is
empty or ``save_warmup=False``, None is returned
"""
trace_posterior = None
trace_warmup = None
if self.save_warmup and self.ntune > 0:
trace_warmup = self.trace[: self.ntune]
if self.ndraws > 0:
trace_posterior = self.trace[self.ntune :]
return trace_posterior, trace_warmup
@requires("trace")
def posterior_to_xarray(self):
"""Convert the posterior to xarray datasets with separate warmup group."""
var_names = get_default_varnames(
self.trace.varnames, include_transformed=self.include_transformed
)
result = {}
if self.warmup_trace:
data_warmup = {}
for var_name in var_names:
data_warmup[var_name] = np.array(
self.warmup_trace.get_values(var_name, combine=False, squeeze=False)
)
result["warmup_posterior"] = dict_to_dataset(
data_warmup,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
)
if self.posterior_trace:
data = {}
for var_name in var_names:
data[var_name] = np.array(
self.posterior_trace.get_values(var_name, combine=False, squeeze=False)
)
result["posterior"] = dict_to_dataset(
data,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
)
return result if result else {}
@requires("trace")
def sample_stats_to_xarray(self):
"""Extract sample_stats from PyMC trace with separate warmup group."""
rename_key = {
"model_logp": "lp",
"mean_tree_accept": "acceptance_rate",
"depth": "tree_depth",
"tree_size": "n_steps",
}
result = {}
if self.warmup_trace:
data_warmup = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
if name in {"tune", "in_warmup"}:
continue
data_warmup[name] = np.array(
self.warmup_trace.get_sampler_stats(stat, combine=False, squeeze=False)
)
result["warmup_sample_stats"] = dict_to_dataset(
data_warmup,
inference_library=pymc,
dims=None,
coords=self.coords,
attrs=self.attrs,
)
if self.posterior_trace:
data = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
if name in {"tune", "in_warmup"}:
continue
data[name] = np.array(
self.posterior_trace.get_sampler_stats(stat, combine=False, squeeze=False)
)
result["sample_stats"] = dict_to_dataset(
data,
inference_library=pymc,
dims=None,
coords=self.coords,
attrs=self.attrs,
)
return result if result else {}
@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = self.posterior_predictive
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
return dict_to_dataset(
data,
inference_library=pymc,
coords=self.coords,
dims=dims,
sample_dims=self.sample_dims,
)
@requires(["predictions"])
def predictions_to_xarray(self):
"""Convert predictions (out of sample predictions) to xarray."""
data = self.predictions
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
return dict_to_dataset(
data,
inference_library=pymc,
coords=self.coords,
dims=dims,
sample_dims=self.sample_dims,
)
def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
if self.prior is None:
return {"prior": None, "prior_predictive": None}
if self.observations is not None:
prior_predictive_vars = list(set(self.observations).intersection(self.prior))
prior_vars = [key for key in self.prior.keys() if key not in prior_predictive_vars]
else:
prior_vars = list(self.prior.keys())
prior_predictive_vars = None
priors_dict = {}
for group, var_names in zip(
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
):
priors_dict[group] = (
None
if var_names is None
else dict_to_dataset_drop_incompatible_coords(
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
inference_library=pymc,
coords=self.coords,
dims=self.dims,
sample_dims=self.sample_dims,
)
)
return priors_dict
@requires("observations")
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.predictions:
return None
return dict_to_dataset(
self.observations,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
sample_dims=[],
)
@requires("model")
def constant_data_to_xarray(self):
"""Convert constant data to xarray."""
constant_data = find_constants(self.model)
if not constant_data:
return None
xarray_dataset = dict_to_dataset(
constant_data,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
sample_dims=[],
)
# provisional handling of scalars in constant
# data to prevent promotion to rank 1
# in the future this will be handled by arviz
scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0]
for s in scalars:
s_dim_0_name = f"{s}_dim_0"
# Only squeeze if the dimension exists in the dataset
if s_dim_0_name in xarray_dataset.sizes:
xarray_dataset = xarray_dataset.squeeze(s_dim_0_name, drop=True)
return xarray_dataset
def to_inference_data(self):
"""Convert all available data to an DataTree object.
Note that if groups can not be created (e.g., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the DataTree
will not have those groups.
"""
id_dict = {
**(self.posterior_to_xarray() or {}),
**(self.sample_stats_to_xarray() or {}),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
}
if self.predictions:
id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
else:
id_dict["constant_data"] = self.constant_data_to_xarray()
id_dict = {k: v for k, v in id_dict.items() if v is not None}
if not self.save_warmup:
id_dict = {k: v for k, v in id_dict.items() if not k.startswith("warmup_")}
idata = DataTree.from_dict({f"/{k}": v for k, v in id_dict.items()})
if self.log_likelihood:
from pymc.stats.log_density import compute_log_likelihood
idata = compute_log_likelihood(
idata,
var_names=None if self.log_likelihood is True else self.log_likelihood,
extend_inferencedata=True,
model=self.model,
sample_dims=self.sample_dims,
progressbar=False,
)
if self.log_prior:
from pymc.stats.log_density import compute_log_prior
idata = compute_log_prior(
idata,
var_names=None if self.log_prior is True else self.log_prior,
extend_inferencedata=True,
model=self.model,
sample_dims=self.sample_dims,
progressbar=False,
)
return idata
[docs]
def to_inference_data(
trace: Optional["MultiTrace"] = None,
*,
prior: Mapping[str, Any] | None = None,
posterior_predictive: Mapping[str, Any] | None = None,
log_likelihood: bool | Iterable[str] = False,
log_prior: bool | Iterable[str] = False,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
model: "Model" = None,
save_warmup: bool | None = None,
include_transformed: bool = False,
) -> xarray.DataTree:
"""Convert pymc data into an InferenceData object.
All three of them are optional arguments, but at least one of ``trace``,
``prior`` and ``posterior_predictive`` must be present.
For a usage example read the
:ref:`Creating InferenceData section on from_pymc <creating_InferenceData>`
Parameters
----------
trace : MultiTrace, optional
Trace generated from MCMC sampling. Output of
:func:`~pymc.sampling.sample`.
prior : dict, optional
Dictionary with the variable names as keys, and values numpy arrays
containing prior and prior predictive samples.
posterior_predictive : dict, optional
Dictionary with the variable names as keys, and values numpy arrays
containing posterior predictive samples.
log_likelihood : bool or array_like of str, optional
List of variables to calculate `log_likelihood`. Defaults to False.
If set to True, computes `log_likelihood` for all observed variables.
log_prior : bool or array_like of str, optional
List of variables to calculate `log_prior`. Defaults to False.
If set to True, computes `log_prior` for all unobserved variables.
coords : dict of {str: array-like}, optional
Map of coordinate names to coordinate values
dims : dict of {str: list of str}, optional
Map of variable names to the coordinate names to use to index its dimensions.
model : Model, optional
Model used to generate ``trace``. It is not necessary to pass ``model`` if in
``with`` context.
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
include_transformed : bool, optional
Save the transformed parameters in the InferenceData object. By default, these are
not saved.
Returns
-------
xarray.DataTree
"""
if isinstance(trace, DataTree):
return trace
return DataTreeConverter(
trace=trace,
prior=prior,
posterior_predictive=posterior_predictive,
log_likelihood=log_likelihood,
log_prior=log_prior,
coords=coords,
dims=dims,
sample_dims=sample_dims,
model=model,
save_warmup=save_warmup,
include_transformed=include_transformed,
).to_inference_data()
### Later I could have this return ``None`` if the ``idata_orig`` argument is supplied. But
### perhaps we should have an inplace argument?
[docs]
def predictions_to_inference_data(
predictions,
posterior_trace: Optional["MultiTrace"] = None,
model: Optional["Model"] = None,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
idata_orig: xarray.DataTree | None = None,
inplace: bool = False,
) -> xarray.DataTree:
"""Translate out-of-sample predictions into ``DataTree``.
Parameters
----------
predictions: Dict[str, np.ndarray]
The predictions are the return value of :func:`~pymc.sample_posterior_predictive`,
a dictionary of strings (variable names) to numpy ndarrays (draws).
Requires the arrays to follow the convention ``chain, draw, *shape``.
posterior_trace: MultiTrace
This should be a trace that has been thinned appropriately for
``pymc.sample_posterior_predictive``. Specifically, any variable whose shape is
a deterministic function of the shape of any predictor (explanatory, independent, etc.)
variables must be *removed* from this trace.
model: Model
The pymc model. It can be omitted if within a model context.
coords: Dict[str, array-like[Any]]
Coordinates for the variables. Map from coordinate names to coordinate values.
dims: Dict[str, array-like[str]]
Map from variable name to ordered set of coordinate names.
idata_orig: DataTree, optional
If supplied, then modify this inference data in place, adding ``predictions`` and
(if available) ``predictions_constant_data`` groups. If this is not supplied, make a
fresh DataTree
inplace: boolean, optional
If idata_orig is supplied and inplace is True, merge the predictions into idata_orig,
rather than returning a fresh DataTree object.
Returns
-------
DataTree:
May be modified ``idata_orig``.
"""
if inplace and not idata_orig:
raise ValueError(
"Do not pass True for inplace unless passing an existing DataTree as idata_orig"
)
converter = DataTreeConverter(
trace=posterior_trace,
predictions=predictions,
model=model,
coords=coords,
dims=dims,
sample_dims=sample_dims,
log_likelihood=False,
)
if hasattr(idata_orig, "posterior"):
assert idata_orig is not None
converter.nchains = idata_orig["posterior"].sizes["chain"]
converter.ndraws = idata_orig["posterior"].sizes["draw"]
else:
aelem = next(iter(predictions.values()))
converter.nchains, converter.ndraws = aelem.shape[:2]
new_idata = converter.to_inference_data()
if idata_orig is None:
return new_idata
elif inplace:
existing_groups = set(idata_orig.children) & set(new_idata.children)
conflicting = existing_groups - {"observed_data", "constant_data"}
if conflicting:
warnings.warn(
f"groups {conflicting} already exist in the DataTree and will be overwritten.",
UserWarning,
stacklevel=2,
)
idata_orig.update(new_idata)
return idata_orig
else:
new_idata.update(idata_orig)
return new_idata
def dataset_to_point_list(
ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str]
) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]:
# All keys of the dataset must be a str
var_names = cast(list[str], list(ds.keys()))
for vn in var_names:
if not isinstance(vn, str):
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
num_sample_dims = len(sample_dims)
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
stacked_size = np.prod(transposed_dict[var_names[0]].shape[:num_sample_dims], dtype=int)
stacked_dict = {
vn: da.values.reshape((stacked_size, *da.shape[num_sample_dims:]))
for vn, da in transposed_dict.items()
}
points = [
{vn: stacked_dict[vn][i, ...] for vn in var_names}
for i in range(np.prod([len(coords) for coords in stacked_dims.values()]))
]
# use the list of points
return cast(list[dict[str, np.ndarray]], points), stacked_dims
def apply_function_over_dataset(
fn: PointFunc,
dataset: Dataset,
*,
output_var_names: Sequence[str],
coords,
dims,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
) -> Dataset:
posterior_pts, stacked_dims = dataset_to_point_list(dataset, sample_dims)
n_pts = len(posterior_pts)
out_dict = _DefaultTrace(n_pts)
indices = range(n_pts)
with CustomProgress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Computing ...", total=n_pts)
for idx in indices:
out = fn(posterior_pts[idx])
fn.f.trust_input = True # If we arrive here the dtypes are valid
for var_name, val in zip(output_var_names, out, strict=True):
out_dict.insert(var_name, val, idx)
progress.advance(task)
progress.update(task, refresh=True, completed=n_pts)
out_trace = out_dict.trace_dict
for key, val in out_trace.items():
out_trace[key] = val.reshape(
(
*[len(coord) for coord in stacked_dims.values()],
*val.shape[1:],
)
)
return dict_to_dataset(
out_trace,
inference_library=pymc,
dims=dims,
coords=coords,
sample_dims=list(sample_dims),
skip_event_dims=True,
)