High-level API¶
High-level API models allow you to quickly build and train concept-based models using pre-configured components and minimal code.
Documentation¶
Design principles¶
Annotations¶
Annotations define the structure of concepts and tasks in your model by describing their types, cardinalities, and distributions.
Basic Annotation Structure
Annotations consist of axis annotations that describe variables along a dimension:
import torch_concepts as pyc
from torch.distributions import Bernoulli, Categorical
# Define concepts and tasks
labels = ["is_round", "is_smooth", "color", "class_A", "class_B"]
cardinalities = [1, 1, 3, 1, 1] # binary, binary, categorical(3), binary, binary
# Metadata with types and distributions
metadata = {
'is_round': {'type': 'discrete', 'distribution': Bernoulli},
'is_smooth': {'type': 'discrete', 'distribution': Bernoulli},
'color': {'type': 'discrete', 'distribution': Categorical},
'class_A': {'type': 'discrete', 'distribution': Bernoulli},
'class_B': {'type': 'discrete', 'distribution': Bernoulli}
}
annotations = pyc.Annotations({
1: pyc.AxisAnnotation(
labels=labels,
cardinalities=cardinalities,
metadata=metadata
)
})
GroupConfig for Automatic Configuration
For models with many concepts, use GroupConfig to automatically assign configurations based on concept type:
from torch_concepts import GroupConfig
# Define annotations without individual distributions
annotations = pyc.Annotations({
1: pyc.AxisAnnotation(
labels=["is_round", "is_smooth", "color", "shape"],
cardinalities=[1, 1, 3, 4],
metadata={
'is_round': {'type': 'discrete'}, # binary (card=1)
'is_smooth': {'type': 'discrete'}, # binary (card=1)
'color': {'type': 'discrete'}, # categorical (card=3)
'shape': {'type': 'discrete'} # categorical (card=4)
}
)
})
# Automatically assign distributions by type
variable_distributions = GroupConfig(
binary=Bernoulli, # for cardinality=1
categorical=Categorical # for cardinality>1
)
This approach scales efficiently to datasets with hundreds of concepts (e.g., CUB-200 with 312 attributes).
Out-of-the-box Models¶
PyC provides ready-to-use models that can be instantiated with minimal configuration:
Concept Bottleneck Model (CBM)
A CBM learns interpretable concept representations and uses them to predict tasks:
from torch_concepts.nn import ConceptBottleneckModel
model = ConceptBottleneckModel(
input_size=2048, # e.g., ResNet feature dimension
annotations=annotations,
task_names=['class_A', 'class_B'],
variable_distributions=distributions, # Optional: GroupConfig or dict
latent_encoder_kwargs={
'hidden_size': 128,
'n_layers': 2,
'activation': 'relu',
'dropout': 0.1
}
)
BlackBox Model
A standard neural network for comparison baselines:
from torch_concepts.nn import BlackBox
model = BlackBox(
input_size=2048,
annotations=annotations,
task_names=['class_A', 'class_B'],
latent_encoder_kwargs={
'hidden_size': 256,
'n_layers': 3
}
)
Losses and Metrics¶
Configure losses and metrics using GroupConfig to automatically handle mixed concept types:
Concept Loss
import torch.nn as nn
from torch_concepts.nn import ConceptLoss
from torch_concepts import GroupConfig
# Different loss functions for different concept types
loss_config = GroupConfig(
binary=nn.BCEWithLogitsLoss(),
categorical=nn.CrossEntropyLoss()
)
concept_loss = ConceptLoss(
annotations=annotations,
fn_collection=loss_config
)
Concept Metrics
from torch_concepts.nn import ConceptMetrics
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy
# Different metrics for different concept types
metrics_config = GroupConfig(
binary={'accuracy': BinaryAccuracy()},
categorical={'accuracy': MulticlassAccuracy}
)
concept_metrics = ConceptMetrics(
annotations=annotations,
fn_collection=metrics_config,
summary_metrics=True, # Compute average across concepts
perconcept_metrics=True # Compute per-concept metrics
)
Training Modes¶
High-level models support two training approaches:
Manual PyTorch Training
import torch.optim as optim
model = ConceptBottleneckModel(input_size=64, annotations=annotations,
task_names=['class_A'])
optimizer = optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()
for epoch in range(100):
optimizer.zero_grad()
predictions = model(x, query=['is_round', 'is_smooth', 'class_A'])
loss = loss_fn(predictions, targets)
loss.backward()
optimizer.step()
PyTorch Lightning Training
from pytorch_lightning import Trainer
# Model with integrated loss and optimizer
model = ConceptBottleneckModel(
input_size=64,
annotations=annotations,
task_names=['class_A'],
loss=concept_loss,
metrics=concept_metrics,
optim_class=torch.optim.AdamW,
optim_kwargs={'lr': 0.001}
)
trainer = Trainer(max_epochs=100)
trainer.fit(model, datamodule)
Querying Models¶
High-level models support flexible querying of concepts and tasks:
model.eval()
with torch.no_grad():
# Query specific variables
concepts = model(x, query=['is_round', 'is_smooth', 'color'])
# Query tasks only
tasks = model(x, query=['class_A', 'class_B'])
# Query everything
all_predictions = model(x, query=['is_round', 'is_smooth',
'color', 'class_A', 'class_B'])