Concept Predictors¶
This module provides predictor implementations that map from concepts to target predictions.
Summary¶
Predictor Classes
Linear concept predictor. |
|
Concept exogenous predictor with mixture of concept activations and exogenous features. |
|
Hypernetwork-based linear predictor for concept-based models. |
|
A predictor that applies a custom callable function to concept representations. |
Class Documentation¶
- class LinearCC(in_features_endogenous: int, out_features: int, in_activation: ~typing.Callable = <built-in method sigmoid of type object>)[source]¶
Bases:
BasePredictorLinear concept predictor.
This predictor transforms input concept endogenous into other concept endogenous using a linear layer followed by activation.
- in_activation¶
Activation function for inputs (default: sigmoid).
- Type:
Callable
- predictor¶
The prediction network.
- Type:
nn.Sequential
- Parameters:
in_features_endogenous – Number of input logit features.
out_features – Number of output concept features.
in_activation – Activation function to apply to input endogenous (default: torch.sigmoid).
Example
>>> import torch >>> from torch_concepts.nn import LinearCC >>> >>> # Create predictor >>> predictor = LinearCC( ... in_features_endogenous=10, ... out_features=5 ... ) >>> >>> # Forward pass >>> in_endogenous = torch.randn(2, 10) # batch_size=2, in_features=10 >>> out_endogenous = predictor(in_endogenous) >>> print(out_endogenous.shape) torch.Size([2, 5])
References
Koh et al. “Concept Bottleneck Models”, ICML 2020. https://arxiv.org/pdf/2007.04612
- forward(endogenous: Tensor) Tensor[source]¶
Forward pass through the predictor.
- Parameters:
endogenous – Input endogenous of shape (batch_size, in_features_endogenous).
- Returns:
Predicted concept probabilities of shape (batch_size, out_features).
- Return type:
- prune(mask: Tensor)[source]¶
Prune input features based on a binary mask.
Removes input features where mask is False/0, reducing model complexity.
- Parameters:
mask – Binary mask of shape (in_features_endogenous,) indicating which features to keep (True/1) or remove (False/0).
Example
>>> import torch >>> from torch_concepts.nn import LinearCC >>> >>> predictor = LinearCC(in_features_endogenous=10, out_features=5) >>> >>> # Prune first 3 features >>> mask = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=torch.bool) >>> predictor.prune(mask) >>> >>> # Now only accepts 7 input features >>> endogenous = torch.randn(2, 7) >>> probs = predictor(endogenous) >>> print(probs.shape) torch.Size([2, 5])
- class MixCUC(in_features_endogenous: int, in_features_exogenous: int, out_features: int, in_activation: ~typing.Callable = <built-in method sigmoid of type object>, cardinalities: ~typing.List[int] | None = None)[source]¶
Bases:
BasePredictorConcept exogenous predictor with mixture of concept activations and exogenous features.
This predictor implements the Concept Embedding Model (CEM) task predictor that combines concept activations with learned exogenous using a mixture operation.
Main reference: “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off” (Espinosa Zarlenga et al., NeurIPS 2022).
- predictor¶
Linear predictor module.
- Type:
nn.Module
- Parameters:
in_features_endogenous – Number of input concept endogenous.
in_features_exogenous – Number of exogenous features (must be even).
out_features – Number of output task features.
in_activation – Activation function for concept endogenous (default: sigmoid).
cardinalities – List of concept group cardinalities (optional).
Example
>>> import torch >>> from torch_concepts.nn import MixCUC >>> >>> # Create predictor with 10 concepts, 20 exogenous dims, 3 tasks >>> predictor = MixCUC( ... in_features_endogenous=10, ... in_features_exogenous=10, # Must be half of exogenous latent size when no cardinalities are provided ... out_features=3, ... in_activation=torch.sigmoid ... ) >>> >>> # Generate random inputs >>> concept_endogenous = torch.randn(4, 10) # batch_size=4, n_concepts=10 >>> exogenous = torch.randn(4, 10, 20) # (batch, n_concepts, emb_size) >>> >>> # Forward pass >>> task_endogenous = predictor(endogenous=concept_endogenous, exogenous=exogenous) >>> print(task_endogenous.shape) # torch.Size([4, 3]) >>> >>> # With concept groups (e.g., color has 3 values, shape has 4, etc.) >>> predictor_grouped = MixCUC( ... in_features_endogenous=10, ... in_features_exogenous=20, # Must be equal to exogenous latent size when cardinalities are provided ... out_features=3, ... cardinalities=[3, 4, 3] # 3 groups summing to 10 ... ) >>> >>> # Forward pass with grouped concepts >>> task_endogenous = predictor_grouped(endogenous=concept_endogenous, exogenous=exogenous) >>> print(task_endogenous.shape) # torch.Size([4, 3])
References
Espinosa Zarlenga et al. “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off”, NeurIPS 2022. https://arxiv.org/abs/2209.09056
- forward(endogenous: Tensor, exogenous: Tensor) Tensor[source]¶
Forward pass through the predictor.
- Parameters:
endogenous – Concept endogenous of shape (batch_size, n_concepts).
exogenous – Concept exogenous of shape (batch_size, n_concepts, emb_size).
- Returns:
Task predictions of shape (batch_size, out_features).
- Return type:
- class HyperLinearCUC(in_features_endogenous: int, in_features_exogenous: int, embedding_size: int, in_activation: ~typing.Callable = <function HyperLinearCUC.<lambda>>, use_bias: bool = True, init_bias_mean: float = 0.0, init_bias_std: float = 0.01, min_std: float = 1e-06)[source]¶
Bases:
BasePredictorHypernetwork-based linear predictor for concept-based models.
This predictor uses a hypernetwork to generate per-sample weights from exogenous features, enabling sample-adaptive predictions. It also supports stochastic biases with learnable mean and standard deviation.
- hypernet¶
Hypernetwork that generates weights.
- Type:
nn.Module
- Parameters:
in_features_endogenous – Number of input concept endogenous.
in_features_exogenous – Number of exogenous input features.
embedding_size – Hidden dimension of hypernetwork.
in_activation – Activation function for concepts (default: identity).
use_bias – Whether to add stochastic bias (default: True).
init_bias_mean – Initial mean for bias distribution (default: 0.0).
init_bias_std – Initial std for bias distribution (default: 0.01).
min_std – Minimum std to ensure stability (default: 1e-6).
Example
>>> import torch >>> from torch_concepts.nn import HyperLinearCUC >>> >>> # Create hypernetwork predictor >>> predictor = HyperLinearCUC( ... in_features_endogenous=10, # 10 concepts ... in_features_exogenous=128, # 128-dim context features ... embedding_size=64, # Hidden dim of hypernet ... use_bias=True ... ) >>> >>> # Generate random inputs >>> concept_endogenous = torch.randn(4, 10) # batch_size=4, n_concepts=10 >>> exogenous = torch.randn(4, 3, 128) # batch_size=4, n_tasks=3, exogenous_dim=128 >>> >>> # Forward pass - generates per-sample weights via hypernetwork >>> task_endogenous = predictor(endogenous=concept_endogenous, exogenous=exogenous) >>> print(task_endogenous.shape) # torch.Size([4, 3]) >>> >>> # The hypernetwork generates different weights for each sample >>> # This enables sample-adaptive predictions >>> >>> # Example without bias >>> predictor_no_bias = HyperLinearCUC( ... in_features_endogenous=10, ... in_features_exogenous=128, ... embedding_size=64, ... use_bias=False ... ) >>> >>> task_endogenous = predictor_no_bias(endogenous=concept_endogenous, exogenous=exogenous) >>> print(task_endogenous.shape) # torch.Size([4, 3])
References
Debot et al. “Interpretable Concept-Based Memory Reasoning”, NeurIPS 2024. https://arxiv.org/abs/2407.15527
- forward(endogenous: Tensor, exogenous: Tensor) Tensor[source]¶
Forward pass through hypernetwork predictor.
- Parameters:
endogenous – Concept endogenous of shape (batch_size, n_concepts).
exogenous – Exogenous features of shape (batch_size, exog_dim).
- Returns:
Task predictions of shape (batch_size, out_features).
- Return type:
- class CallableCC(func: ~typing.Callable, in_activation: ~typing.Callable = <function CallableCC.<lambda>>, use_bias: bool = True, init_bias_mean: float = 0.0, init_bias_std: float = 0.01, min_std: float = 1e-06)[source]¶
Bases:
BasePredictorA predictor that applies a custom callable function to concept representations.
This predictor allows flexible task prediction by accepting any callable function that operates on concept representations. It optionally includes learnable stochastic bias parameters (mean and standard deviation) that are added to the output using the reparameterization trick for gradient-based learning.
The module can be used to write custom layers for standard Structural Causal Models (SCMs).
- Parameters:
func – Callable function that takes concept probabilities and returns task predictions. Should accept a tensor of shape (batch_size, n_concepts) and return a tensor of shape (batch_size, out_features).
in_activation – Activation function to apply to input endogenous before passing to func. Default is identity (lambda x: x).
use_bias – Whether to add learnable stochastic bias to the output. Default is True.
init_bias_mean – Initial value for the bias mean parameter. Default is 0.0.
init_bias_std – Initial value for the bias standard deviation. Default is 0.01.
min_std – Minimum standard deviation floor for numerical stability. Default is 1e-6.
Examples
>>> import torch >>> from torch_concepts.nn import CallableCC >>> >>> # Generate sample data >>> batch_size = 32 >>> n_concepts = 3 >>> endogenous = torch.randn(batch_size, n_concepts) >>> >>> # Define a polynomial function with fixed weights for 3 inputs, 2 outputs >>> def quadratic_predictor(probs): ... c0, c1, c2 = probs[:, 0:1], probs[:, 1:2], probs[:, 2:3] ... output1 = 0.5*c0**2 + 1.0*c1**2 + 1.5*c2 ... output2 = 2.0*c0 - 1.0*c1**2 + 0.5*c2**3 ... return torch.cat([output1, output2], dim=1) >>> >>> predictor = CallableCC( ... func=quadratic_predictor, ... use_bias=True ... ) >>> predictions = predictor(endogenous) >>> print(predictions.shape) # torch.Size([32, 2])
- References
Pearl, J. “Causality”, Cambridge University Press (2009).
- forward(endogenous: Tensor, *args, **kwargs) Tensor[source]¶
Forward pass through the concept layer.
Must be implemented by subclasses.
- Returns:
Output tensor.
- Return type:
- Raises:
NotImplementedError – This is an abstract method.