Out-of-the-box Models

pyc_logo PyC provides ready-to-use models for concept-based learning with minimal configuration. Models support both manual PyTorch training and automatic pl_logo PyTorch Lightning training.

Design Principles

pyc_logo PyC out-of-the-box models handle complexity automatically:

  • Type-Aware Routing: Predictions automatically routed to correct loss and metric functions based on concept types

  • Minimal Configuration: Use GroupConfig to specify settings once per type (binary, categorical) rather than per concept

  • Flexible Training: Choose between manual PyTorch control or automatic Lightning training

Two Training Modes

Manual PyTorch Mode: Initialize without loss/optimizer for full control

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    variable_distributions=variable_distributions,
    task_names=['cancer']
)

# Write your own training loop
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
    # Your training code

Lightning Mode: Initialize with loss/optimizer for automatic training

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['cancer'],
    loss=concept_loss,         # torch loss or ConceptLoss
    metrics=concept_metrics,   # torchmetrics or ConceptMetrics
    optim_class=torch.optim.AdamW,
    optim_kwargs={'lr': 0.001}
)

# Automatic training
trainer = Trainer(max_epochs=100)
trainer.fit(model, datamodule)

Detailed Guides

Annotations

Concept and Task Metadata

Annotations store metadata about concepts including names, cardinalities, distribution types, and custom attributes. They specify the structure and properties of concepts for models, losses, and metrics.

Quick Start

from torch_concepts.annotations import AxisAnnotation, Annotations
from torch.distributions import Bernoulli, Categorical

# Define concept structure with distributions
ann = Annotations({
    1: AxisAnnotation(
        labels=['is_round', 'is_smooth', 'color', 'class_A', 'class_B'],
        cardinalities=[1, 1, 3, 1, 1],
        metadata={
            'is_round': {'type': 'discrete', 'distribution': Bernoulli},
            'is_smooth': {'type': 'discrete', 'distribution': Bernoulli},
            'color': {'type': 'discrete', 'distribution': Categorical},
            'class_A': {'type': 'discrete', 'distribution': Bernoulli},
            'class_B': {'type': 'discrete', 'distribution': Bernoulli}
        }
    )
})

Key Components

  • labels: List of concept and task names

  • cardinalities: Number of classes for each (1 for binary, >1 for categorical)

  • metadata: Dictionary with concept properties including distribution types

Distribution Assignment Methods

Distributions can be provided in three ways:

Method 1: In annotations metadata (recommended)

ann = Annotations({
    1: AxisAnnotation(
        labels=['is_round', 'color'],
        cardinalities=[1, 3],
        metadata={
            'is_round': {'type': 'discrete', 'distribution': Bernoulli},
            'color': {'type': 'discrete', 'distribution': Categorical}
        }
    )
})

# Use directly in model
model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A']
)

Method 2: Via variable_distributions dictionary

# Annotations without distributions
ann = Annotations({
    1: AxisAnnotation(
        labels=['is_round', 'color'],
        cardinalities=[1, 3],
        metadata={
            'is_round': {'type': 'discrete'},
            'color': {'type': 'discrete'}
        }
    )
})

# Provide distributions separately
variable_distributions = {
    'is_round': Bernoulli,
    'color': Categorical
}

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    variable_distributions=variable_distributions,
    task_names=['class_A']
)

Method 3: Using GroupConfig (for mixed types)

from torch_concepts import GroupConfig

# Annotations with mixed types
ann = Annotations({
    1: AxisAnnotation(
        labels=['is_round', 'is_smooth', 'color', 'shape'],
        cardinalities=[1, 1, 3, 4],
        metadata={
            'is_round': {'type': 'discrete'},   # binary (card=1)
            'is_smooth': {'type': 'discrete'},  # binary (card=1)
            'color': {'type': 'discrete'},      # categorical (card=3)
            'shape': {'type': 'discrete'}       # categorical (card=4)
        }
    )
})

# GroupConfig automatically assigns by concept type
variable_distributions = GroupConfig(
    binary=Bernoulli,      # all concepts with cardinality=1
    categorical=Categorical # all concepts with cardinality>1
)

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    variable_distributions=variable_distributions,
    task_names=['class_A']
)

Usage with Loss and Metrics

from torch_concepts.nn import ConceptLoss, ConceptMetrics
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy

