Base classes (low level)¶
This module provides abstract base classes for building concept-based neural networks at the low level. These classes define the fundamental interfaces for encoders, predictors, graph learners, and inference modules.
Summary¶
Base Layer Classes
Abstract base class for concept layers. |
|
Abstract base class for concept encoder layers. |
|
Abstract base class for concept predictor layers. |
Graph Learning Classes
Abstract base class for concept graph learning modules. |
Inference Classes
Abstract base class for inference modules. |
|
Abstract base class for intervention modules. |
Class Documentation¶
Layer Classes¶
- class BaseConceptLayer(out_features: int, in_features_endogenous: int | None = None, in_features: int | None = None, in_features_exogenous: int | None = None, *args, **kwargs)[source]¶
-
Abstract base class for concept layers.
This class provides the foundation for all concept-based layers, defining the interface and basic structure for concept encoders and predictors.
- Parameters:
out_features – Number of output features.
in_features_endogenous – Number of input logit features (optional).
in_features – Number of input latent features (optional).
in_features_exogenous – Number of exogenous input features (optional).
Example
>>> import torch >>> from torch_concepts.nn import BaseConceptLayer >>> >>> # Create a custom concept layer >>> class MyConceptLayer(BaseConceptLayer): ... def __init__(self, out_features, in_features_endogenous): ... super().__init__( ... out_features=out_features, ... in_features_endogenous=in_features_endogenous ... ) ... self.linear = torch.nn.Linear(in_features_endogenous, out_features) ... ... def forward(self, endogenous): ... return torch.sigmoid(self.linear(endogenous)) >>> >>> # Example usage >>> layer = MyConceptLayer(out_features=5, in_features_endogenous=10) >>> >>> # Generate random input >>> endogenous = torch.randn(2, 10) # batch_size=2, in_features=10 >>> >>> # Forward pass >>> output = layer(endogenous) >>> print(output.shape) # torch.Size([2, 5])
- forward(*args, **kwargs) Tensor[source]¶
Forward pass through the concept layer.
Must be implemented by subclasses.
- Returns:
Output tensor.
- Return type:
- Raises:
NotImplementedError – This is an abstract method.
- class BaseEncoder(out_features: int, in_features: int | None = None, in_features_exogenous: int | None = None)[source]¶
Bases:
BaseConceptLayerAbstract base class for concept encoder layers.
Encoders transform input features (latent or exogenous variables) into concept representations.
- Parameters:
out_features – Number of output concept features.
in_features – Number of input latent features (optional).
in_features_exogenous – Number of exogenous input features (optional).
Example
>>> import torch >>> from torch_concepts.nn import BaseEncoder >>> >>> # Create a custom encoder >>> class MyEncoder(BaseEncoder): ... def __init__(self, out_features, in_features): ... super().__init__( ... out_features=out_features, ... in_features=in_features ... ) ... self.net = torch.nn.Sequential( ... torch.nn.Linear(in_features, 128), ... torch.nn.ReLU(), ... torch.nn.Linear(128, out_features) ... ) ... ... def forward(self, latent): ... return self.net(latent) >>> >>> # Example usage >>> encoder = MyEncoder(out_features=10, in_features=784) >>> >>> # Generate random image latent (e.g., flattened MNIST) >>> x = torch.randn(4, 784) # batch_size=4, pixels=784 >>> >>> # Encode to concepts >>> concepts = encoder(x) >>> print(concepts.shape) # torch.Size([4, 10])
- class BasePredictor(out_features: int, in_features_endogenous: int, in_features: int | None = None, in_features_exogenous: int | None = None, in_activation: ~typing.Callable = <built-in method sigmoid of type object>)[source]¶
Bases:
BaseConceptLayerAbstract base class for concept predictor layers.
Predictors take concept representations (plus latent or exogenous variables) and predict other concept representations.
- in_activation¶
Activation function for input (default: sigmoid).
- Type:
Callable
- Parameters:
out_features – Number of output concept features.
in_features_endogenous – Number of input logit features.
in_features – Number of input latent features (optional).
in_features_exogenous – Number of exogenous input features (optional).
in_activation – Activation function for input (default: torch.sigmoid).
Example
>>> import torch >>> from torch_concepts.nn import BasePredictor >>> >>> # Create a custom predictor >>> class MyPredictor(BasePredictor): ... def __init__(self, out_features, in_features_endogenous): ... super().__init__( ... out_features=out_features, ... in_features_endogenous=in_features_endogenous, ... in_activation=torch.sigmoid ... ) ... self.linear = torch.nn.Linear(in_features_endogenous, out_features) ... ... def forward(self, endogenous): ... # Apply activation to input endogenous ... probs = self.in_activation(endogenous) ... # Predict next concepts ... return self.linear(probs) >>> >>> # Example usage >>> predictor = MyPredictor(out_features=3, in_features_endogenous=10) >>> >>> # Generate random concept endogenous >>> concept_endogenous = torch.randn(4, 10) # batch_size=4, n_concepts=10 >>> >>> # Predict task labels from concepts >>> task_endogenous = predictor(concept_endogenous) >>> print(task_endogenous.shape) # torch.Size([4, 3]) >>> >>> # Get task predictions >>> task_probs = torch.sigmoid(task_endogenous) >>> print(task_probs.shape) # torch.Size([4, 3])
- prune(mask: Tensor)[source]¶
Prune the predictor by removing connections based on the given mask.
This method removes unnecessary connections in the predictor layer based on a binary mask, which can help reduce model complexity and improve interpretability.
- Parameters:
mask – A binary mask indicating which connections to keep (1) or remove (0).
- Raises:
NotImplementedError – Must be implemented by subclasses that support pruning.
Graph Learning Classes¶
- class BaseGraphLearner(row_labels: List[str], col_labels: List[str])[source]¶
-
Abstract base class for concept graph learning modules.
This class provides the foundation for learning the structure of concept graphs from data. Subclasses implement specific graph learning algorithms such as WANDA, NOTEARS, or other structure learning methods.
- Parameters:
row_labels – List of concept names for graph rows.
col_labels – List of concept names for graph columns.
- Raises:
AssertionError – If row_labels and col_labels have different lengths.
Example
>>> import torch >>> from torch_concepts.nn import BaseGraphLearner >>> >>> class MyGraphLearner(BaseGraphLearner): ... def __init__(self, row_labels, col_labels): ... super().__init__(row_labels, col_labels) ... self.graph_params = torch.nn.Parameter( ... torch.randn(self.n_labels, self.n_labels) ... ) ... ... def weighted_adj(self): ... return torch.sigmoid(self.graph_params) >>> >>> # Create learner >>> concepts = ['c1', 'c2', 'c3'] >>> learner = MyGraphLearner(concepts, concepts) >>> adj_matrix = learner.weighted_adj() >>> print(adj_matrix.shape) torch.Size([3, 3])
- abstract weighted_adj() Tensor[source]¶
Return the learned weighted adjacency matrix.
This method must be implemented by subclasses to return the current estimate of the concept graph’s adjacency matrix.
- Returns:
Weighted adjacency matrix of shape (n_labels, n_labels).
- Return type:
- Raises:
NotImplementedError – This is an abstract method.
Inference Classes¶
- class BaseInference[source]¶
Bases:
ModuleAbstract base class for inference modules.
Inference modules define how to query concept-based models to obtain concept predictions, supporting various inference strategies such as forward inference, ancestral sampling, or stochastic inference.
Example
>>> import torch >>> from torch_concepts.nn import BaseInference >>> >>> # Create a custom inference class >>> class SimpleInference(BaseInference): ... def __init__(self, model): ... super().__init__() ... self.model = model ... ... def query(self, x, **kwargs): ... # Simple forward pass through model ... return self.model(x) >>> >>> # Example usage >>> dummy_model = torch.nn.Linear(10, 5) >>> inference = SimpleInference(dummy_model) >>> >>> # Generate random input >>> x = torch.randn(2, 10) # batch_size=2, input_features=10 >>> >>> # Query concepts using forward method >>> concepts = inference(x) >>> print(concepts.shape) # torch.Size([2, 5]) >>> >>> # Or use query method directly >>> concepts = inference.query(x) >>> print(concepts.shape) # torch.Size([2, 5])
- forward(x: Tensor, *args, **kwargs) Tensor[source]¶
Forward pass delegates to the query method.
- Parameters:
x – Input tensor.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
Queried concepts.
- Return type:
- abstract query(*args, **kwargs) Tensor[source]¶
Query model to get concepts.
This method must be implemented by subclasses to define the specific inference strategy.
- Parameters:
*args – Variable length argument list (typically includes input x).
**kwargs – Arbitrary keyword arguments (may include intervention c).
- Returns:
Queried concept predictions.
- Return type:
- Raises:
NotImplementedError – This is an abstract method.
- class BaseIntervention(model: Module)[source]¶
Bases:
BaseInference,ABCAbstract base class for intervention modules.
Intervention modules modify concept-based models by replacing certain modules, enabling causal reasoning and what-if analysis.
This class provides a framework for implementing different intervention strategies on concept-based models.
- model¶
The concept-based model to apply interventions to.
- Type:
nn.Module
- Parameters:
model – The neural network model to intervene on.
Example
>>> import torch >>> import torch.nn as nn >>> from torch_concepts.nn import BaseIntervention >>> >>> # Create a custom intervention class >>> class CustomIntervention(BaseIntervention): ... def query(self, module_name, **kwargs): ... # Get the module to intervene on ... module = self.model.get_submodule(module_name) ... # Apply intervention logic ... return module(**kwargs) >>> >>> # Create a simple concept model >>> class ConceptModel(nn.Module): ... def __init__(self): ... super().__init__() ... self.encoder = nn.Linear(10, 5) ... self.predictor = nn.Linear(5, 3) ... ... def forward(self, x): ... concepts = torch.sigmoid(self.encoder(x)) ... return self.predictor(concepts) >>> >>> # Example usage >>> model = ConceptModel() >>> intervention = CustomIntervention(model) >>> >>> # Generate random input >>> x = torch.randn(2, 10) # batch_size=2, input_features=10 >>> >>> # Query encoder module >>> encoder_output = intervention.query('encoder', input=x) >>> print(encoder_output.shape) # torch.Size([2, 5])