Interpretable Probabilistic Models¶
PyC can be used to build interpretable concept-based probabilisitc models.
Warning
This API is still under development and interfaces might change in future releases.
Design principles¶
Probabilistic Models¶
At this API level, models are represented as probabilistic models where:
Variableobjects represent random variables in the probabilistic model. Variables are defined by their name, parents, and distribution type. For instance we can define a list of three concepts as:concepts = pyc.EndogenousVariable( concepts=["c1", "c2", "c3"], parents=[], distribution=torch.distributions.RelaxedBernoulli )
ParametricCPDobjects represent conditional probability distributions (CPDs) between variables in the probabilistic model and are parameterized byPyC layers. For instance we can define a list of three parametric CPDs for the above concepts as:
concept_cpd = pyc.nn.ParametricCPD( concepts=["c1", "c2", "c3"], parametrization=pyc.nn.LinearZC(in_features=10, out_features=3) )
ProbabilisticModelobjects are a collection of variables and CPDs. For instance we can define a model as:probabilistic_model = pyc.nn.ProbabilisticModel( variables=concepts, parametric_cpds=concept_cpd )
Inference¶
Inference is performed using efficient tensorial probabilistic inference algorithms. For instance, we can perform ancestral sampling as:
inference_engine = pyc.nn.AncestralSamplingInference(
probabilistic_model=probabilistic_model,
graph_learner=wanda,
temperature=1.
)
predictions = inference_engine.query(["c1"], evidence={'input': x})
Detailed Guides¶
Interpretable Probabilistic Models
Import Libraries
Start by importing PyC and
PyTorch:
import torch
import torch_concepts as pyc
Create Sample Data
batch_size = 16
input_dim = 64
x = torch.randn(batch_size, input_dim)
Define Variables and Graph Structure
Variables represent random variables in the probabilistic model. To define a variable, specify its name, parents, and distribution type. By specifying parents, we define the graph structure of the model.
# Define input variable
input_var = pyc.InputVariable(
concepts=["input"],
parents=[],
)
# Define concept variables
concepts = pyc.EndogenousVariable(
concepts=["round", "smooth", "bright"],
parents=["input"],
distribution=torch.distributions.RelaxedBernoulli
)
# Define task variables
tasks = pyc.EndogenousVariable(
concepts=["class_A", "class_B"],
parents=["round", "smooth", "bright"],
distribution=torch.distributions.RelaxedBernoulli
)
Define ParametricCPDs
ParametricCPDs are conditional probability distributions parameterized by PyC or
PyTorch layers.
Define a ParametricCPD for each variable based on its parents.
# ParametricCPD for input (no parents)
input_factor = pyc.nn.ParametricCPD(
concepts=["input"],
parametrization=torch.nn.Identity()
)
# ParametricCPD for concepts (from input)
concept_cpd = pyc.nn.ParametricCPD(
concepts=["round", "smooth", "bright"],
parametrization=pyc.nn.LinearZC(
in_features=input_dim,
out_features=1
)
)
# ParametricCPD for tasks (from concepts)
task_cpd = pyc.nn.ParametricCPD(
concepts=["class_A", "class_B"],
parametrization=pyc.nn.LinearCC(
in_features_endogenous=3,
out_features=1
)
)
Build Concept-based Probabilistic Model
A concept-based probabilistic model is defined by collecting all variables and their corresponding ParametricCPDs.
# Create the probabilistic model
prob_model = pyc.nn.ProbabilisticModel(
variables=[input_var, *concepts, *tasks],
parametric_cpds=[input_factor, *concept_cpd, *task_cpd]
)
Probabilistic Inference
Deterministic Inference
We can perform deterministic inference by querying the model for concept and task predictions given input evidence:
# Create inference engine
inference_engine = pyc.nn.DeterministicInference(
probabilistic_model=prob_model,
)
# Query concept predictions
concept_predictions = inference_engine.query(
query_concepts=["round", "smooth", "bright"],
evidence={'input': x}
)
# Query task predictions given concepts
task_predictions = inference_engine.query(
query_concepts=["class_A", "class_B"],
evidence={
'input': x,
'round': concept_predictions[:, 0],
'smooth': concept_predictions[:, 1],
'bright': concept_predictions[:, 2]
}
)
print(f"Concept predictions: {concept_predictions}")
print(f"Task predictions: {task_predictions}")
Ancestral Sampling
While deterministic inference is the standard approach in deep learning, PyC also supports probabilistic inference methods.
For instance, we can perform ancestral sampling to obtain predictions by sampling from each variable’s distribution:
# Create inference engine
inference_engine = pyc.nn.AncestralSamplingInference(
probabilistic_model=prob_model,
temperature=1.0
)
# Query concept predictions
concept_predictions = inference_engine.query(
query_concepts=["round", "smooth", "bright"],
evidence={'input': x}
)
# Query task predictions given concepts
task_predictions = inference_engine.query(
query_concepts=["class_A", "class_B"],
evidence={
'input': x,
'round': concept_predictions[:, 0],
'smooth': concept_predictions[:, 1],
'bright': concept_predictions[:, 2]
}
)
print(f"Concept predictions: {concept_predictions}")
print(f"Task predictions: {task_predictions}")
Interventions
We can perform interventions on specific concepts to observe their effects on other variables, similarly to how interventions are performed using low-level APIs.
from torch_concepts.nn import DoIntervention, UniformPolicy
from torch_concepts.nn import intervention
strategy = DoIntervention(model=prob_model.parametric_cpds, constants=100.0)
policy = UniformPolicy(out_features=prob_model.concept_to_variable["round"].size)
original_predictions = inference_engine.query(
query_concepts=["round", "smooth", "bright", "class_A", "class_B"],
evidence={'input': x}
)
# Apply intervention to encoder
with intervention(
policies=policy,
strategies=strategy,
target_concepts=["round", "smooth"]
):
intervened_predictions = inference_engine.query(
query_concepts=["round", "smooth", "bright", "class_A", "class_B"],
evidence={'input': x}
)
print(f"Original endogenous: {original_predictions[0]}")
print(f"Intervened endogenous: {intervened_predictions[0]}")
Next Steps¶
Explore the full Mid-Level API documentation
Try the High-Level API for out-of-the-box models
Learn about probabilistic inference methods