# Loss configuration
loss_config = GroupConfig(
    binary=BCEWithLogitsLoss(),
    categorical=CrossEntropyLoss()
)
loss = ConceptLoss(annotations=ann, fn_collection=loss_config)

# Metrics configuration
metrics_config = GroupConfig(
    binary={'accuracy': BinaryAccuracy()},
    categorical={'accuracy': (MulticlassAccuracy, {'average': 'macro'})}
)
metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=metrics_config,
    summary_metrics=True,
    perconcept_metrics=True
)

Special Cases

Missing distributions: If distributions are not in metadata and variable_distributions is not provided, the model will raise an assertion error.

Task concepts: Concepts that are prediction targets (tasks) should be included in the annotations and specified via the task_names parameter.

Custom metadata: Add custom fields to metadata for application-specific needs:

metadata={
    'is_round': {
        'type': 'discrete',
        'distribution': Bernoulli,
        'description': 'Object has rounded shape',
        'importance': 0.8
    }
}
GroupConfig

Type-Based Configuration Helper

GroupConfig simplifies configuration for models with mixed concept types (binary and categorical). Instead of configuring each concept individually, configure once per type.

Quick Start

from torch_concepts import GroupConfig
from torch.distributions import Bernoulli, Categorical
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss

# Configure distributions by type
variable_distributions = GroupConfig(
    binary=Bernoulli,
    categorical=Categorical
)

# Configure losses by type
loss_config = GroupConfig(
    binary=BCEWithLogitsLoss(),
    categorical=CrossEntropyLoss()
)

# Configure metrics by type
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy

metrics_config = GroupConfig(
    binary={'accuracy': BinaryAccuracy()},
    categorical={'accuracy': MulticlassAccuracy}
)

Automatic Type Detection

GroupConfig automatically determines concept types based on cardinalities:

  • Binary: cardinality = 1

  • Categorical: cardinality > 1

  • Continuous: when type=’continuous’ in metadata (not yet fully supported)

# Annotations with mixed types
ann = Annotations({
    1: AxisAnnotation(
        labels=['c1', 'c2', 'c3', 'c4'],
        cardinalities=[1, 1, 3, 5],  # 2 binary + 2 categorical
        metadata={...}
    )
})

# Single configuration for all binary, another for all categorical
variable_distributions = GroupConfig(
    binary=Bernoulli,      # Applied to c1, c2 (cardinality=1)
    categorical=Categorical # Applied to c3, c4 (cardinality>1)
)

Benefits

  1. Scalability: Configure 312 CUB-200 attributes as easily as 5 concepts

  2. Consistency: Same settings applied to all concepts of the same type

  3. Maintainability: Change one configuration instead of hundreds

  4. Type Safety: Validates that all required types are configured

Usage with Models

from torch_concepts.nn import ConceptBottleneckModel

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    variable_distributions=GroupConfig(
        binary=Bernoulli,
        categorical=Categorical
    ),
    task_names=['class_A', 'class_B']
)

Usage with Loss Functions

from torch_concepts.nn import ConceptLoss

loss = ConceptLoss(
    annotations=ann,
    fn_collection=GroupConfig(
        binary=BCEWithLogitsLoss(),
        categorical=CrossEntropyLoss()
    )
)

Usage with Metrics

from torch_concepts.nn import ConceptMetrics

metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=GroupConfig(
        binary={'accuracy': BinaryAccuracy(), 'f1': BinaryF1Score()},
        categorical={'accuracy': (MulticlassAccuracy, {'average': 'macro'})}
    ),
    summary_metrics=True,
    perconcept_metrics=False
)

Special Cases

All same type: GroupConfig works even when all concepts are the same type:

# All binary
variable_distributions = GroupConfig(binary=Bernoulli)

# All categorical
variable_distributions = GroupConfig(categorical=Categorical)

Missing types: If a required type is not configured, an error is raised:

# ERROR: has categorical concepts but only binary configured
variable_distributions = GroupConfig(binary=Bernoulli)
# Will raise error when used with mixed annotations
Loss Functions

Type-Aware Loss Computation

ConceptLoss automatically routes predictions to appropriate loss functions based on concept types (binary, categorical). It handles mixed concept types seamlessly.

Quick Start

from torch_concepts.nn import ConceptLoss
from torch_concepts import GroupConfig
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss

# Configure losses by type
loss_config = GroupConfig(
    binary=BCEWithLogitsLoss(),
    categorical=CrossEntropyLoss()
)

