Base classes (low level)

This module provides abstract base classes for building concept-based neural networks at the low level. These classes define the fundamental interfaces for encoders, predictors, graph learners, and inference modules.

Summary

Base Layer Classes

BaseConceptLayer

Abstract base class for concept layers.

BaseEncoder

Abstract base class for concept encoder layers.

BasePredictor

Abstract base class for concept predictor layers.

Graph Learning Classes

BaseGraphLearner

Abstract base class for concept graph learning modules.

Inference Classes

BaseInference

Abstract base class for inference modules.

BaseIntervention

Abstract base class for intervention modules.

Class Documentation

Layer Classes

class BaseConceptLayer(out_features: int, in_features_endogenous: int | None = None, in_features: int | None = None, in_features_exogenous: int | None = None, *args, **kwargs)[source]

Bases: ABC, Module

Abstract base class for concept layers.

This class provides the foundation for all concept-based layers, defining the interface and basic structure for concept encoders and predictors.

in_features_endogenous

Number of input logit features.

Type:

int

in_features

Number of input latent features.

Type:

int

in_features_exogenous

Number of exogenous input features.

Type:

int

out_features

Number of output features.

Type:

int

Parameters:
  • out_features – Number of output features.

  • in_features_endogenous – Number of input logit features (optional).

  • in_features – Number of input latent features (optional).

  • in_features_exogenous – Number of exogenous input features (optional).

Example

>>> import torch
>>> from torch_concepts.nn import BaseConceptLayer
>>>
>>> # Create a custom concept layer
>>> class MyConceptLayer(BaseConceptLayer):
...     def __init__(self, out_features, in_features_endogenous):
...         super().__init__(
...             out_features=out_features,
...             in_features_endogenous=in_features_endogenous
...         )
...         self.linear = torch.nn.Linear(in_features_endogenous, out_features)
...
...     def forward(self, endogenous):
...         return torch.sigmoid(self.linear(endogenous))
>>>
>>> # Example usage
>>> layer = MyConceptLayer(out_features=5, in_features_endogenous=10)
>>>
>>> # Generate random input
>>> endogenous = torch.randn(2, 10)  # batch_size=2, in_features=10
>>>
>>> # Forward pass
>>> output = layer(endogenous)
>>> print(output.shape)  # torch.Size([2, 5])
forward(*args, **kwargs) Tensor[source]

Forward pass through the concept layer.

Must be implemented by subclasses.

Returns:

Output tensor.

Return type:

torch.Tensor

Raises:

NotImplementedError – This is an abstract method.

training: bool
class BaseEncoder(out_features: int, in_features: int | None = None, in_features_exogenous: int | None = None)[source]

Bases: BaseConceptLayer

Abstract base class for concept encoder layers.

Encoders transform input features (latent or exogenous variables) into concept representations.

Parameters:
  • out_features – Number of output concept features.

  • in_features – Number of input latent features (optional).

  • in_features_exogenous – Number of exogenous input features (optional).

Example

>>> import torch
>>> from torch_concepts.nn import BaseEncoder
>>>
>>> # Create a custom encoder
>>> class MyEncoder(BaseEncoder):
...     def __init__(self, out_features, in_features):
...         super().__init__(
...             out_features=out_features,
...             in_features=in_features
...         )
...         self.net = torch.nn.Sequential(
...             torch.nn.Linear(in_features, 128),
...             torch.nn.ReLU(),
...             torch.nn.Linear(128, out_features)
...         )
...
...     def forward(self, latent):
...         return self.net(latent)
>>>
>>> # Example usage
>>> encoder = MyEncoder(out_features=10, in_features=784)
>>>
>>> # Generate random image latent (e.g., flattened MNIST)
>>> x = torch.randn(4, 784)  # batch_size=4, pixels=784
>>>
>>> # Encode to concepts
>>> concepts = encoder(x)
>>> print(concepts.shape)  # torch.Size([4, 10])
training: bool
class BasePredictor(out_features: int, in_features_endogenous: int, in_features: int | None = None, in_features_exogenous: int | None = None, in_activation: ~typing.Callable = <built-in method sigmoid of type object>)[source]

Bases: BaseConceptLayer

Abstract base class for concept predictor layers.

Predictors take concept representations (plus latent or exogenous variables) and predict other concept representations.

in_activation

Activation function for input (default: sigmoid).

Type:

Callable

Parameters:
  • out_features – Number of output concept features.

  • in_features_endogenous – Number of input logit features.

  • in_features – Number of input latent features (optional).

  • in_features_exogenous – Number of exogenous input features (optional).

  • in_activation – Activation function for input (default: torch.sigmoid).

Example

>>> import torch
>>> from torch_concepts.nn import BasePredictor
>>>
>>> # Create a custom predictor
>>> class MyPredictor(BasePredictor):
...     def __init__(self, out_features, in_features_endogenous):
...         super().__init__(
...             out_features=out_features,
...             in_features_endogenous=in_features_endogenous,
...             in_activation=torch.sigmoid
...         )
...         self.linear = torch.nn.Linear(in_features_endogenous, out_features)
...
...     def forward(self, endogenous):
...         # Apply activation to input endogenous
...         probs = self.in_activation(endogenous)
...         # Predict next concepts
...         return self.linear(probs)
>>>
>>> # Example usage
>>> predictor = MyPredictor(out_features=3, in_features_endogenous=10)
>>>
>>> # Generate random concept endogenous
>>> concept_endogenous = torch.randn(4, 10)  # batch_size=4, n_concepts=10
>>>
>>> # Predict task labels from concepts
>>> task_endogenous = predictor(concept_endogenous)
>>> print(task_endogenous.shape)  # torch.Size([4, 3])
>>>
>>> # Get task predictions
>>> task_probs = torch.sigmoid(task_endogenous)
>>> print(task_probs.shape)  # torch.Size([4, 3])
prune(mask: Tensor)[source]

