Source code for torch_concepts.nn.modules.mid.base.model
"""
Base model class for concept-based architectures.
This module provides the abstract base class for all concept-based models,
defining the structure for models that use concept representations.
"""
from typing import Union
import torch
from torch.nn import Module
from .....annotations import Annotations
from ...low.lazy import LazyConstructor
[docs]
class BaseConstructor(torch.nn.Module):
"""
Abstract base class for all concept-based models.
This class provides the foundation for building concept-based neural networks.
Attributes:
input_size (int): Size of the input features.
annotations (Annotations): Concept annotations with metadata.
labels (List[str]): List of concept labels.
name2id (Dict[str, int]): Mapping from concept names to indices.
Args:
input_size: Size of the input features.
annotations: Annotations object containing concept metadata.
encoder: LazyConstructor layer for encoding root concepts from inputs.
predictor: LazyConstructor layer for making predictions from concepts.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Example:
>>> import torch
>>> from torch_concepts import Annotations, AxisAnnotation
>>> from torch_concepts.nn import LazyConstructor
>>> from torch_concepts.nn.modules.mid.base.model import BaseConstructor
>>> from torch.distributions import RelaxedBernoulli
>>>
>>> # Create annotations for concepts
>>> concept_labels = ('color', 'shape', 'size')
>>> cardinalities = [1, 1, 1]
>>> metadata = {
... 'color': {'distribution': RelaxedBernoulli},
... 'shape': {'distribution': RelaxedBernoulli},
... 'size': {'distribution': RelaxedBernoulli}
... }
>>> annotations = Annotations({1: AxisAnnotation(
... labels=concept_labels,
... cardinalities=cardinalities,
... metadata=metadata
... )})
>>>
>>> # Create a concrete model class
>>> class MyConceptModel(BaseConstructor):
... def __init__(self, input_size, annotations, encoder, predictor):
... super().__init__(input_size, annotations, encoder, predictor)
... # Build encoder and predictor
... self.encoder = self._encoder_builder
... self.predictor = self._predictor_builder
...
... def forward(self, x):
... concepts = self.encoder(x)
... predictions = self.predictor(concepts)
... return predictions
>>>
>>> # Create encoder and predictor propagators
>>> encoder = torch.nn.Linear(784, 3) # Simple encoder
>>> predictor = torch.nn.Linear(3, 10) # Simple predictor
>>>
>>> # Instantiate model
>>> model = MyConceptModel(
... input_size=784,
... annotations=annotations,
... encoder=encoder,
... predictor=predictor
... )
>>>
>>> # Generate random input (e.g., flattened MNIST image)
>>> x = torch.randn(8, 784) # batch_size=8, pixels=784
>>>
>>> # Forward pass
>>> output = model(x)
>>> print(output.shape) # torch.Size([8, 10])
>>>
>>> # Access concept labels
>>> print(model.labels) # ('color', 'shape', 'size')
>>>
>>> # Get concept index by name
>>> idx = model.name2id['color']
>>> print(idx) # 0
"""
[docs]
def __init__(self,
input_size: int,
annotations: Annotations,
encoder: Union[LazyConstructor, Module], # layer for root concepts
predictor: Union[LazyConstructor, Module],
*args,
**kwargs
):
super(BaseConstructor, self).__init__()
self.input_size = input_size
self.annotations = annotations
self._encoder_builder = encoder
self._predictor_builder = predictor
self.labels = annotations.get_axis_labels(axis=1)
self.name2id = {name: i for i, name in enumerate(self.labels)}