# Create type-aware loss
loss = ConceptLoss(annotations=ann, fn_collection=loss_config)

# Use in training
predictions = model(x)
targets = batch['concepts']
loss_value = loss(predictions, targets)

Automatic Routing

ConceptLoss automatically:

  1. Splits predictions and targets by concept type

  2. Routes binary concepts to binary loss

  3. Routes categorical concepts to categorical loss

  4. Aggregates results

# Mixed predictions: 2 binary + 3-class categorical + 1 binary
predictions = torch.randn(32, 6)  # Shape: [batch, 1+1+3+1]

# Mixed targets: 2 binary + 1 categorical (class indices) + 1 binary
targets = torch.cat([
    torch.randint(0, 2, (32, 2)),  # Binary targets
    torch.randint(0, 3, (32, 1)),  # Categorical target (indices)
    torch.randint(0, 2, (32, 1))   # Binary target
], dim=1)

# Automatic routing to appropriate losses
loss_value = loss(predictions, targets)

Weighted Loss

Use WeightedConceptLoss for custom weighting:

from torch_concepts.nn import WeightedConceptLoss

loss = WeightedConceptLoss(
    annotations=ann,
    fn_collection=loss_config,
    concept_loss_weight=0.5,  # Weight for concept predictions
    task_loss_weight=1.0       # Weight for task predictions
)

Integration with Models

from torch_concepts.nn import ConceptBottleneckModel

# Lightning training mode
model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A', 'class_B'],
    loss=loss,  # Automatic loss computation
    optim_class=torch.optim.AdamW,
    optim_kwargs={'lr': 0.001}
)

# Manual training mode
model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A', 'class_B']
)

optimizer = torch.optim.Adam(model.parameters())
for batch in dataloader:
    predictions = model(batch['inputs'])
    loss_value = loss(predictions, batch['concepts'])
    loss_value.backward()
    optimizer.step()

Special Cases

Target format: Targets must match the concept space structure:

  • Binary concepts: targets are 0 or 1 (shape: [batch, n_binary])

  • Categorical concepts: targets are class indices (shape: [batch, 1] per concept)

Reduction: Losses support different reduction modes (‘mean’, ‘sum’, ‘none’):

loss_config = GroupConfig(
    binary=BCEWithLogitsLoss(reduction='mean'),
    categorical=CrossEntropyLoss(reduction='mean')
)
Metrics

Type-Aware Metric Tracking

ConceptMetrics automatically routes predictions to appropriate metrics based on concept types and provides both summary (aggregate) and per-concept tracking.

Quick Start

from torch_concepts.nn import ConceptMetrics
from torch_concepts import GroupConfig
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy

# Configure metrics by type
metrics_config = GroupConfig(
    binary={'accuracy': BinaryAccuracy()},
    categorical={'accuracy': MulticlassAccuracy}
)

# Create metrics tracker
metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=metrics_config,
    summary_metrics=True,      # Aggregate by type
    perconcept_metrics=True     # Individual concept tracking
)

# During training
metrics.update(preds=predictions, target=targets, split='train')

# End of epoch
results = metrics.compute('train')
metrics.reset('train')

Summary vs Per-Concept Metrics

Summary metrics: Aggregate performance across all concepts of each type

metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=metrics_config,
    summary_metrics=True,
    perconcept_metrics=False
)

results = metrics.compute('train')
# Output: {
#     'train/SUMMARY-binary_accuracy': tensor(0.8542),
#     'train/SUMMARY-categorical_accuracy': tensor(0.7621)
# }

Per-concept metrics: Track each concept individually

metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=metrics_config,
    summary_metrics=False,
    perconcept_metrics=True
)

results = metrics.compute('train')
# Output: {
#     'train/is_round_accuracy': tensor(0.9000),
#     'train/is_smooth_accuracy': tensor(0.8500),
#     'train/color_accuracy': tensor(0.7621)
# }

Selective tracking: Track only specific concepts

metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=metrics_config,
    summary_metrics=True,
    perconcept_metrics=['is_round', 'color']  # Only these
)

Multiple Metrics per Type

from torchmetrics.classification import BinaryF1Score, BinaryPrecision

metrics_config = GroupConfig(
    binary={
        'accuracy': BinaryAccuracy(),
        'f1': BinaryF1Score(),
        'precision': BinaryPrecision()
    },
    categorical={
        'accuracy': (MulticlassAccuracy, {'average': 'macro'}),
        'f1': (MulticlassF1Score, {'average': 'weighted'})
    }
)

