Source code for torch_concepts.nn.modules.outputs
"""Output containers for PGM inference engines."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch
# ---------------------------------------------------------------------------
# Parameter-dict type alias
# ---------------------------------------------------------------------------
ParamDict = Dict[str, torch.Tensor]
# ---------------------------------------------------------------------------
# params -> logits assembly
# ---------------------------------------------------------------------------
# FIXME: this is a bit of a hack, but it works for now. We must make the ModelOutput
# and InferenceOutput classes more flexible in the future. Storing the tensors and
# provide utilities for per-concept views.
def logits_from_params(
params: Dict[str, ParamDict],
keys: Optional[List[str]] = None,
) -> Optional[torch.Tensor]:
"""Concatenate per-variable ``'logits'`` tensors from an output's ``params``.
The single place the library turns the per-variable parameter dict produced
by inference into the flat ``(batch, sum_cardinalities)`` logits tensor that
losses and metrics consume.
Parameters
----------
params : dict[str, ParamDict]
Per-variable parameter dicts (e.g. ``{'c1': {'logits': ...}, ...}``).
keys : list[str], optional
Variable names to assemble, in order. When ``None`` (default), every
variable that carries a ``'logits'`` entry is used, in insertion order.
Returns
-------
torch.Tensor or None
Concatenated logits along the last dim, or ``None`` when no queried
variable carries logits.
"""
if keys is None:
keys = [n for n, p in params.items() if isinstance(p, dict) and 'logits' in p]
parts = [params[n]['logits'] for n in keys]
return torch.cat(parts, dim=-1) if parts else None
# ---------------------------------------------------------------------------
# InferenceOutput
# ---------------------------------------------------------------------------
[docs]
@dataclass
class InferenceOutput:
"""Return value of every inference engine.
Attributes
----------
params : dict[str, ParamDict]
Per-variable named parameter tensors of the model-side distribution
(e.g. ``{'c': {'probs': ...}}``).
guide_params : dict[str, ParamDict]
Per-latent named parameter tensors of the variational guide.
samples : dict[str, torch.Tensor]
Per-variable sampled values.
probabilities : torch.Tensor or None
Joint conditional probabilities for a fully realised query batch.
"""
params: Dict[str, ParamDict] = field(default_factory=dict)
guide_params: Dict[str, ParamDict] = field(default_factory=dict)
samples: Dict[str, torch.Tensor] = field(default_factory=dict)
probabilities: Optional[torch.Tensor] = None
[docs]
@dataclass
class ModelOutput:
"""Structured output from a high-level model's ``forward()`` method.
Attributes
----------
params : dict[str, ParamDict]
Per-variable named parameter tensors of the model-side distribution
(e.g. ``{'c': {'probs': ...}}``).
guide_params : dict[str, ParamDict]
Per-latent named parameter tensors of the variational guide.
samples : dict[str, torch.Tensor]
Per-variable sampled values.
probabilities : torch.Tensor or None
Joint conditional probabilities for a fully realised query batch.
"""
params: Dict[str, ParamDict] = field(default_factory=dict)
guide_params: Dict[str, ParamDict] = field(default_factory=dict)
samples: Dict[str, torch.Tensor] = field(default_factory=dict)
probabilities: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None # FIXME: to be removed
target: Optional[torch.Tensor] = None
extra: Optional[Dict[str, torch.Tensor]] = None