Functional API

This module provides functional operations for concept-based computations.

Summary

Concept Operations

grouped_concept_exogenous_mixture

Vectorized version of grouped concept exogenous mixture.

selection_eval

Evaluate concept selection by computing weighted predictions.

confidence_selection

Selects concepts with confidence above a selected threshold.

soft_select

Soft selection function, a special activation function for a network rescaling the output such that, if they are uniformly distributed, then we will select only half of them.

Linear and Logic Operations

linear_equation_eval

Function to evaluate a set of linear equations with concept predictions.

linear_equation_expl

Extract linear equations from decoded equations embeddings as strings. :param concept_weights: Equation embeddings with shape (batch_size, memory_size, n_concepts, n_tasks). :param bias: Bias term to add to the linear models (batch_size, memory_size, n_tasks). :param concept_names: Concept and task names. If the bias is included, the concept names should include the bias name.

logic_rule_eval

Use concept weights to make predictions based on logic rules.

logic_memory_reconstruction

Reconstruct tasks based on concept reconstructions, ground truth concepts and ground truth tasks.

logic_rule_explanations

Extracts rules from rule concept weights as strings.

Evaluation Metrics

completeness_score

Calculate the completeness score for the given predictions and true labels.

intervention_score

Compute the effect of concept interventions on downstream task predictions.

cace_score

Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.

residual_concept_causal_effect

Compute the residual concept causal effect between two concepts.

Calibration and Selection

selective_calibration

Selects concepts based on confidence scores and target coverage.

Graph Utilities

Model Utilities

prune_linear_layer

Return a new nn.Linear where inputs (dim=0) or outputs (dim=1) have been pruned according to mask.

Function Documentation

Concept Operations

grouped_concept_exogenous_mixture(c_emb: Tensor, c_scores: Tensor, groups: list[int]) Tensor[source]

Vectorized version of grouped concept exogenous mixture.

Extends to handle grouped concepts where some groups may contain multiple related concepts. Adapted from “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off” (Espinosa Zarlenga et al., 2022).

Parameters:
  • c_emb – Concept exogenous of shape (B, n_concepts, emb_size).

  • c_scores – Concept scores of shape (B, sum(groups)).

  • groups – List of group sizes (e.g., [3, 4] for two groups).

Returns:

Mixed exogenous of shape (B, len(groups), emb_size // 2).

Return type:

Tensor

Raises:

References

Espinosa Zarlenga et al. “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off”, NeurIPS 2022. https://arxiv.org/abs/2209.09056

Example

>>> import torch
>>> from torch_concepts.nn.functional import grouped_concept_exogenous_mixture
>>>
>>> # 10 concepts in 3 groups: [3, 4, 3]
>>> # Embedding size = 20 (must be even)
>>> batch_size = 4
>>> n_concepts = 10
>>> emb_size = 20
>>> groups = [3, 4, 3]
>>>
>>> # Generate random latent and scores
>>> c_emb = torch.randn(batch_size, n_concepts, emb_size)
>>> c_scores = torch.rand(batch_size, n_concepts)  # Probabilities
>>>
>>> # Apply grouped mixture
>>> mixed = grouped_concept_exogenous_mixture(c_emb, c_scores, groups)
>>> print(mixed.shape)  # torch.Size([4, 3, 10])
>>> # Output shape: (batch_size, n_groups, emb_size // 2)
>>>
>>> # Singleton groups use two-half mixture
>>> # Multi-concept groups use weighted average of base exogenous
selection_eval(selection_weights: Tensor, *predictions: Tensor) Tensor[source]

Evaluate concept selection by computing weighted predictions.

Parameters:
  • selection_weights – Weights for selecting between predictions.

  • *predictions – Variable number of prediction tensors to combine.

Returns:

Weighted combination of predictions.

Return type:

Tensor

confidence_selection(c_confidence: Tensor, theta: Tensor) Tensor[source]

Selects concepts with confidence above a selected threshold.

Parameters:
  • c_confidence – Concept confidence scores.

  • theta – Threshold to select confident predictions.

Returns:

mask selecting confident predictions.

Return type:

Tensor

soft_select(values, temperature, dim=1) Tensor[source]

Soft selection function, a special activation function for a network rescaling the output such that, if they are uniformly distributed, then we will select only half of them. A higher temperature will select more concepts, a lower temperature will select fewer concepts.

Parameters:
  • values – Output of the network.

  • temperature – Temperature for the softmax function [-inf, +inf].

  • dim – dimension to apply the softmax function. Default is 1.

Returns:

Soft selection scores.

Return type:

Tensor

Linear and Logic Operations

linear_equation_eval(concept_weights: Tensor, c_pred: Tensor, bias: Tensor | None = None) Tensor[source]

Function to evaluate a set of linear equations with concept predictions. In this case we have one equation (concept_weights) for each sample in the batch.

Parameters:
  • concept_weights – Parameters representing the weights of multiple linear models with shape (batch_size, memory_size, n_concepts, n_classes).

  • c_pred – Concept predictions with shape (batch_size, n_concepts).

  • bias – Bias term to add to the linear models (batch_size, memory_size, n_classes).

Returns:

Predictions made by the linear models with shape (batch_size,

n_classes, memory_size).

Return type:

Tensor

linear_equation_expl(concept_weights: Tensor, bias: Tensor | None = None, concept_names: Dict[int, List[str]] | None = None) List[Dict[str, Dict[str, str]]][source]

Extract linear equations from decoded equations embeddings as strings. :param concept_weights: Equation embeddings with shape (batch_size,

memory_size, n_concepts, n_tasks).

Parameters:
  • bias – Bias term to add to the linear models (batch_size, memory_size, n_tasks).

  • concept_names – Concept and task names. If the bias is included, the concept names should include the bias name.

Returns:

List of predicted equations as strings.

Return type:

List[Dict[str, Dict[str, str]]]

logic_rule_eval(concept_weights: ~torch.Tensor, c_pred: ~torch.Tensor, memory_idxs: ~torch.Tensor | None = None, semantic=<torch_concepts.nn.modules.low.semantic.CMRSemantic object>) Tensor[source]

Use concept weights to make predictions based on logic rules.

Parameters:
  • concept_weights – concept weights with shape (batch_size, memory_size, n_concepts, n_tasks, n_roles) with n_roles=3.

  • c_pred – concept predictions with shape (batch_size, n_concepts).

  • memory_idxs – Indices of rules to evaluate with shape (batch_size, n_tasks). Default is None (evaluate all).

  • semantic – Semantic function to use for rule evaluation.

Returns:

Rule predictions with shape (batch_size, n_tasks,

memory_size)

Return type:

torch.Tensor

logic_memory_reconstruction(concept_weights: Tensor, c_true: Tensor, y_true: Tensor) Tensor[source]

Reconstruct tasks based on concept reconstructions, ground truth concepts and ground truth tasks.

Parameters:
  • concept_weights – concept reconstructions with shape (batch_size, memory_size, n_concepts, n_tasks).

  • c_true – concept ground truth with shape (batch_size, n_concepts).

  • y_true – task ground truth with shape (batch_size, n_tasks).

Returns:

Reconstructed tasks with shape (batch_size, n_tasks,

memory_size).

Return type:

torch.Tensor

logic_rule_explanations(concept_logic_weights: Tensor, concept_names: Dict[int, List[str]] | None = None) List[Dict[str, Dict[str, str]]][source]

Extracts rules from rule concept weights as strings.

Parameters:
  • concept_logic_weights – Rule embeddings with shape (batch_size, memory_size, n_concepts, n_tasks, 3).

  • concept_names – Concept and task names.

Returns:

Rules as strings.

Return type:

List[Dict[str, Dict[str, str]]]

Evaluation Metrics

completeness_score(y_true, y_pred_blackbox, y_pred_whitebox, scorer=<function roc_auc_score>, average='macro')[source]

Calculate the completeness score for the given predictions and true labels. Main reference: “On Completeness-aware Concept-Based Explanations in Deep Neural Networks”

Parameters:
  • y_true (torch.Tensor) – True labels.

  • y_pred_blackbox (torch.Tensor) – Predictions from the blackbox model.

  • y_pred_whitebox (torch.Tensor) – Predictions from the whitebox model.

  • scorer (function) – Scoring function to evaluate predictions. Default is roc_auc_score.

  • average (str) – Type of averaging to use. Default is ‘macro’.

Returns:

Completeness score.

Return type:

float

intervention_score(y_predictor: ~torch.nn.modules.module.Module, c_pred: ~torch.Tensor, c_true: ~torch.Tensor, y_true: ~torch.Tensor, intervention_groups: ~typing.List[~typing.List[int]], activation: ~typing.Callable = <built-in method sigmoid of type object>, scorer: ~typing.Callable = <function roc_auc_score>, average: str = 'macro', auc: bool = True) float | List[float][source]

Compute the effect of concept interventions on downstream task predictions.

Given set of intervention groups, the intervention score measures the effectiveness of each intervention group on the model’s task predictions.

Main reference: “Concept Bottleneck Models”

Parameters:
  • y_predictor (torch.nn.Module) – Model that predicts downstream task abels.

  • c_pred (torch.Tensor) – Predicted concept values.

  • c_true (torch.Tensor) – Ground truth concept values.

  • y_true (torch.Tensor) – Ground truth task labels.

  • intervention_groups (List[List[int]]) – List of intervention groups.

  • activation (Callable) – Activation function to apply to the model’s predictions. Default is torch.sigmoid.

  • scorer (Callable) – Scoring function to evaluate predictions. Default is roc_auc_score.

  • average (str) – Type of averaging to use. Default is ‘macro’.

  • auc (bool) – Whether to return the average score across all intervention groups. Default is True.

Returns:

The intervention effectiveness for each

intervention group or the average score across all groups.

Return type:

Union[float, List[float]]

cace_score(y_pred_c0, y_pred_c1)[source]

Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.

The ACE/CaCE score measures the causal effect of a concept on the predictions of a model. It is computed as the absolute difference between the expected predictions when the concept is inactive (c0) and active (c1).

Main reference: “Explaining Classifiers with Causal Concept Effect (CaCE)”

Parameters:
  • y_pred_c0 (torch.Tensor) – Predictions of the model when the concept is inactive. Shape: (batch_size, num_classes).

  • y_pred_c1 (torch.Tensor) – Predictions of the model when the concept is active. Shape: (batch_size, num_classes).

Returns:

The ACE/CaCE score for each class. Shape: (num_classes,).

Return type:

torch.Tensor

residual_concept_causal_effect(cace_before, cace_after)[source]

Compute the residual concept causal effect between two concepts. :param cace_metric_before: ConceptCausalEffect metric before the do-intervention on the inner concept :param cace_metric_after: ConceptCausalEffect metric after do-intervention on the inner concept

Calibration and Selection

selective_calibration(c_confidence: Tensor, target_coverage: float) Tensor[source]

Selects concepts based on confidence scores and target coverage.

Parameters:
  • c_confidence – Concept confidence scores.

  • target_coverage – Target coverage.

Returns:

Thresholds to select confident predictions.

Return type:

Tensor

Graph Utilities

edge_type(graph, i, j)[source]

Model Utilities

prune_linear_layer(linear: Linear, mask: Tensor, dim: int = 0) Linear[source]

Return a new nn.Linear where inputs (dim=0) or outputs (dim=1) have been pruned according to mask.

Parameters:
  • linear (nn.Linear) – Layer to prune.

  • mask (1D Tensor[bool] or 0/1) – Mask over features. True/1 = keep, False/0 = drop. - If dim=0: length == in_features - If dim=1: length == out_features

  • dim (int) – 0 -> prune input features (columns of weight) 1 -> prune output units (rows of weight)