Split-Aware Tracking

Maintain independent metrics for train/validation/test:

# Training loop
for batch in train_loader:
    predictions = model(batch['inputs'])
    metrics.update(pred=predictions, target=batch['concepts'], split='train')

# Validation loop
for batch in val_loader:
    predictions = model(batch['inputs'])
    metrics.update(pred=predictions, target=batch['concepts'], split='val')

# Compute separately
train_results = metrics.compute('train')
val_results = metrics.compute('val')

# Reset for next epoch
metrics.reset('train')
metrics.reset('val')

Integration with Lightning

from torch_concepts.nn import ConceptBottleneckModel

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A', 'class_B'],
    loss=loss,
    metrics=metrics,  # Automatic metric tracking
    optim_class=torch.optim.AdamW,
    optim_kwargs={'lr': 0.001}
)

trainer = Trainer(max_epochs=100)
trainer.fit(model, datamodule)
# Metrics automatically logged

Special Cases

Metric configuration methods: Three ways to specify metrics

  1. Pre-instantiated: {'accuracy': BinaryAccuracy()}

  2. Class + kwargs: {'accuracy': (BinaryAccuracy, {'threshold': 0.6})}

  3. Class only: {'accuracy': BinaryAccuracy}

Target format: Targets must be in concept space:

  • Binary: 0 or 1 values

  • Categorical: class indices (0 to num_classes-1)

num_classes: For categorical metrics, num_classes is automatically set based on cardinalities

Models

Pre-Built Concept-Based Models

PyC provides ready-to-use models like ConceptBottleneckModel that support both manual PyTorch training and automatic Lightning training.

Quick Start

from torch_concepts.nn import ConceptBottleneckModel
from torch_concepts import GroupConfig
from torch.distributions import Bernoulli, Categorical

# Basic model
model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    variable_distributions=GroupConfig(
        binary=Bernoulli,
        categorical=Categorical
    ),
    task_names=['class_A', 'class_B']
)

Manual PyTorch Training

# Model without loss/optimizer
model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A', 'class_B'],
    latent_encoder_kwargs={'hidden_size': 128, 'n_layers': 2}
)

# Custom training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()

model.train()
for epoch in range(100):
    for batch in dataloader:
        optimizer.zero_grad()

        # Forward pass - query all concepts and tasks
        predictions = model(
            batch['inputs']['x'],
            query=['round', 'smooth', 'bright', 'class_A', 'class_B']
        )

        loss = loss_fn(predictions, batch['targets'])
        loss.backward()
        optimizer.step()

Lightning Training

from torch_concepts.nn import ConceptLoss, ConceptMetrics

# Model with loss, metrics, and optimizer
model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A', 'class_B'],
    loss=ConceptLoss(annotations=ann, fn_collection=loss_config),
    metrics=ConceptMetrics(
        annotations=ann,
        fn_collection=metrics_config,
        summary_metrics=True,
        perconcept_metrics=True
    ),
    optim_class=torch.optim.AdamW,
    optim_kwargs={'lr': 0.001}
)

# Automatic training
from pytorch_lightning import Trainer

trainer = Trainer(max_epochs=100)
trainer.fit(model, datamodule)

Model Architecture

model = ConceptBottleneckModel(
    input_size=256,                    # After backbone (if any)
    annotations=ann,
    task_names=['class_A', 'class_B'],

    # Optional backbone for feature extraction
    backbone=torchvision.models.resnet18(pretrained=True),

    # Latent encoder configuration
    latent_encoder_kwargs={
        'hidden_size': 128,    # Hidden dimension
        'n_layers': 2,         # Number of layers
        'activation': 'relu',  # Activation function
        'dropout': 0.1         # Dropout rate
    },

    # Distribution configuration
    variable_distributions=GroupConfig(
        binary=Bernoulli,
        categorical=Categorical
    )
)

Querying Models

Models support flexible querying of concepts and tasks:

model.eval()
with torch.no_grad():
    # Query all variables
    all_preds = model(x, query=['round', 'smooth', 'bright', 'class_A'])
    # Shape: [batch, 4]

    # Query only concepts
    concept_preds = model(x, query=['round', 'smooth', 'bright'])
    # Shape: [batch, 3]

    # Query only tasks
    task_preds = model(x, query=['class_A', 'class_B'])
    # Shape: [batch, 2]

    # Query specific subset
    subset_preds = model(x, query=['round', 'class_A'])
    # Shape: [batch, 2]

