Base classes (high level)

This module provides abstract base classes for high-level model implementations.

Summary

Base Model Classes

BaseModel

Abstract base class for concept-based models.

Class Documentation

class BaseModel(input_size: int, annotations: Annotations, variable_distributions: Mapping | None = None, backbone: str | Callable[[Tensor], Tensor] | None = None, latent_encoder: Module | None = None, latent_encoder_kwargs: Dict | None = None, **kwargs)[source]

Bases: Module, ABC

Abstract base class for concept-based models.

Provides common functionality for models that use backbones for feature extraction, and encoders for latent representations. All concrete model implementations should inherit from this class.

BaseModel is flexible and supports two distinct training paradigms:

Mode 1: Standard PyTorch Training (Manual Loop)

Initialize model without loss/optimizer parameters for full manual control. You define the training loop, optimizer, and loss function externally.

Mode 2: PyTorch Lightning Training (Automatic)

Initialize model with loss, optim_class, and optim_kwargs for automatic training via PyTorch Lightning Trainer. The model inherits training logic from Learner classes.

Parameters:
  • input_size (int) – Dimensionality of input features after backbone processing. If no backbone is used (backbone=None), this should match raw input dimensionality.

  • annotations (Annotations) – Concept annotations containing variable names, cardinalities, and optional distribution metadata. Distributions specify how the model represents each concept (e.g., Bernoulli for binary, Categorical for multi-class).

  • variable_distributions (Mapping, optional) – Dictionary mapping concept names to torch.distributions classes (e.g., {'c1': Bernoulli, 'c2': Categorical}). Required if annotations lack ‘distribution’ metadata. If provided, distributions are added to annotations internally. Can also be a GroupConfig object. Defaults to None.

  • backbone (BackboneType, optional) – Feature extraction module (e.g., ResNet, ViT) applied before latent encoder. Can be nn.Module or callable. If None, assumes inputs are pre-computed features. Defaults to None.

  • latent_encoder (nn.Module, optional) – Custom encoder mapping backbone outputs to latent space. If provided, latent_encoder_kwargs are passed to this constructor. If None and latent_encoder_kwargs provided, uses MLP. Defaults to None.

  • latent_encoder_kwargs (Dict, optional) – Arguments for latent encoder construction. Common keys: - ‘hidden_size’ (int): Latent dimension - ‘n_layers’ (int): Number of hidden layers - ‘activation’ (str): Activation function name If None, uses nn.Identity (no encoding). Defaults to None.

  • **kwargs – Additional arguments passed to nn.Module superclass.

concept_annotations

Axis-1 annotations with distribution metadata for each concept.

Type:

AxisAnnotation

concept_names

List of concept variable names from annotations.

Type:

List[str]

backbone

Feature extraction module (None if using pre-computed features).

Type:

BackboneType or None

latent_encoder

Encoder transforming backbone outputs to latent representations.

Type:

nn.Module

latent_size

Dimensionality of latent encoder output (input to concept encoders).

Type:

int

Notes

  • Concept Distributions: The model needs to know which distribution to use for each concept (Bernoulli, Categorical, Normal, etc.). This can be provided in two ways:

    1. In annotations metadata: metadata={'c1': {'distribution': Bernoulli}}

    2. Via variable_distributions parameter at initialization

    If distributions are in annotations, variable_distributions is not needed. If not, variable_distributions is required and will be added to annotations.

  • Subclasses must implement forward(), filter_output_for_loss(), and filter_output_for_metrics() methods.

  • For Lightning training, subclasses typically inherit from both BaseModel and a Learner class (e.g., JointLearner) via multiple inheritance.

  • The latent_size attribute is critical for downstream concept encoders to determine input dimensionality.

Examples

Distributions specify how the model represents concepts. Provide them either in annotations metadata OR via variable_distributions parameter:

