Source code for torch_concepts.nn.modules.mid.inference.pyro.variational

"""Variational inference engine (Pyro backend).

This engine uses Pyro's effect handlers (``poutine.trace`` and
``poutine.replay``) to collect distribution parameters from the generative
model and the variational guide.

The Pyro stochastic functions themselves are inherited from
:class:`PyroBaseInference` (``model_fn`` / ``guide_fn``); this class only
orchestrates the effect-handler plumbing.
"""
from __future__ import annotations

import sys
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn

from ...models.bayesian_network import BayesianNetwork
from ...models.cpd import ParametricCPD
from ..utils import make_temperature_schedule
from ....outputs import InferenceOutput
from .base import PyroBaseInference, trace_to_params, _import_pyro


_YELLOW_START = "\033[33m"
_YELLOW_END = "\033[0m"


def _yellow_notice(msg: str) -> None:
    """Print a yellow notice on ``stderr``."""
    print(f"{_YELLOW_START}{msg}{_YELLOW_END}", file=sys.stderr, flush=True)


[docs] class VariationalInference(PyroBaseInference): """Variational inference engine. Uses Pyro's effect handlers to trace the generative model and the variational guide, intercept sample sites, and collect distribution parameters. The Pyro stochastic functions are provided by :class:`PyroBaseInference`. Parameters ---------- pgm : BayesianNetwork The probabilistic graphical model. latents : dict, optional Declaration of latent (unobservable) variables and their guide CPDs. Maps each latent variable name to a user-provided :class:`~torch_concepts.nn.modules.mid.models.cpd.ParametricCPD` that acts as the variational guide for that variable. If omitted or empty, no guides are registered and the engine warns that variational inference may not behave as expected. initial_temperature, annealing, annealing_rate Temperature schedule for the relaxed-discrete sites; see :func:`~torch_concepts.nn.modules.mid.inference.base.make_temperature_schedule`. """ name = "VariationalInference"
[docs] def __init__( self, pgm: BayesianNetwork, latents: Optional[Dict[str, ParametricCPD]] = None, n_samples: int = 1, max_plate_nesting: int = 1, initial_temperature: float = 1.0, annealing: Union[str, Callable[[int], float]] = "constant", annealing_rate: float = 0.0, ): super().__init__(pgm) # Detect PGM device before building guides. try: _pgm_device = next(pgm.parameters()).device except StopIteration: _pgm_device = torch.device("cpu") self._latent_names = self._build_guides(pgm, latents=latents or {}) # Move the newly registered guides to the same device as the PGM. pgm.guides.to(_pgm_device) self.n_samples = int(n_samples) self.max_plate_nesting = int(max_plate_nesting) self._warned_latent_evidence = False # Retained for repr/introspection; the live schedule lives in ``_schedule``. self.initial_temperature = float(initial_temperature) self.annealing = annealing self.annealing_rate = float(annealing_rate) self._schedule = make_temperature_schedule( initial_temperature, annealing, annealing_rate ) self._step: int = 0 self.register_buffer( "_temperature", torch.tensor(float(self._schedule(self._step))), ) # Construction-time yellow notices. if self._latent_names: _yellow_notice( f"{self.name} Warning:\nDeclared latent (unobservable) variables: " f"{self._latent_names}. This inference algorithm expects to be queried " "with those variables absent from `query` (or mapped to None). " "No guarantees if you query with any of these variables observed, or if you " "query with other unobserved variables that are not declared latent." ) else: _yellow_notice( f"{self.name} Warning:\nYou are using variational inference without " "declaring unobservable variables. The engine might not " "behave as expected." ) _yellow_notice( f"{self.name} Warning:\nContract — pass all variables in `query`: observed " "variables with tensor values, latent variables absent or set to None. " "`evidence` is still accepted and merged (`query` values take priority)." )
def __repr__(self) -> str: return self._format_repr( latents=list(self._latent_names), n_samples=self.n_samples, max_plate_nesting=self.max_plate_nesting, initial_temperature=self.initial_temperature, annealing=self.annealing, annealing_rate=self.annealing_rate, ) # ------------------------------------------------------------------ @classmethod def _build_guides( cls, pgm: BayesianNetwork, latents: Dict[str, ParametricCPD], ) -> List[str]: if not latents: pgm.guides = nn.ModuleDict() return [] if not isinstance(latents, dict): raise TypeError( f"{cls.__name__}: `latents` must be a dict mapping latent " "variable names to ParametricCPD instances, " f"got {type(latents).__name__}." ) all_var_names = {v.name for v in pgm.variables.values()} latent_set = set(latents.keys()) for lat_name, cpd in latents.items(): if lat_name not in all_var_names: raise ValueError( f"{cls.__name__}: unknown latent variable name {lat_name!r}." ) if not isinstance(cpd, ParametricCPD): raise TypeError( f"{cls.__name__}: guide for {lat_name!r} must be a " f"ParametricCPD, got {type(cpd).__name__}." ) if cpd.variable.name != lat_name: raise ValueError( f"{cls.__name__}: guide CPD variable name " f"{cpd.variable.name!r} does not match the dict key {lat_name!r}." ) for p in cpd.parents: if p.name not in all_var_names: raise ValueError( f"{cls.__name__}: guide for {lat_name!r}: parent " f"{p.name!r} is not a variable of the PGM." ) if p.name in latent_set: raise ValueError( f"{cls.__name__}: guide for {lat_name!r} cannot " f"condition on latent variable {p.name!r}." ) if pgm.variables[p.name] is not p: raise ValueError( f"{cls.__name__}: guide for {lat_name!r}: parent " f"{p.name!r} is a different Variable instance than the one " "registered in the PGM. Pass the same object." ) pgm.guides = nn.ModuleDict(latents) return list(latents.keys()) # ------------------------------------------------------------------ @property def latent_names(self) -> List[str]: return list(self._latent_names) @property def guide_conditioning(self) -> Dict[str, List[str]]: return { name: [p.name for p in self.pgm.guides[name].parents] for name in self._latent_names } @property def temperature(self) -> torch.Tensor: return self._temperature def step(self) -> None: self._step += 1 self._temperature.fill_(float(self._schedule(self._step))) # ------------------------------------------------------------------ def _merge_observables( self, query: Dict[str, Optional[torch.Tensor]], evidence: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: data: Dict[str, torch.Tensor] = dict(evidence) for name, val in query.items(): if val is not None: data[name] = val return data def _align_param_keys( self, params: Dict[str, Dict[str, torch.Tensor]], use_guides: bool = False, ) -> Dict[str, Dict[str, torch.Tensor]]: """Relabel ``'probs'`` ↔ ``'logits'`` to match each CPD's original key. ``dist_to_params`` always extracts ``probs`` from relaxed discrete distributions (since Pyro's internal representation is logits and Pyro reconstructs distribution objects during tracing). This method converts the extracted key back to what the user originally wrote in their ``parametrization`` dict, so ``guide_params['Y']`` contains ``'logits'`` when the guide CPD was built with ``{'logits': ...}``. """ aligned = {} for name, pdict in params.items(): cpd = (self.pgm.guides[name] if use_guides and name in self.pgm.guides else self.pgm.factors[name] if not use_guides else None) if cpd is None: aligned[name] = pdict continue cpd_keys = set(cpd.parametrization.keys()) pdict = dict(pdict) # shallow copy — don't mutate caller's dict if "logits" in cpd_keys and "probs" in pdict and "logits" not in pdict: probs = pdict.pop("probs") pdict["logits"] = torch.log(probs.clamp(min=1e-8)) - torch.log( (1.0 - probs).clamp(min=1e-8) ) elif "probs" in cpd_keys and "logits" in pdict and "probs" not in pdict: pdict["probs"] = torch.sigmoid(pdict.pop("logits")) aligned[name] = pdict return aligned # ------------------------------------------------------------------ def query( self, query: Union[List[str], Dict[str, Optional[torch.Tensor]], None] = None, evidence: Dict[str, torch.Tensor] = None, layer_kwargs: Dict[str, Dict] = {}, ) -> InferenceOutput: """Run variational inference and return model and guide parameters.""" _, _, poutine = _import_pyro() if query is None: query = {} if evidence is None: evidence = {} query = self._normalize_query(query) self._validate_containers(query, evidence) data = self._merge_observables(query, evidence) supplied_latents = [name for name in self._latent_names if name in data] if supplied_latents and not self._warned_latent_evidence: _yellow_notice( f"{self.name} Warning:\nDeclared latent variables were supplied as " f"observations: {supplied_latents}. Variational inference is " "guaranteed only when declared latent variables match the " "unobserved variables." ) self._warned_latent_evidence = True non_latent_missing = [ v.name for v in self.pgm.variables.values() if v.name not in self._latent_names and v.name not in data ] if non_latent_missing: _yellow_notice( f"{self.name} Warning:\nThe following non-latent variables were not " f"supplied: {non_latent_missing}. They will be sampled from " "the respective distributions; we cannot guarantee the result." ) temperature = self.temperature latent_names = self._latent_names if self.pgm.has_guides: guide_fn = lambda: self.guide_fn(data, temperature, latent_names, layer_kwargs) guide_tr = poutine.trace(guide_fn).get_trace() model_fn = lambda: self.model_fn(data, temperature, latent_names, layer_kwargs=layer_kwargs) replayed = poutine.replay(model_fn, trace=guide_tr) model_tr = poutine.trace(replayed).get_trace() guide_params = self._align_param_keys( trace_to_params(guide_tr), use_guides=True ) else: model_fn = lambda: self.model_fn(data, temperature, latent_names, layer_kwargs=layer_kwargs) model_tr = poutine.trace(model_fn).get_trace() guide_params = {} model_params = self._expose_members( self._align_param_keys(trace_to_params(model_tr), use_guides=False), set(query), ) return InferenceOutput(params=model_params, guide_params=guide_params)