Base classes (mid level)¶
This module provides abstract base classes for building probabilistic models at the mid level.
Summary¶
Base Constructor Classes
Abstract base class for all concept-based models. |
Class Documentation¶
- class BaseConstructor(input_size: int, annotations: Annotations, encoder: LazyConstructor | Module, predictor: LazyConstructor | Module, *args, **kwargs)[source]¶
Bases:
ModuleAbstract base class for all concept-based models.
This class provides the foundation for building concept-based neural networks.
- annotations¶
Concept annotations with metadata.
- Type:
- Parameters:
input_size – Size of the input features.
annotations – Annotations object containing concept metadata.
encoder – LazyConstructor layer for encoding root concepts from inputs.
predictor – LazyConstructor layer for making predictions from concepts.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
Example
>>> import torch >>> from torch_concepts import Annotations, AxisAnnotation >>> from torch_concepts.nn import LazyConstructor >>> from torch_concepts.nn.modules.mid.base.model import BaseConstructor >>> from torch.distributions import RelaxedBernoulli >>> >>> # Create annotations for concepts >>> concept_labels = ('color', 'shape', 'size') >>> cardinalities = [1, 1, 1] >>> metadata = { ... 'color': {'distribution': RelaxedBernoulli}, ... 'shape': {'distribution': RelaxedBernoulli}, ... 'size': {'distribution': RelaxedBernoulli} ... } >>> annotations = Annotations({1: AxisAnnotation( ... labels=concept_labels, ... cardinalities=cardinalities, ... metadata=metadata ... )}) >>> >>> # Create a concrete model class >>> class MyConceptModel(BaseConstructor): ... def __init__(self, input_size, annotations, encoder, predictor): ... super().__init__(input_size, annotations, encoder, predictor) ... # Build encoder and predictor ... self.encoder = self._encoder_builder ... self.predictor = self._predictor_builder ... ... def forward(self, x): ... concepts = self.encoder(x) ... predictions = self.predictor(concepts) ... return predictions >>> >>> # Create encoder and predictor propagators >>> encoder = torch.nn.Linear(784, 3) # Simple encoder >>> predictor = torch.nn.Linear(3, 10) # Simple predictor >>> >>> # Instantiate model >>> model = MyConceptModel( ... input_size=784, ... annotations=annotations, ... encoder=encoder, ... predictor=predictor ... ) >>> >>> # Generate random input (e.g., flattened MNIST image) >>> x = torch.randn(8, 784) # batch_size=8, pixels=784 >>> >>> # Forward pass >>> output = model(x) >>> print(output.shape) # torch.Size([8, 10]) >>> >>> # Access concept labels >>> print(model.labels) # ('color', 'shape', 'size') >>> >>> # Get concept index by name >>> idx = model.name2id['color'] >>> print(idx) # 0