Functional API¶
This module provides functional operations for concept-based computations.
Summary¶
Concept Operations
Vectorized version of grouped concept exogenous mixture. |
|
Evaluate concept selection by computing weighted predictions. |
|
Selects concepts with confidence above a selected threshold. |
|
Soft selection function, a special activation function for a network rescaling the output such that, if they are uniformly distributed, then we will select only half of them. |
Linear and Logic Operations
Function to evaluate a set of linear equations with concept predictions. |
|
Extract linear equations from decoded equations embeddings as strings. :param concept_weights: Equation embeddings with shape (batch_size, memory_size, n_concepts, n_tasks). :param bias: Bias term to add to the linear models (batch_size, memory_size, n_tasks). :param concept_names: Concept and task names. If the bias is included, the concept names should include the bias name. |
|
Use concept weights to make predictions based on logic rules. |
|
Reconstruct tasks based on concept reconstructions, ground truth concepts and ground truth tasks. |
|
Extracts rules from rule concept weights as strings. |
Evaluation Metrics
Calculate the completeness score for the given predictions and true labels. |
|
Compute the effect of concept interventions on downstream task predictions. |
|
Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score. |
|
Compute the residual concept causal effect between two concepts. |
Calibration and Selection
Selects concepts based on confidence scores and target coverage. |
Graph Utilities
Model Utilities
Return a new nn.Linear where inputs (dim=0) or outputs (dim=1) have been pruned according to mask. |
Function Documentation¶
Concept Operations¶
- grouped_concept_exogenous_mixture(c_emb: Tensor, c_scores: Tensor, groups: list[int]) Tensor[source]¶
Vectorized version of grouped concept exogenous mixture.
Extends to handle grouped concepts where some groups may contain multiple related concepts. Adapted from “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off” (Espinosa Zarlenga et al., 2022).
- Parameters:
c_emb – Concept exogenous of shape (B, n_concepts, emb_size).
c_scores – Concept scores of shape (B, sum(groups)).
groups – List of group sizes (e.g., [3, 4] for two groups).
- Returns:
Mixed exogenous of shape (B, len(groups), emb_size // 2).
- Return type:
Tensor
- Raises:
AssertionError – If group sizes don’t sum to n_concepts.
AssertionError – If exogenous dimension is not even.
References
Espinosa Zarlenga et al. “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off”, NeurIPS 2022. https://arxiv.org/abs/2209.09056
Example
>>> import torch >>> from torch_concepts.nn.functional import grouped_concept_exogenous_mixture >>> >>> # 10 concepts in 3 groups: [3, 4, 3] >>> # Embedding size = 20 (must be even) >>> batch_size = 4 >>> n_concepts = 10 >>> emb_size = 20 >>> groups = [3, 4, 3] >>> >>> # Generate random latent and scores >>> c_emb = torch.randn(batch_size, n_concepts, emb_size) >>> c_scores = torch.rand(batch_size, n_concepts) # Probabilities >>> >>> # Apply grouped mixture >>> mixed = grouped_concept_exogenous_mixture(c_emb, c_scores, groups) >>> print(mixed.shape) # torch.Size([4, 3, 10]) >>> # Output shape: (batch_size, n_groups, emb_size // 2) >>> >>> # Singleton groups use two-half mixture >>> # Multi-concept groups use weighted average of base exogenous
- selection_eval(selection_weights: Tensor, *predictions: Tensor) Tensor[source]¶
Evaluate concept selection by computing weighted predictions.
- Parameters:
selection_weights – Weights for selecting between predictions.
*predictions – Variable number of prediction tensors to combine.
- Returns:
Weighted combination of predictions.
- Return type:
Tensor
- confidence_selection(c_confidence: Tensor, theta: Tensor) Tensor[source]¶
Selects concepts with confidence above a selected threshold.
- Parameters:
c_confidence – Concept confidence scores.
theta – Threshold to select confident predictions.
- Returns:
mask selecting confident predictions.
- Return type:
Tensor
- soft_select(values, temperature, dim=1) Tensor[source]¶
Soft selection function, a special activation function for a network rescaling the output such that, if they are uniformly distributed, then we will select only half of them. A higher temperature will select more concepts, a lower temperature will select fewer concepts.
- Parameters:
values – Output of the network.
temperature – Temperature for the softmax function [-inf, +inf].
dim – dimension to apply the softmax function. Default is 1.
- Returns:
Soft selection scores.
- Return type:
Tensor
Linear and Logic Operations¶
- linear_equation_eval(concept_weights: Tensor, c_pred: Tensor, bias: Tensor | None = None) Tensor[source]¶
Function to evaluate a set of linear equations with concept predictions. In this case we have one equation (concept_weights) for each sample in the batch.
- Parameters:
concept_weights – Parameters representing the weights of multiple linear models with shape (batch_size, memory_size, n_concepts, n_classes).
c_pred – Concept predictions with shape (batch_size, n_concepts).
bias – Bias term to add to the linear models (batch_size, memory_size, n_classes).
- Returns:
- Predictions made by the linear models with shape (batch_size,
n_classes, memory_size).
- Return type:
Tensor
- linear_equation_expl(concept_weights: Tensor, bias: Tensor | None = None, concept_names: Dict[int, List[str]] | None = None) List[Dict[str, Dict[str, str]]][source]¶
Extract linear equations from decoded equations embeddings as strings. :param concept_weights: Equation embeddings with shape (batch_size,
memory_size, n_concepts, n_tasks).
- logic_rule_eval(concept_weights: ~torch.Tensor, c_pred: ~torch.Tensor, memory_idxs: ~torch.Tensor | None = None, semantic=<torch_concepts.nn.modules.low.semantic.CMRSemantic object>) Tensor[source]¶
Use concept weights to make predictions based on logic rules.
- Parameters:
concept_weights – concept weights with shape (batch_size, memory_size, n_concepts, n_tasks, n_roles) with n_roles=3.
c_pred – concept predictions with shape (batch_size, n_concepts).
memory_idxs – Indices of rules to evaluate with shape (batch_size, n_tasks). Default is None (evaluate all).
semantic – Semantic function to use for rule evaluation.
- Returns:
- Rule predictions with shape (batch_size, n_tasks,
memory_size)
- Return type:
- logic_memory_reconstruction(concept_weights: Tensor, c_true: Tensor, y_true: Tensor) Tensor[source]¶
Reconstruct tasks based on concept reconstructions, ground truth concepts and ground truth tasks.
- Parameters:
concept_weights – concept reconstructions with shape (batch_size, memory_size, n_concepts, n_tasks).
c_true – concept ground truth with shape (batch_size, n_concepts).
y_true – task ground truth with shape (batch_size, n_tasks).
- Returns:
- Reconstructed tasks with shape (batch_size, n_tasks,
memory_size).
- Return type:
Evaluation Metrics¶
- completeness_score(y_true, y_pred_blackbox, y_pred_whitebox, scorer=<function roc_auc_score>, average='macro')[source]¶
Calculate the completeness score for the given predictions and true labels. Main reference: “On Completeness-aware Concept-Based Explanations in Deep Neural Networks”
- Parameters:
y_true (torch.Tensor) – True labels.
y_pred_blackbox (torch.Tensor) – Predictions from the blackbox model.
y_pred_whitebox (torch.Tensor) – Predictions from the whitebox model.
scorer (function) – Scoring function to evaluate predictions. Default is roc_auc_score.
average (str) – Type of averaging to use. Default is ‘macro’.
- Returns:
Completeness score.
- Return type:
- intervention_score(y_predictor: ~torch.nn.modules.module.Module, c_pred: ~torch.Tensor, c_true: ~torch.Tensor, y_true: ~torch.Tensor, intervention_groups: ~typing.List[~typing.List[int]], activation: ~typing.Callable = <built-in method sigmoid of type object>, scorer: ~typing.Callable = <function roc_auc_score>, average: str = 'macro', auc: bool = True) float | List[float][source]¶
Compute the effect of concept interventions on downstream task predictions.
Given set of intervention groups, the intervention score measures the effectiveness of each intervention group on the model’s task predictions.
Main reference: “Concept Bottleneck Models”
- Parameters:
y_predictor (torch.nn.Module) – Model that predicts downstream task abels.
c_pred (torch.Tensor) – Predicted concept values.
c_true (torch.Tensor) – Ground truth concept values.
y_true (torch.Tensor) – Ground truth task labels.
intervention_groups (List[List[int]]) – List of intervention groups.
activation (Callable) – Activation function to apply to the model’s predictions. Default is torch.sigmoid.
scorer (Callable) – Scoring function to evaluate predictions. Default is roc_auc_score.
average (str) – Type of averaging to use. Default is ‘macro’.
auc (bool) – Whether to return the average score across all intervention groups. Default is True.
- Returns:
- The intervention effectiveness for each
intervention group or the average score across all groups.
- Return type:
- cace_score(y_pred_c0, y_pred_c1)[source]¶
Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.
The ACE/CaCE score measures the causal effect of a concept on the predictions of a model. It is computed as the absolute difference between the expected predictions when the concept is inactive (c0) and active (c1).
Main reference: “Explaining Classifiers with Causal Concept Effect (CaCE)”
- Parameters:
y_pred_c0 (torch.Tensor) – Predictions of the model when the concept is inactive. Shape: (batch_size, num_classes).
y_pred_c1 (torch.Tensor) – Predictions of the model when the concept is active. Shape: (batch_size, num_classes).
- Returns:
The ACE/CaCE score for each class. Shape: (num_classes,).
- Return type:
- residual_concept_causal_effect(cace_before, cace_after)[source]¶
Compute the residual concept causal effect between two concepts. :param cace_metric_before: ConceptCausalEffect metric before the do-intervention on the inner concept :param cace_metric_after: ConceptCausalEffect metric after do-intervention on the inner concept
Calibration and Selection¶
- selective_calibration(c_confidence: Tensor, target_coverage: float) Tensor[source]¶
Selects concepts based on confidence scores and target coverage.
- Parameters:
c_confidence – Concept confidence scores.
target_coverage – Target coverage.
- Returns:
Thresholds to select confident predictions.
- Return type:
Tensor