Source code for torch_concepts.nn.modules.mid.models.factor

"""
Abstract class for PGM factors.
"""

from __future__ import annotations

import inspect
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, List, Set, Tuple, Union

import torch
import torch.nn as nn
from .variable import Variable
from ...low.lazy import LazyConstructor


# Known PyC parameter-name combinations
_PYC_PARAM_SETS = [
    {'concepts'},
    {'embeddings'},
    {'concepts', 'embeddings'},
]


def _cat_parents(inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Concatenate parent values along the last dim, preserving their shape.

    No flattening or reshaping is performed: every parent value keeps its full
    event shape and the tensors are concatenated along ``dim=-1``. This
    deliberately raises when the values have mismatched non-concatenation
    dimensions (e.g. a matrix-valued parent alongside a vector-valued one) —
    such combinations are ambiguous and must be resolved with a custom
    ``aggregate``. Values are cast to floating point so discrete parents can
    feed float layers, but their shape is left untouched.
    """
    vals = [
        v.float() if not v.is_floating_point() else v
        for v in inputs.values()
    ]
    return torch.cat(vals, dim=-1)


def _module_input_names(mod: nn.Module) -> Set[str]:
    """Return the explicit keyword/positional parameter names of ``mod.forward``.

    A PyC :class:`~torch_concepts.nn.Sequential` forwards its inputs straight to
    its first layer, so its input signature *is* that first layer's.
    """
    from ...low.sequential import Sequential

    while isinstance(mod, Sequential) and len(mod) > 0:
        mod = mod[0]
    sig = inspect.signature(mod.forward)
    return {
        name
        for name, p in sig.parameters.items()
        if name != "self"
        and p.kind in (
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
            inspect.Parameter.KEYWORD_ONLY,
        )
    }


[docs] class ParametricFactor(nn.Module, ABC): """Abstract class for factors parameterised by torch.nn.Module. Concrete factor types (directed: :class:`ParametricCPD`; undirected: ``ParametricPotential``) must subclass this and implement :meth:`forward`. Subclasses call ``super().__init__(parametrization, aggregate)`` to store: - ``self.parametrization`` — an ``nn.ModuleDict`` mapping parameter names to ``nn.Module`` instances. - ``self._module_signatures`` — cached ``forward`` param-name sets, one per module in ``parametrization`` (computed once at construction time). - ``self._aggregators`` — per-parameter aggregation callables, resolved at construction time. Standard modules get :meth:`_standard_aggregate`; PyC modules get :meth:`_pyc_aggregate`; user-supplied callables override. ``aggregate`` accepts: - ``None`` — auto-select :meth:`_standard_aggregate` or :meth:`_pyc_aggregate` per module based on its ``forward`` signature. - A single ``Callable`` — use it for every parameter module. - A ``Dict[str, Callable]`` — use the keyed callable for the matching parameter module; auto-select the default for any missing key. A user-supplied aggregate is called with a signature that matches the parameter module's kind: for a **PyC** module it receives the parent values already split by type — ``agg(concepts, embeddings)``, each a ``Dict[Variable, Tensor]`` — and must return the ``{'concepts': ..., 'embeddings': ...}`` dict the module expects; for a **standard** module it receives the single ``agg(inputs)`` dict and returns one concatenated tensor. See :meth:`_resolve_aggregator`. """
[docs] def __init__( self, parametrization: Dict[str, nn.Module], aggregate: Optional[ Union[ Callable, Dict[str, Callable], ] ] = None, ): super().__init__() parametrization = self._initialize_parametrization(parametrization) # Cache each module's forward parameter names once at construction time. self._module_signatures: Dict[str, Set[str]] = { pname: _module_input_names(mod) for pname, mod in parametrization.items() } # Normalise the user input to one entry per parameter (``None`` = use # the auto-selected default), then adapt each to the uniform # ``inputs -> result`` call site used by the CPD's forward. if aggregate is None: per_param: Dict[str, Optional[Callable]] = {pname: None for pname in parametrization} elif callable(aggregate): per_param = {pname: aggregate for pname in parametrization} elif isinstance(aggregate, dict): bad = [k for k, v in aggregate.items() if not callable(v)] if bad: raise TypeError( f"ParametricFactor: aggregate dict contains non-callable " f"values for keys {bad}." ) per_param = {pname: aggregate.get(pname) for pname in parametrization} else: raise TypeError( "ParametricFactor: `aggregate` must be None, a callable, or a " f"dict mapping parameter names to callables, got {type(aggregate).__name__}." ) self._aggregators: Dict[str, Callable] = { pname: self._resolve_aggregator(pname, agg) for pname, agg in per_param.items() } self.parametrization = parametrization
def _initialize_parametrization( self, parametrization: Dict[str, nn.Module], ) -> nn.ModuleDict: """Normalise ``parametrization`` into an ``nn.ModuleDict``. Accepts a plain dict (or an existing ``nn.ModuleDict``) mapping each parameter name to a ready ``nn.Module``. Concrete subclasses resolve any :class:`LazyConstructor` entries before calling ``super().__init__`` — the input/output sizes a lazy layer needs come from the factor's variables, which only the subclass knows (see :meth:`ParametricCPD._instantiate_lazy`). As a safeguard, an already-built ``LazyConstructor`` is unwrapped to its concrete module. """ modules: Dict[str, nn.Module] = {} for pname, module in parametrization.items(): if isinstance(module, LazyConstructor) and module.module is not None: module = module.module modules[pname] = module return nn.ModuleDict(modules) def _is_pyc(self, pname: str) -> bool: """Whether the parameter module follows the PyC ``concepts``/``embeddings`` calling convention (vs. a standard single-tensor module).""" return self._module_signatures[pname] in _PYC_PARAM_SETS # For entries not covered by the user, pick _pyc_aggregate or # _standard_aggregate based on the cached module signature. def _select_default(self, pname: str) -> Callable: return self._pyc_aggregate if self._is_pyc(pname) else self._standard_aggregate def _resolve_aggregator( self, pname: str, user_aggregate: Optional[Callable], ) -> Callable: """Adapt an aggregate to the uniform ``inputs -> result`` call site. ``None`` selects the auto-chosen default (:meth:`_select_default`). A user-supplied aggregate is dispatched by the parameter module's kind: - **PyC** module — called as ``agg(concepts, embeddings)`` over the type-split parent dicts (each ``Dict[Variable, Tensor]``); it returns the ``{'concepts': ..., 'embeddings': ...}`` dict the module expects. - **standard** module — called as ``agg(inputs)`` over the single parent dict; it returns one concatenated tensor. """ if user_aggregate is None: return self._select_default(pname) if self._is_pyc(pname): def aggregator(inputs, _agg=user_aggregate): concepts, embeddings = self._split_by_type(inputs) return _agg(concepts, embeddings) return aggregator return user_aggregate def _standard_aggregate( self, inputs: Dict[Variable, torch.Tensor], ) -> torch.Tensor: """Default aggregation for standard torch modules. Concatenates the parent values along the last dim without reshaping (see :func:`_cat_parents`). """ return _cat_parents(inputs) def _split_by_type( self, inputs: Dict[Variable, torch.Tensor], ) -> Tuple[Dict[Variable, torch.Tensor], Dict[Variable, torch.Tensor]]: """Partition parent values into ``(concepts, embeddings)`` dicts by variable type, preserving parent order.""" concepts: Dict[Variable, torch.Tensor] = {} embeddings: Dict[Variable, torch.Tensor] = {} for p in self.parents: if p not in inputs: continue if p.variable_type == "concept": concepts[p] = inputs[p] elif p.variable_type == "embedding": embeddings[p] = inputs[p] else: raise ValueError( f"ParametricCPD({self.variable.name!r}): parent " f"{p.name!r} has invalid type {p.variable_type!r}, " "expected 'embedding' or 'concept'." ) return concepts, embeddings def _pyc_aggregate( self, inputs: Dict[Variable, torch.Tensor], ) -> Dict[str, torch.Tensor]: """Default aggregation for PyC-style modules. Splits parent values by type (:meth:`_split_by_type`) and concatenates each group along the last dim (see :func:`_cat_parents`), returning a dict with keys ``'concepts'`` and/or ``'embeddings'`` matching the module's ``forward`` signature. """ concepts, embeddings = self._split_by_type(inputs) out: Dict[str, torch.Tensor] = {} if concepts: out["concepts"] = _cat_parents(concepts) if embeddings: out["embeddings"] = _cat_parents(embeddings) return out @abstractmethod def forward( self, inputs: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Compute the factor output given its input variable values. Subclasses define the precise signature and semantics: - :class:`ParametricCPD` accepts ``parent_values`` and returns a named distribution-parameter dict (e.g. ``{"probs": ...}``). - A future ``ParametricPotential`` will accept clique variable values and return a log-potential tensor. """