Interpretable Probabilistic Graphical Models

The Mid-Level API lets you describe any interpretable deep learning model as a probabilistic graphical model (PGM): a set of random variables connected by factors. Different probabilistic inferences can be performed on the PGM. It is the right entry point if you think in terms of probabilistic or causal models.

Overview of the PyC Mid-Level API

A mid-level model is assembled from four building blocks:

  • Variables — the random variables (concepts, embeddings) that make up the model.

  • Factors — conditional distributions or potential functions wiring variable together.

  • ProbabilisticModel — the container that collects variables and factors into a PGM.

  • Inference — engines that answer queries over the model.

Expand each block below for an explanation and an example. The running example builds a Concept Bottleneck Model input latent concepts task as a probabilistic model.

Variables

A Variable is a random variable in the model. Two concrete kinds are provided:

  • ConceptVariable — an interpretable, named variable (a concept or a task). A single call can declare several concepts at once.

  • EmbeddingVariable — a vector-valued node (e.g. the raw input or a latent representation), typically given a Delta distribution.

A variable is defined by its name, its distribution (torch.distributions), and its size (scalar) or shape (multi-dimensional, e.g. an image tensor).

import torch_concepts as pyc
from torch.distributions import Bernoulli, OneHotCategorical, Normal
from torch_concepts import EmbeddingVariable, ConceptVariable
from torch_concepts.distributions import Delta

input_var = EmbeddingVariable("input", distribution=Delta, shape=(3, 224, 224))  # RGB image
latent_var = EmbeddingVariable("latent", distribution=Delta, size=64)
smoking  = ConceptVariable("smoking",  distribution=Bernoulli)
genotype = ConceptVariable("genotype", distribution=OneHotCategorical, size=3)
tar      = ConceptVariable("tar",      distribution=Normal)
cancer   = ConceptVariable("cancer",   distribution=Bernoulli)

Passing a list of names creates independent variables (one node per concept in the graph). To group several concepts under a single node — one factor covers all of them — use a plate variable with members:

# One graph node; members are addressed individually for fine-grained parent wiring
binary_concepts = ConceptVariable("binary_concepts", distribution=Bernoulli,
                                  members=["smoking", "cancer"])

smoking_handle = binary_concepts.member("smoking")   # parent handle for downstream factors
Factors

A factor encodes a relationship over a set of variables and it is the building block that gives a PGM its structure. The abstract base class is ParametricFactor; different subclasses encode different kinds of relationship (conditional distributions for directed graphs, potential functions for undirected ones, etc.).

Currently, the only implemented factor is ParametricCPD — a conditional probability distribution p(variable | parents) parameterised by a pyc_logo PyC or pytorch_logo PyTorch module. Support for undirected factors (potentials) is planned for the near future.

Each ParametricCPD declares its parents, which is exactly how the directed graph structure is defined. The parametrization dict keys must match the distribution’s constructor arguments (e.g., logits for a Bernoulli distribution). Parent-less (root) variables use a prior such as LearnablePrior.

from torch_concepts.nn import ParametricCPD, LearnablePrior, Sequential, LinearConceptToConcept, LinearEmbeddingToConcept
import torch.nn as nn

# Input —> root, generally provided as evidence at inference time
input_cpd = ParametricCPD(
     input_var,
     parents=[],
     parametrization=LearnablePrior(size=1)
 )

# Latent | Input —> Delta (deterministic backbone)
latent_cpd = ParametricCPD(
    latent_var,
    parents=[input_var],
    parametrization=nn.Sequential(nn.Flatten(), nn.Linear(3 * 224 * 224, 64), nn.ReLU())
)

# Genotype | Latent —> OneHotCategorical, parametrize logits
genotype_cpd = ParametricCPD(
    genotype,
    parents=[latent_var],
    parametrization={'logits': LinearEmbeddingToConcept(in_embeddings=64, out_concepts=3)},
)