Prune the predictor by removing connections based on the given mask.

This method removes unnecessary connections in the predictor layer based on a binary mask, which can help reduce model complexity and improve interpretability.

Parameters:

mask – A binary mask indicating which connections to keep (1) or remove (0).

Raises:

NotImplementedError – Must be implemented by subclasses that support pruning.

training: bool

Graph Learning Classes

class BaseGraphLearner(row_labels: List[str], col_labels: List[str])[source]

Bases: Module, ABC

Abstract base class for concept graph learning modules.

This class provides the foundation for learning the structure of concept graphs from data. Subclasses implement specific graph learning algorithms such as WANDA, NOTEARS, or other structure learning methods.

row_labels

Labels for graph rows (source concepts).

Type:

List[str]

col_labels

Labels for graph columns (target concepts).

Type:

List[str]

n_labels

Number of concepts in the graph.

Type:

int

Parameters:
  • row_labels – List of concept names for graph rows.

  • col_labels – List of concept names for graph columns.

Raises:

AssertionError – If row_labels and col_labels have different lengths.

Example

>>> import torch
>>> from torch_concepts.nn import BaseGraphLearner
>>>
>>> class MyGraphLearner(BaseGraphLearner):
...     def __init__(self, row_labels, col_labels):
...         super().__init__(row_labels, col_labels)
...         self.graph_params = torch.nn.Parameter(
...             torch.randn(self.n_labels, self.n_labels)
...         )
...
...     def weighted_adj(self):
...         return torch.sigmoid(self.graph_params)
>>>
>>> # Create learner
>>> concepts = ['c1', 'c2', 'c3']
>>> learner = MyGraphLearner(concepts, concepts)
>>> adj_matrix = learner.weighted_adj()
>>> print(adj_matrix.shape)
torch.Size([3, 3])
abstract weighted_adj() Tensor[source]

Return the learned weighted adjacency matrix.

This method must be implemented by subclasses to return the current estimate of the concept graph’s adjacency matrix.

Returns:

Weighted adjacency matrix of shape (n_labels, n_labels).

Return type:

torch.Tensor

Raises:

NotImplementedError – This is an abstract method.

training: bool

Inference Classes

class BaseInference[source]

Bases: Module

Abstract base class for inference modules.

Inference modules define how to query concept-based models to obtain concept predictions, supporting various inference strategies such as forward inference, ancestral sampling, or stochastic inference.

Example

>>> import torch
>>> from torch_concepts.nn import BaseInference
>>>
>>> # Create a custom inference class
>>> class SimpleInference(BaseInference):
...     def __init__(self, model):
...         super().__init__()
...         self.model = model
...
...     def query(self, x, **kwargs):
...         # Simple forward pass through model
...         return self.model(x)
>>>
>>> # Example usage
>>> dummy_model = torch.nn.Linear(10, 5)
>>> inference = SimpleInference(dummy_model)
>>>
>>> # Generate random input
>>> x = torch.randn(2, 10)  # batch_size=2, input_features=10
>>>
>>> # Query concepts using forward method
>>> concepts = inference(x)
>>> print(concepts.shape)  # torch.Size([2, 5])
>>>
>>> # Or use query method directly
>>> concepts = inference.query(x)
>>> print(concepts.shape)  # torch.Size([2, 5])
forward(x: Tensor, *args, **kwargs) Tensor[source]

Forward pass delegates to the query method.

Parameters:
  • x – Input tensor.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Queried concepts.

Return type:

torch.Tensor

abstract query(*args, **kwargs) Tensor[source]

Query model to get concepts.

This method must be implemented by subclasses to define the specific inference strategy.

Parameters:
  • *args – Variable length argument list (typically includes input x).

  • **kwargs – Arbitrary keyword arguments (may include intervention c).

Returns:

Queried concept predictions.

Return type:

torch.Tensor

Raises:

NotImplementedError – This is an abstract method.

training: bool
class BaseIntervention(model: Module)[source]

Bases: BaseInference, ABC

Abstract base class for intervention modules.

Intervention modules modify concept-based models by replacing certain modules, enabling causal reasoning and what-if analysis.

This class provides a framework for implementing different intervention strategies on concept-based models.

model

The concept-based model to apply interventions to.

Type:

nn.Module

Parameters:

model – The neural network model to intervene on.

Example

>>> import torch
>>> import torch.nn as nn
>>> from torch_concepts.nn import BaseIntervention
>>>
>>> # Create a custom intervention class
>>> class CustomIntervention(BaseIntervention):
...     def query(self, module_name, **kwargs):
...         # Get the module to intervene on
...         module = self.model.get_submodule(module_name)
...         # Apply intervention logic
...         return module(**kwargs)
>>>
>>> # Create a simple concept model
>>> class ConceptModel(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.encoder = nn.Linear(10, 5)
...         self.predictor = nn.Linear(5, 3)
...
...     def forward(self, x):
...         concepts = torch.sigmoid(self.encoder(x))
...         return self.predictor(concepts)
>>>
>>> # Example usage
>>> model = ConceptModel()
>>> intervention = CustomIntervention(model)
>>>
>>> # Generate random input
>>> x = torch.randn(2, 10)  # batch_size=2, input_features=10
>>>
>>> # Query encoder module
>>> encoder_output = intervention.query('encoder', input=x)
>>> print(encoder_output.shape)  # torch.Size([2, 5])
training: bool