Source code for torch_concepts.nn.modules.mid.models.bayesian_network

"""
BayesianNetwork: a directed factor graph wiring a list of ``Variable``s to a list of ``ParametricCPD``s.
"""

from __future__ import annotations

from collections import defaultdict, deque
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn

from .cpd import ParametricCPD
from .probabilistic_model import ProbabilisticModel
from .variable import Variable

[docs] class BayesianNetwork(ProbabilisticModel): """Directed factor graph wiring a list of ``Variable``s to a list of ``ParametricCPD``s. Validates the structure (one factor per variable, no duplicate names, every parent reference resolves, DAG only), runs a topological sort, and stores the graph ready for inference engines. Parameters ---------- variables : list of Variable All random variables in the graph. factors : list of ParametricCPD One CPD per variable (in any order). """
[docs] def __init__( self, variables: List[Variable], factors: List[ParametricCPD], ): super().__init__(variables) # registers self.variables (dict), self.guides # ---- factors -------------------------------------------------------- if len(factors) != len(variables): raise ValueError( f"Got {len(variables)} variables but {len(factors)} factors; " "exactly one factor per variable is required." ) # ``_factors`` maps {variable name: ParametricCPD}; the key is taken from # each child ``f.variable.name``. # Exposed through the ``factors`` property (the abstract contract). self._factors: nn.ModuleDict = nn.ModuleDict() for f in factors: if f.variable.name in self._factors: raise ValueError(f"Duplicate factor for variable {f.variable.name!r}.") if f.variable.name not in self.variables: raise ValueError( f"Factor name {f.variable.name!r} has no matching Variable." ) self._factors[f.variable.name] = f # Validate parent references against the variables table and dedup. self._validate_graph() # Variables sorted in topological order. self.sorted_variables: List[Variable] = self._topological_sort() # Cache for topological levels (computed lazily on first access). self._levels_cache: Optional[List[List[Variable]]] = None
@property def factors(self) -> nn.ModuleDict: """Mapping ``{child variable name: ParametricCPD}`` (one CPD per variable).""" return self._factors @property def levels(self) -> List[List[Variable]]: """Variables grouped by topological depth. ``levels[d]`` is the list of variables whose longest path from any root is exactly ``d``. All variables in ``levels[d]`` are mutually independent given ``levels[0..d-1]``, so their CPDs can be evaluated in parallel within a level. The result is cached after the first call. The DAG is immutable after construction, so this is safe. """ if self._levels_cache is not None: return self._levels_cache depth: Dict[str, int] = {} for v in self.sorted_variables: parents = self._factors[v.name].parents depth[v.name] = 0 if not parents else 1 + max( depth[p.plate.name] for p in parents ) groups: Dict[int, List[Variable]] = defaultdict(list) for v in self.sorted_variables: groups[depth[v.name]].append(v) self._levels_cache = [groups[d] for d in sorted(groups)] return self._levels_cache # ----------------------------------------------------------- validate def _validate_graph(self) -> None: """Validate every CPD's parents against ``self.variables`` and dedup. Each parent must be the exact same object as the one registered in ``self.variables``; a same-name-but-different-instance parent is rejected with ``ValueError``. After validation, each CPD's ``parents`` list is replaced by an order-preserving deduplicated copy to guard against accidental repetition. """ for name, f in self._factors.items(): for p in f.parents: if not isinstance(p, Variable): # Defensive: CPD.__init__ already enforces this, but the # user could mutate `f.parents` between construction and # registration. raise TypeError( f"Factor {name!r}: parent must be a Variable, " f"got {type(p).__name__}." ) plate = getattr(p, "_plate", None) if plate is not None: # Member handle (``plate.member(name)``): the edge depends on a # single member; validate the plate is registered and owns it. if self.variables.get(plate.name) is not plate: raise ValueError( f"Factor {name!r}: parent {p.name!r} is a member of plate " f"{plate.name!r}, which is not the registered variable. " "Pass the same plate object via `variables`." ) if p.name not in plate.members: raise ValueError( f"Factor {name!r}: {plate.name!r} has no member {p.name!r}." ) continue if p.name not in self.variables: raise ValueError( f"Factor {name!r}: parent {p.name!r} not in variables list." ) if self.variables[p.name] is not p: raise ValueError( f"Factor {name!r}: parent {p.name!r} is a different " "Variable instance than the one registered in " "`variables`. Pass the same object." ) # Order-preserving dedup by identity. f.parents = list({id(p): p for p in f.parents}.values()) # ----------------------------------------------------------- topo sort def _topological_sort(self) -> List[Variable]: indeg: Dict[str, int] = {name: 0 for name in self.variables} children: Dict[str, List[str]] = defaultdict(list) for name, f in self._factors.items(): for p in f.parents: indeg[name] += 1 children[p.plate.name].append(name) queue = deque([n for n, d in indeg.items() if d == 0]) out: List[Variable] = [] while queue: n = queue.popleft() out.append(self.variables[n]) for c in children[n]: indeg[c] -= 1 if indeg[c] == 0: queue.append(c) if len(out) != len(self.variables): raise ValueError( "BayesianNetwork: variables/factors form a cycle; the graph " "must be a DAG." ) return out