Base classes (high level)¶
This module provides abstract base classes for high-level model implementations.
Summary¶
Base Model Classes
Abstract base class for concept-based models. |
Class Documentation¶
- class BaseModel(input_size: int, annotations: Annotations, variable_distributions: Mapping | None = None, backbone: str | Callable[[Tensor], Tensor] | None = None, latent_encoder: Module | None = None, latent_encoder_kwargs: Dict | None = None, **kwargs)[source]¶
-
Abstract base class for concept-based models.
Provides common functionality for models that use backbones for feature extraction, and encoders for latent representations. All concrete model implementations should inherit from this class.
BaseModel is flexible and supports two distinct training paradigms:
Mode 1: Standard PyTorch Training (Manual Loop)
Initialize model without loss/optimizer parameters for full manual control. You define the training loop, optimizer, and loss function externally.
Mode 2: PyTorch Lightning Training (Automatic)
Initialize model with loss, optim_class, and optim_kwargs for automatic training via PyTorch Lightning Trainer. The model inherits training logic from Learner classes.
- Parameters:
input_size (int) – Dimensionality of input features after backbone processing. If no backbone is used (backbone=None), this should match raw input dimensionality.
annotations (Annotations) – Concept annotations containing variable names, cardinalities, and optional distribution metadata. Distributions specify how the model represents each concept (e.g., Bernoulli for binary, Categorical for multi-class).
variable_distributions (Mapping, optional) – Dictionary mapping concept names to torch.distributions classes (e.g.,
{'c1': Bernoulli, 'c2': Categorical}). Required if annotations lack ‘distribution’ metadata. If provided, distributions are added to annotations internally. Can also be a GroupConfig object. Defaults to None.backbone (BackboneType, optional) – Feature extraction module (e.g., ResNet, ViT) applied before latent encoder. Can be nn.Module or callable. If None, assumes inputs are pre-computed features. Defaults to None.
latent_encoder (nn.Module, optional) – Custom encoder mapping backbone outputs to latent space. If provided, latent_encoder_kwargs are passed to this constructor. If None and latent_encoder_kwargs provided, uses MLP. Defaults to None.
latent_encoder_kwargs (Dict, optional) – Arguments for latent encoder construction. Common keys: - ‘hidden_size’ (int): Latent dimension - ‘n_layers’ (int): Number of hidden layers - ‘activation’ (str): Activation function name If None, uses nn.Identity (no encoding). Defaults to None.
**kwargs – Additional arguments passed to nn.Module superclass.
- concept_annotations¶
Axis-1 annotations with distribution metadata for each concept.
- Type:
- backbone¶
Feature extraction module (None if using pre-computed features).
- Type:
BackboneType or None
- latent_encoder¶
Encoder transforming backbone outputs to latent representations.
- Type:
nn.Module
Notes
Concept Distributions: The model needs to know which distribution to use for each concept (Bernoulli, Categorical, Normal, etc.). This can be provided in two ways:
In annotations metadata:
metadata={'c1': {'distribution': Bernoulli}}Via variable_distributions parameter at initialization
If distributions are in annotations, variable_distributions is not needed. If not, variable_distributions is required and will be added to annotations.
Subclasses must implement
forward(),filter_output_for_loss(), andfilter_output_for_metrics()methods.For Lightning training, subclasses typically inherit from both BaseModel and a Learner class (e.g., JointLearner) via multiple inheritance.
The latent_size attribute is critical for downstream concept encoders to determine input dimensionality.
Examples
Distributions specify how the model represents concepts. Provide them either in annotations metadata OR via variable_distributions parameter:
>>> import torch >>> import torch.nn as nn >>> from torch.distributions import Bernoulli >>> from torch_concepts.nn import ConceptBottleneckModel >>> from torch_concepts.annotations import AxisAnnotation, Annotations >>> >>> # Option 1: Distributions in annotations metadata >>> ann = Annotations({ ... 1: AxisAnnotation( ... labels=['c1', 'c2', 'task'], ... cardinalities=[1, 1, 1], ... metadata={ ... 'c1': {'type': 'binary', 'distribution': Bernoulli}, ... 'c2': {'type': 'binary', 'distribution': Bernoulli}, ... 'task': {'type': 'binary', 'distribution': Bernoulli} ... } ... ) ... }) >>> model = ConceptBottleneckModel( ... input_size=10, ... annotations=ann, # Distributions already in metadata ... task_names=['task'] ... ) >>> >>> # Option 2: Distributions via variable_distributions parameter >>> ann_no_dist = Annotations({ ... 1: AxisAnnotation( ... labels=['c1', 'c2', 'task'], ... cardinalities=[1, 1, 1] ... ) ... }) >>> variable_distributions = {'c1': Bernoulli, 'c2': Bernoulli, 'task': Bernoulli} >>> model = ConceptBottleneckModel( ... input_size=10, ... annotations=ann_no_dist, ... variable_distributions=variable_distributions, # Added here ... task_names=['task'] ... ) >>> >>> # Manual training loop >>> optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) >>> loss_fn = nn.BCEWithLogitsLoss() >>> x = torch.randn(32, 10) >>> y = torch.randint(0, 2, (32, 3)).float() >>> >>> for epoch in range(100): ... optimizer.zero_grad() ... out = model(x, query=['c1', 'c2', 'task']) ... loss = loss_fn(out, y) ... loss.backward() ... optimizer.step()
See also
torch_concepts.nn.modules.high.models.cbm.ConceptBottleneckModelConcrete CBM implementation
torch_concepts.nn.modules.high.learners.JointLearnerLightning training logic for joint models
torch_concepts.annotations.AnnotationsConcept annotation container
- property backbone: str | Callable[[Tensor], Tensor] | None¶
The backbone feature extractor.
Returns the backbone module used for feature extraction from raw inputs. If None, the model expects pre-computed features as inputs.
- Returns:
Backbone module (e.g., ResNet, ViT) or None if using pre-computed features.
- Return type:
BackboneType or None
- property latent_encoder: Module¶
The encoder mapping backbone output to latent space.
Returns the latent encoder module that transforms backbone features (or raw inputs if no backbone) into latent representations used by concept encoders.
- Returns:
Latent encoder network (MLP, custom module, or nn.Identity if no encoding).
- Return type:
nn.Module
- maybe_apply_backbone(x: Tensor, backbone_args: Mapping[str, Any] | None = None) Tensor[source]¶
Apply the backbone to
xunless features are pre-computed.- Parameters:
x (torch.Tensor) – Raw input tensor or already computed embeddings.
backbone_kwargs (Any) – Extra keyword arguments forwarded to the backbone callable when it is invoked.
- Returns:
Feature embeddings.
- Return type:
- Raises:
TypeError – If backbone is not None and not callable.
- filter_output_for_loss(out_concepts)[source]¶
Filter model outputs before passing to loss function.
Override this method to customize what outputs are passed to the loss. Useful when your model returns auxiliary outputs that shouldn’t be included in loss computation or viceversa.
- Parameters:
out_concepts – Model output (typically concept predictions).
- Returns:
Filtered output passed to loss function. By default, returns out_concepts unchanged.
Example
>>> def filter_output_for_loss(self, out): ... # Only use concept predictions, ignore attention weights ... return out['concepts']
- filter_output_for_metrics(out_concepts)[source]¶
Filter model outputs before passing to metrics.
Override this method to customize what outputs are passed to metrics. Useful when your model returns auxiliary outputs that shouldn’t be included in metric computation or viceversa.
- Parameters:
out_concepts – Model output (typically concept predictions).
- Returns:
Filtered output passed to metrics. By default, returns out_concepts unchanged.