>>> import torch
>>> import torch.nn as nn
>>> from torch.distributions import Bernoulli
>>> from torch_concepts.nn import ConceptBottleneckModel
>>> from torch_concepts.annotations import AxisAnnotation, Annotations
>>>
>>> # Option 1: Distributions in annotations metadata
>>> ann = Annotations({
...     1: AxisAnnotation(
...         labels=['c1', 'c2', 'task'],
...         cardinalities=[1, 1, 1],
...         metadata={
...             'c1': {'type': 'binary', 'distribution': Bernoulli},
...             'c2': {'type': 'binary', 'distribution': Bernoulli},
...             'task': {'type': 'binary', 'distribution': Bernoulli}
...         }
...     )
... })
>>> model = ConceptBottleneckModel(
...     input_size=10,
...     annotations=ann,  # Distributions already in metadata
...     task_names=['task']
... )
>>>
>>> # Option 2: Distributions via variable_distributions parameter
>>> ann_no_dist = Annotations({
...     1: AxisAnnotation(
...         labels=['c1', 'c2', 'task'],
...         cardinalities=[1, 1, 1]
...     )
... })
>>> variable_distributions = {'c1': Bernoulli, 'c2': Bernoulli, 'task': Bernoulli}
>>> model = ConceptBottleneckModel(
...     input_size=10,
...     annotations=ann_no_dist,
...     variable_distributions=variable_distributions,  # Added here
...     task_names=['task']
... )
>>>
>>> # Manual training loop
>>> optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
>>> loss_fn = nn.BCEWithLogitsLoss()
>>> x = torch.randn(32, 10)
>>> y = torch.randint(0, 2, (32, 3)).float()
>>>
>>> for epoch in range(100):
...     optimizer.zero_grad()
...     out = model(x, query=['c1', 'c2', 'task'])
...     loss = loss_fn(out, y)
...     loss.backward()
...     optimizer.step()

See also

torch_concepts.nn.modules.high.models.cbm.ConceptBottleneckModel

Concrete CBM implementation

torch_concepts.nn.modules.high.learners.JointLearner

Lightning training logic for joint models

torch_concepts.annotations.Annotations

Concept annotation container

property backbone: str | Callable[[Tensor], Tensor] | None

The backbone feature extractor.

Returns the backbone module used for feature extraction from raw inputs. If None, the model expects pre-computed features as inputs.

Returns:

Backbone module (e.g., ResNet, ViT) or None if using pre-computed features.

Return type:

BackboneType or None

property latent_encoder: Module

The encoder mapping backbone output to latent space.

Returns the latent encoder module that transforms backbone features (or raw inputs if no backbone) into latent representations used by concept encoders.

Returns:

Latent encoder network (MLP, custom module, or nn.Identity if no encoding).

Return type:

nn.Module

training: bool
maybe_apply_backbone(x: Tensor, backbone_args: Mapping[str, Any] | None = None) Tensor[source]

Apply the backbone to x unless features are pre-computed.

Parameters:
  • x (torch.Tensor) – Raw input tensor or already computed embeddings.

  • backbone_kwargs (Any) – Extra keyword arguments forwarded to the backbone callable when it is invoked.

Returns:

Feature embeddings.

Return type:

torch.Tensor

Raises:

TypeError – If backbone is not None and not callable.

filter_output_for_loss(out_concepts)[source]

Filter model outputs before passing to loss function.

Override this method to customize what outputs are passed to the loss. Useful when your model returns auxiliary outputs that shouldn’t be included in loss computation or viceversa.

Parameters:

out_concepts – Model output (typically concept predictions).

Returns:

Filtered output passed to loss function. By default, returns out_concepts unchanged.

Example

>>> def filter_output_for_loss(self, out):
...     # Only use concept predictions, ignore attention weights
...     return out['concepts']
filter_output_for_metrics(out_concepts)[source]

Filter model outputs before passing to metrics.

Override this method to customize what outputs are passed to metrics. Useful when your model returns auxiliary outputs that shouldn’t be included in metric computation or viceversa.

Parameters:

out_concepts – Model output (typically concept predictions).

Returns:

Filtered output passed to metrics. By default, returns out_concepts unchanged.