Available Models

  • ConceptBottleneckModel: Standard CBM with joint training

  • ConceptBottleneckModel_Joint: Explicit joint training variant

  • BlackBox: Non-interpretable baseline for comparison

Special Cases

Backbone integration: For image data, use a backbone for feature extraction

import torchvision.models as models

backbone = models.resnet18(pretrained=True)
# Remove final classification layer
backbone = nn.Sequential(*list(backbone.children())[:-1])

model = ConceptBottleneckModel(
    input_size=512,  # ResNet18 output size
    annotations=ann,
    backbone=backbone,
    task_names=['class_A']
)

No latent encoder: For pre-computed features, skip the encoder

model = ConceptBottleneckModel(
    input_size=256,
    annotations=ann,
    task_names=['class_A'],
    latent_encoder_kwargs=None  # Use Identity, no encoding
)

Complete Example

Putting it all together:

import torch
from torch.distributions import Bernoulli, Categorical
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy
from pytorch_lightning import Trainer

from torch_concepts import GroupConfig
from torch_concepts.nn import (
    ConceptBottleneckModel,
    ConceptLoss,
    ConceptMetrics
)
from torch_concepts.data.datamodules import BnLearnDataModule

# Use the insurance dataset from BnLearn (mixed binary and categorical concepts)
datamodule = BnLearnDataModule(
    name='insurance',
    root='./data/insurance',
    seed=42,
    n_gen=1000,
    batch_size=32,
    val_size=0.1,
    test_size=0.2
)

# Setup the datamodule to load/generate data
datamodule.setup('fit')

# Get annotations from the dataset
ann = datamodule.annotations

print(f"Dataset concepts: {ann[1].labels}")
print(f"Concept cardinalities: {ann[1].cardinalities}")

# 2. Create loss and metrics
loss = ConceptLoss(
    annotations=ann,
    fn_collection=GroupConfig(
        binary=BCEWithLogitsLoss(),
        categorical=CrossEntropyLoss()
    )
)

metrics = ConceptMetrics(
    annotations=ann,
    fn_collection=GroupConfig(
        binary={'accuracy': BinaryAccuracy()},
        categorical={'accuracy': (MulticlassAccuracy, {'average': 'micro'})}
    ),
    summary_metrics=True,
    perconcept_metrics=True
)

# 3. Create model with all configurations
# Get input size from first batch
sample_batch = next(iter(datamodule.train_dataloader()))
# The batch['inputs'] is the tensor directly, not a nested dict
if isinstance(sample_batch['inputs'], dict):
    input_size = sample_batch['inputs']['x'].shape[1]
else:
    input_size = sample_batch['inputs'].shape[1]
print(f"Input size: {input_size}")

model = ConceptBottleneckModel(
    input_size=input_size,
    annotations=ann,
    variable_distributions=GroupConfig(
        binary=Bernoulli,
        categorical=Categorical
    ),
    task_names=[],  # No task names for this unsupervised example
    loss=loss,
    metrics=metrics,
    optim_class=torch.optim.AdamW,
    optim_kwargs={'lr': 0.001},
    latent_encoder_kwargs={'hidden_size': 64, 'n_layers': 1}
)

print(f"\nModel created successfully!")
print(f"Number of concepts: {len(ann[1].labels)}")
print(f"Binary concepts: {sum(1 for c in ann[1].cardinalities if c == 1)}")
print(f"Categorical concepts: {sum(1 for c in ann[1].cardinalities if c > 1)}")

# 4. Train with Lightning
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False)
trainer.fit(model, datamodule=datamodule)

# 5. Evaluate
test_results = trainer.test(model, datamodule=datamodule)

# 6. Make predictions
model.eval()
test_batch = next(iter(datamodule.test_dataloader()))
# Get the actual tensor from batch
if isinstance(test_batch['inputs'], dict):
    test_data = test_batch['inputs']['x'][:10]
else:
    test_data = test_batch['inputs'][:10]

with torch.no_grad():
    # Query first 3 concepts
    test_predictions = model(test_data, query=ann[1].labels[:3])
    print(f"\n✓ Test predictions shape: {test_predictions.shape}")
    print(f"✓ Queried concepts: {ann[1].labels[:3]}")

Next Steps