Source code for torch_concepts.nn.modules.mid.inference.pyro.base

"""PyroBaseInference — base class for Pyro-backed inference engines.

Provides the shared Pyro plumbing required by any engine that uses Pyro's
effect handlers (``poutine.trace``, ``poutine.replay``, ``pyro.infer.SVI``):

- ``model_fn`` / ``guide_fn``: bound Pyro stochastic functions that traverse
  the wrapped PGM topologically and emit ``pyro.sample`` sites.
- ``_pyro_relaxed_distribution``: pyro-compatible straight-through relaxation
  for the discrete distribution families.
- ``dist_to_params`` / ``trace_to_params``: helpers to harvest distribution
  parameters from a Pyro trace into the engine-agnostic
  :class:`InferenceOutput.params` schema.

Parameter sharing with the wrapped PGM is inherited from
:class:`BaseInference` (the engine holds a reference to ``pgm``, so
``engine.parameters()`` enumerates the same tensors as ``pgm.parameters()``).
"""
from __future__ import annotations

from typing import Dict, List, Optional

import torch
import torch.distributions as td

from ...models.bayesian_network import BayesianNetwork
from ...models.variable import Delta
from ..base import BaseInference
from ..utils import build_distribution, reshape_value_to_event
from .utils import dist_to_params, trace_to_params


def _import_pyro():
    """Lazily import Pyro, raising a clear error if it is not installed."""
    try:
        import pyro
        import pyro.distributions as pyro_dist
        import pyro.poutine as poutine
        return pyro, pyro_dist, poutine
    except ImportError as exc:
        raise ImportError(
            "Pyro-based inference requires the `pyro-ppl` package. "
            "Install it with: pip install pyro-ppl"
        ) from exc