# Smoking | Genotype —> Bernoulli, parametrize logits
smoking_cpd = ParametricCPD(
    smoking,
    parents=[genotype],
    parametrization={'logits': LinearConceptToConcept(in_concepts=3, out_concepts=1)},
)

# Tar | Smoking —> Normal, both loc and scale must be parametrized; scale must be positive
tar_cpd = ParametricCPD(
    tar,
    parents=[smoking],
    parametrization={
        'loc':   LinearConceptToConcept(in_concepts=1, out_concepts=1),
        'scale': Sequential(LinearConceptToConcept(in_concepts=1, out_concepts=1), nn.Softplus()),
    },
)

# Cancer | Genotype, Tar —> Bernoulli, parametrize logits
cancer_cpd = ParametricCPD(
    cancer,
    parents=[genotype, tar],
    parametrization={'logits': LinearConceptToConcept(in_concepts=4, out_concepts=1)},
)
ProbabilisticModel

ProbabilisticModel is the abstract base for probabilistic graphical models. The concrete class you instantiate is BayesianNetwork, a directed model that wires a list of variables to a list of factors (one factor per variable).

from torch_concepts.nn import BayesianNetwork

model = BayesianNetwork(
    variables=[input_var, latent_var, genotype, smoking, tar, cancer],
    factors=[input_cpd, latent_cpd, genotype_cpd, smoking_cpd, tar_cpd, cancer_cpd],
)
Inference

An inference engine answers queries of the form “give me these variables, given this evidence” via engine.query(query, evidence). pyc_logo PyC ships several engines:

query accepts a list of variable names (predict them) or a dict mapping names to observed values (clamp them as evidence, e.g. for teacher forcing during training). The result exposes per-variable distribution parameters in out.params[name] and samples (when applicable) in out.samples[name].

from torch_concepts.nn import DeterministicInference

inference = DeterministicInference(model, activate_before_propagation=True)

x = torch.randn(16, 3, 224, 224)
out = inference.query(query=["genotype", "smoking", "tar", "cancer"], evidence={'input': x})

genotype_logits = out.params['genotype']['logits']   # (16, 3)
cancer_logits   = out.params['cancer']['logits']     # (16, 1)
Putting it Together: Concept Bottleneck Model

The blocks above assemble into a full CBM. During training, pass observed concept and task values as the query dict — they are clamped as evidence for teacher forcing:

import torch
from torch.distributions import Bernoulli
from torch_concepts import EmbeddingVariable, ConceptVariable
from torch_concepts.distributions import Delta
from torch_concepts.nn import (
    LinearEmbeddingToConcept, LinearConceptToConcept,
    ParametricCPD, BayesianNetwork, DeterministicInference, LearnablePrior,
)

x_var = EmbeddingVariable("x", distribution=Delta, size=16)
c_var = ConceptVariable(["c1", "c2"], distribution=Bernoulli)
y_var = ConceptVariable("y", distribution=Bernoulli)

model = BayesianNetwork(
    variables=[x_var, *c_var, y_var],
    factors=[
        ParametricCPD(x_var, parents=[], parametrization=LearnablePrior(size=1)),
        ParametricCPD(c_var, parents=[x_var],
                      parametrization={'logits': LinearEmbeddingToConcept(16, out_concepts=1)}),
        ParametricCPD(y_var, parents=[*c_var],
                      parametrization={'logits': LinearConceptToConcept(2, out_concepts=1)}),
    ],
)
inference = DeterministicInference(model, activate_before_propagation=True)

# At training time, pass observed labels as query to clamp them as evidence
x = torch.randn(32, 16)
c_true = torch.randint(0, 2, (32, 2)).float()
y_true = torch.randint(0, 2, (32, 1)).float()
pred = inference.query(query={"c1": c_true[:, 0], "c2": c_true[:, 1], "y": y_true},
                       evidence={"x": x})

Next Steps