Base classes (mid level)

This module provides abstract base classes for building probabilistic models at the mid level.

Summary

Base Constructor Classes

BaseConstructor

Abstract base class for all concept-based models.

Class Documentation

class BaseConstructor(input_size: int, annotations: Annotations, encoder: LazyConstructor | Module, predictor: LazyConstructor | Module, *args, **kwargs)[source]

Bases: Module

Abstract base class for all concept-based models.

This class provides the foundation for building concept-based neural networks.

input_size

Size of the input features.

Type:

int

annotations

Concept annotations with metadata.

Type:

Annotations

labels

List of concept labels.

Type:

List[str]

name2id

Mapping from concept names to indices.

Type:

Dict[str, int]

Parameters:
  • input_size – Size of the input features.

  • annotations – Annotations object containing concept metadata.

  • encoder – LazyConstructor layer for encoding root concepts from inputs.

  • predictor – LazyConstructor layer for making predictions from concepts.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Example

>>> import torch
>>> from torch_concepts import Annotations, AxisAnnotation
>>> from torch_concepts.nn import LazyConstructor
>>> from torch_concepts.nn.modules.mid.base.model import BaseConstructor
>>> from torch.distributions import RelaxedBernoulli
>>>
>>> # Create annotations for concepts
>>> concept_labels = ('color', 'shape', 'size')
>>> cardinalities = [1, 1, 1]
>>> metadata = {
...     'color': {'distribution': RelaxedBernoulli},
...     'shape': {'distribution': RelaxedBernoulli},
...     'size': {'distribution': RelaxedBernoulli}
... }
>>> annotations = Annotations({1: AxisAnnotation(
...     labels=concept_labels,
...     cardinalities=cardinalities,
...     metadata=metadata
... )})
>>>
>>> # Create a concrete model class
>>> class MyConceptModel(BaseConstructor):
...     def __init__(self, input_size, annotations, encoder, predictor):
...         super().__init__(input_size, annotations, encoder, predictor)
...         # Build encoder and predictor
...         self.encoder = self._encoder_builder
...         self.predictor = self._predictor_builder
...
...     def forward(self, x):
...         concepts = self.encoder(x)
...         predictions = self.predictor(concepts)
...         return predictions
>>>
>>> # Create encoder and predictor propagators
>>> encoder = torch.nn.Linear(784, 3)  # Simple encoder
>>> predictor = torch.nn.Linear(3, 10)  # Simple predictor
>>>
>>> # Instantiate model
>>> model = MyConceptModel(
...     input_size=784,
...     annotations=annotations,
...     encoder=encoder,
...     predictor=predictor
... )
>>>
>>> # Generate random input (e.g., flattened MNIST image)
>>> x = torch.randn(8, 784)  # batch_size=8, pixels=784
>>>
>>> # Forward pass
>>> output = model(x)
>>> print(output.shape)  # torch.Size([8, 10])
>>>
>>> # Access concept labels
>>> print(model.labels)  # ('color', 'shape', 'size')
>>>
>>> # Get concept index by name
>>> idx = model.name2id['color']
>>> print(idx)  # 0
training: bool