# -----------------------------------------------------------------------------
[docs] class PyroBaseInference(BaseInference): """Base class for inference engines backed by Pyro. Bundles the model/guide stochastic functions and the Pyro-side parameter harvesters. Subclasses (e.g. :class:`VariationalInference`) supply their own ``query`` method that orchestrates effect handlers. """ name = "PyroBaseInference"
[docs] def __init__(self, pgm: BayesianNetwork): super().__init__(pgm)
# ------------------------------------------------------------------ # Distribution helpers # ------------------------------------------------------------------ @staticmethod def _pyro_relaxed_distribution( variable, params: Dict[str, torch.Tensor], temperature: torch.Tensor, ) -> pyro_dist.Distribution: """Build a Pyro-compatible relaxed distribution for ``pyro.sample`` sites. Returns a ``pyro.distributions`` instance (subclass of ``TorchDistribution``) — required by ``pyro.sample`` for unobserved sites. Plain ``torch.distributions`` objects are not callable and would raise ``TypeError: 'X' object is not callable`` at runtime. Uses Pyro's own straight-through estimators (which register correctly with Pyro's effect-handler stack) for the discrete families. """ # Parameters are flat (*batch, size); the single size axis is reinterpreted # as the event (``to_event(1)`` / ``event_dim=1``) so batch_shape stays # (*batch,) and the ``pyro.plate("batch", ...)`` dim lines up. The # variable's declared shape is restored on the sampled realization. _, pyro_dist, _ = _import_pyro() D = variable.distribution if issubclass(D, td.Bernoulli): d = pyro_dist.RelaxedBernoulliStraightThrough(temperature=temperature, **params) return d.to_event(1) if issubclass(D, td.OneHotCategorical): d = pyro_dist.RelaxedOneHotCategoricalStraightThrough(temperature=temperature, **params) return d if issubclass(D, td.Normal): d = pyro_dist.Normal(**params) return d.to_event(1) if issubclass(D, td.MultivariateNormal): return pyro_dist.MultivariateNormal(**params) if D.__name__ == "Delta": # Map ``value`` (our Delta convention) to ``v`` (Pyro Delta convention). v = params["value"] return pyro_dist.Delta(v, event_dim=1) # Fallback for any other family: try the exact torch distribution. return build_distribution(variable, params) # ------------------------------------------------------------------ # Plate (member) addressing — shared by the Pyro engines, reusing the # CPD's slicing so a plate behaves the same as under the torch engine. # ------------------------------------------------------------------ def _gather_parents(self, cpd, cache, data): """Parent values for ``cpd``, slicing member-handle parents out of their plate's (sampled or observed) value — the Pyro counterpart of the torch engine's member-as-parent handling.""" parents: Dict[str, torch.Tensor] = {} for p in cpd.parents: src = p.plate.name # owning plate for a member handle, else p.name value = cache.get(src, data.get(src)) if value is None: raise ValueError( f"{self.name}: parent {p.name!r} of {cpd.variable.name!r} is " "neither sampled nor in data." ) if p.name != src: # member handle -> slice its column from the plate value value = self.pgm.factors[src].select_value(value, p.name) parents[p.name] = value return parents def _expose_members(self, params, query_names): """Add per-member entries for queried plate members, sliced (a view) from their plate's params, so members are addressable by name in the output.""" for name in query_names: var = self.pgm.resolve(name) if name != var.name and var.name in params: params[name] = self.pgm.factors[var.name].select(params[var.name], name) return params # ------------------------------------------------------------------ # Stochastic functions (bound to ``self.pgm``) # ------------------------------------------------------------------ def model_fn( self, data: Dict[str, torch.Tensor], temperature: torch.Tensor, latent_names: List[str], batch_size: Optional[int] = None, layer_kwargs: Dict[str, Dict] = {}, ) -> Dict[str, torch.Tensor]: """Pyro stochastic function for the generative model. Iterates ``self.pgm.sorted_variables`` in topological order. Each variable becomes a ``pyro.sample`` site: - Variables present in ``data`` are scored against their exact distribution (``obs=`` keyword to ``pyro.sample``). - Variables absent from ``data`` are sampled via a straight-through relaxation so gradients flow through the discrete sites. Registers ``self.pgm`` with Pyro's param store via ``pyro.module`` on every call so SVI updates flow back into the original PGM's ``nn.Parameter`` tensors (no parameter duplication). """ pyro, _, _ = _import_pyro() pgm = self.pgm pyro.module("pgm", pgm) if data: B = next(iter(data.values())).shape[0] elif batch_size is not None: B = batch_size else: raise ValueError( "Cannot infer batch size: `data` is empty and `batch_size` was not provided." ) cache: Dict[str, torch.Tensor] = {} with pyro.plate("batch", B, dim=-1): for level in pgm.levels: for var in level: cpd = pgm.factors[var.name] if cpd.is_root: params = cpd.root_params(B) else: parent_values = self._gather_parents(cpd, cache, data) params = cpd(parent_values=parent_values, **layer_kwargs.get(var.name, {})) obs = data.get(var.name, None) if obs is not None: # The distribution's event is the flat size axis, so match # the observation to it: (*batch, *shape) -> (*batch, size). obs = obs.reshape(obs.shape[0], var.size) d = ( build_distribution(var, params) if obs is not None else self._pyro_relaxed_distribution(var, params, temperature) ) sample = pyro.sample(var.name, d, obs=obs) # Cache the realization in the variable's event shape; downstream # CPD aggregation re-flattens it as needed. cache[var.name] = reshape_value_to_event(var, sample) return cache def guide_fn( self, data: Dict[str, torch.Tensor], temperature: torch.Tensor, latent_names: List[str], layer_kwargs: Dict[str, Dict] = {}, ) -> None: """Pyro stochastic function for the variational posterior. Runs a ``pyro.sample`` site for each latent variable using its registered guide CPD from ``self.pgm.guides``. Registers the guide ``nn.ModuleDict`` with Pyro's param store via ``pyro.module`` on every call so SVI updates flow back into the original guide CPDs' ``nn.Parameter`` tensors. """ pyro, _, _ = _import_pyro() pgm = self.pgm pyro.module("pgm_guides", pgm.guides) B = next(iter(data.values())).shape[0] if data else 1 with pyro.plate("batch", B, dim=-1): for name in latent_names: cpd = pgm.guides[name] if cpd.is_root: params = cpd(parent_values={}) params = { k: v.unsqueeze(0).expand(B, *v.shape) for k, v in params.items() } else: parent_values = {p.name: data[p.name] for p in cpd.parents} params = cpd(parent_values=parent_values, **layer_kwargs.get(name, {})) q = self._pyro_relaxed_distribution(cpd.variable, params, temperature) pyro.sample(name, q)