Metrics¶
Concept-aware metrics with automatic routing and flexible tracking.
Summary¶
Metrics Classes
Metrics manager for concept-based models with automatic type-aware routing. |
Functional Metrics
Class Documentation¶
- class ConceptMetrics(annotations: Annotations, fn_collection: GroupConfig, summary_metrics: bool = True, perconcept_metrics: bool | List[str] = False)[source]¶
Bases:
ModuleMetrics manager for concept-based models with automatic type-aware routing.
This class organizes and manages metrics for different concept types (binary, categorical, continuous) with support for both summary metrics (aggregated across all concepts of a type) and per-concept metrics (individual tracking per concept).
The class automatically routes predictions to the appropriate metrics based on concept types defined in the annotations, handles different metric instantiation patterns, and maintains independent metric tracking across train/val/test splits.
- Parameters:
annotations (Annotations) – Concept annotations containing labels, types, and cardinalities. Should include axis 1 (concept axis) with metadata specifying concept types as ‘discrete’ or ‘continuous’.
fn_collection (GroupConfig) –
Metric configurations organized by concept type (‘binary’, ‘categorical’, ‘continuous’). Each metric can be specified in three ways:
Pre-instantiated metric: Pass an already instantiated metric object for full control over all parameters.
Example:
'accuracy': torchmetrics.classification.BinaryAccuracy(threshold=0.6)
Class with user kwargs: Pass a tuple of (MetricClass, kwargs_dict) to provide custom parameters while letting ConceptMetrics handle concept-specific parameters like num_classes automatically.
Example:
'accuracy': (torchmetrics.classification.MulticlassAccuracy, {'average': 'macro'})
Class only: Pass just the metric class and let ConceptMetrics handle all instantiation with appropriate concept-specific parameters.
Example:
'accuracy': torchmetrics.classification.MulticlassAccuracy
summary_metrics (bool, optional) – Whether to compute summary metrics that aggregate performance across all concepts of each type. Defaults to True.
perconcept_metrics (Union[bool, List[str]], optional) –
Controls per-concept metric tracking. Options:
False: No per-concept tracking (default)
True: Track all concepts individually
List[str]: Track only the specified concept names
- train_metrics¶
Metrics for training split
- Type:
MetricCollection
- val_metrics¶
Metrics for validation split
- Type:
MetricCollection
- test_metrics¶
Metrics for test split
- Type:
MetricCollection
- Raises:
NotImplementedError – If continuous concepts are found (not yet supported)
ValueError – If metric configuration doesn’t match concept types, or if user provides num_classes when it should be set automatically
Example
Basic usage with pre-instantiated metrics:
import torch import torchmetrics from torch_concepts import Annotations, AxisAnnotation from torch_concepts.nn.modules.metrics import ConceptMetrics from torch_concepts.nn.modules.utils import GroupConfig # Define concept structure annotations = Annotations({ 1: AxisAnnotation( labels=('round', 'smooth'), cardinalities=[1, 1], metadata={ 'round': {'type': 'discrete'}, 'smooth': {'type': 'discrete'} } ) }) # Create metrics with pre-instantiated objects metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={ 'accuracy': torchmetrics.classification.BinaryAccuracy(), 'f1': torchmetrics.classification.BinaryF1Score() } ), summary_metrics=True, perconcept_metrics=False ) # Simulate training batch predictions = torch.randn(32, 2) # endogenous predictions targets = torch.randint(0, 2, (32, 2)) # binary targets # Update metrics metrics.update(pred=predictions, target=targets, split='train') # Compute at epoch end results = metrics.compute('train') print(results) # {'train/SUMMARY-binary_accuracy': ..., 'train/SUMMARY-binary_f1': ...} # Reset for next epoch metrics.reset('train')
Using class + kwargs for flexible configuration:
# Mixed concept types with custom metric parameters annotations = Annotations({ 1: AxisAnnotation( labels=('binary1', 'binary2', 'category'), cardinalities=[1, 1, 5], metadata={ 'binary1': {'type': 'discrete'}, 'binary2': {'type': 'discrete'}, 'category': {'type': 'discrete'} } ) }) metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={ # Custom threshold 'accuracy': (torchmetrics.classification.BinaryAccuracy, {'threshold': 0.6}) }, categorical={ # Custom averaging, num_classes added automatically 'accuracy': (torchmetrics.classification.MulticlassAccuracy, {'average': 'macro'}) } ), summary_metrics=True, perconcept_metrics=True # Track all concepts individually ) # Predictions: 2 binary + 5 categorical = 7 dimensions predictions = torch.randn(16, 7) targets = torch.cat([ torch.randint(0, 2, (16, 2)), # binary torch.randint(0, 5, (16, 1)) # categorical ], dim=1) metrics.update(pred=predictions, target=targets, split='train') results = metrics.compute('train') # Results include both summary and per-concept metrics: # 'train/SUMMARY-binary_accuracy' # 'train/SUMMARY-categorical_accuracy' # 'train/binary1_accuracy' # 'train/binary2_accuracy' # 'train/category_accuracy'
Selective per-concept tracking:
# Track only specific concepts metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={'accuracy': torchmetrics.classification.BinaryAccuracy} ), summary_metrics=True, perconcept_metrics=['binary1'] # Only track binary1 individually )
Integration with PyTorch Lightning:
import pytorch_lightning as pl class ConceptModel(pl.LightningModule): def __init__(self, annotations): super().__init__() self.model = ... # your model self.metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={'accuracy': torchmetrics.classification.BinaryAccuracy} ), summary_metrics=True ) def training_step(self, batch, batch_idx): x, concepts = batch preds = self.model(x) # Update metrics self.metrics.update(pred=preds, target=concepts, split='train') return loss def on_train_epoch_end(self): # Compute and log metrics metrics_dict = self.metrics.compute('train') self.log_dict(metrics_dict) self.metrics.reset('train')
Note
Continuous concepts are not yet supported and will raise NotImplementedError
For categorical concepts, ConceptMetrics automatically handles padding to the maximum cardinality when computing summary metrics
User-provided ‘num_classes’ parameter for categorical metrics will raise an error as it’s set automatically based on concept cardinalities
Each split (train/val/test) maintains independent metric state
See also
torch_concepts.nn.modules.utils.GroupConfig: Configuration helpertorch_concepts.annotations.Annotations: Concept annotationsTorchMetrics Documentation: Available metrics and their parameters
- __init__(annotations: Annotations, fn_collection: GroupConfig, summary_metrics: bool = True, perconcept_metrics: bool | List[str] = False)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- get(key: str, default=None)[source]¶
Get a metric collection by key (dict-like interface).
- Parameters:
key (str) – Collection key (‘train_metrics’, ‘val_metrics’, ‘test_metrics’).
default – Default value to return if key not found.
- Returns:
MetricCollection or default value.
- update(preds: Tensor, target: Tensor, split: str = 'train')[source]¶
Update metrics with predictions and targets for a given split.
This method automatically routes predictions to the appropriate metrics based on concept types. For summary metrics, it aggregates all concepts of each type. For per-concept metrics, it extracts individual concept predictions.
The preds tensor should be in the endogenous space (after applying the concept distributions’ transformations), and the target tensor should contain the ground truth concept values.
- Parameters:
preds (torch.Tensor) –
Model predictions in endogenous space. Shape depends on concept types:
Binary concepts: (batch_size, n_binary_concepts)
Categorical concepts: (batch_size, sum of cardinalities)
Mixed: (batch_size, n_binary + sum of cat cardinalities)
target (torch.Tensor) –
Ground truth concept values. Shape (batch_size, n_concepts) where each column corresponds to a concept:
Binary concepts: float values in {0, 1}
Categorical concepts: integer class indices in {0, …, cardinality-1}
Continuous concepts: float values (not yet supported)
split (str, optional) –
Which data split to update. Must be one of:
’train’: Training split
’val’ or ‘validation’: Validation split
’test’: Test split
Defaults to ‘train’.
- Raises:
ValueError – If split is not one of ‘train’, ‘val’, ‘validation’, or ‘test’
NotImplementedError – If continuous concepts are encountered
Example
Basic update:
# Binary concepts only predictions = torch.randn(32, 3) # 3 binary concepts targets = torch.randint(0, 2, (32, 3)) # binary ground truth metrics.update(preds=predictions, target=targets, split='train')
Mixed concept types:
# 2 binary + 1 categorical (3 classes) # Endogenous space: 2 binary + 3 categorical = 5 dims predictions = torch.randn(32, 5) targets = torch.cat([ torch.randint(0, 2, (32, 2)), # binary targets torch.randint(0, 3, (32, 1)) # categorical target ], dim=1) metrics.update(preds=predictions, target=targets, split='train')
Validation split:
val_predictions = model(val_data) metrics.update(preds=val_predictions, target=val_targets, split='val') Note:
- compute(split: str = 'train')[source]¶
Compute final metric values from accumulated state for a split.
This method calculates the final metric values using all data accumulated through
update()calls since the lastreset(). It does not reset the metric state, allowing you to log results before resetting.- Parameters:
split (str, optional) – Which data split to compute metrics for. Must be one of ‘train’, ‘val’, ‘validation’, or ‘test’. Defaults to ‘train’.
- Returns:
- Dictionary mapping metric names (with split prefix) to computed
values. Keys follow the format:
Summary metrics: ‘{split}/SUMMARY-{type}_{metric_name}’
Per-concept metrics: ‘{split}/{concept_name}_{metric_name}’
Values are torch.Tensor objects containing the computed metric values.
- Return type:
- Raises:
ValueError – If split is not one of the valid options
Example
Basic compute:
# After updating with training data train_results = metrics.compute('train') print(train_results) # { # 'train/SUMMARY-binary_accuracy': tensor(0.8500), # 'train/SUMMARY-binary_f1': tensor(0.8234), # 'train/concept1_accuracy': tensor(0.9000), # 'train/concept2_accuracy': tensor(0.8000) # }
Compute multiple splits:
train_metrics = metrics.compute('train') val_metrics = metrics.compute('val') # Log to wandb or tensorboard logger.log_metrics(train_metrics) logger.log_metrics(val_metrics)
Extract specific metrics:
results = metrics.compute('val') accuracy = results['val/SUMMARY-binary_accuracy'].item() print(f"Validation accuracy: {accuracy:.2%}")
Note
This method can be called multiple times without resetting
Always call
reset()after logging to start fresh for next epochReturned tensors are on the same device as the metric state
- reset(split: str | None = None)[source]¶
Reset metric state for one or all splits.
This method resets the accumulated metric state, clearing all data from previous
update()calls. Call this after computing and logging metrics to prepare for the next epoch.- Parameters:
split (Optional[str], optional) –
Which split to reset. Options:
’train’: Reset only training metrics
’val’ or ‘validation’: Reset only validation metrics
’test’: Reset only test metrics
None: Reset all splits simultaneously (default)
- Raises:
ValueError – If split is not None and not a valid split name
Example
Reset single split:
# At end of training epoch train_metrics = metrics.compute('train') logger.log_metrics(train_metrics) metrics.reset('train') # Reset only training
Reset all splits:
# At end of validation train_metrics = metrics.compute('train') val_metrics = metrics.compute('val') logger.log_metrics({**train_metrics, **val_metrics}) metrics.reset() # Reset both train and val
Typical training loop:
for epoch in range(num_epochs): # Training for batch in train_loader: preds = model(batch) metrics.update(preds, targets, split='train') # Validation for batch in val_loader: preds = model(batch) metrics.update(preds, targets, split='val') # Compute and log train_results = metrics.compute('train') val_results = metrics.compute('val') log_metrics({**train_results, **val_results}) # Reset for next epoch metrics.reset() # Resets both train and val
Note
Resetting is essential to avoid mixing data from different epochs
Each split can be reset independently
Resetting does not affect the metric configuration, only the state
Functional Metrics¶
Calculate the completeness score for the given predictions and true labels. |
|
Compute the effect of concept interventions on downstream task predictions. |
|
Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score. |
- completeness_score(y_true, y_pred_blackbox, y_pred_whitebox, scorer=<function roc_auc_score>, average='macro')[source]¶
Calculate the completeness score for the given predictions and true labels. Main reference: “On Completeness-aware Concept-Based Explanations in Deep Neural Networks”
- Parameters:
y_true (torch.Tensor) – True labels.
y_pred_blackbox (torch.Tensor) – Predictions from the blackbox model.
y_pred_whitebox (torch.Tensor) – Predictions from the whitebox model.
scorer (function) – Scoring function to evaluate predictions. Default is roc_auc_score.
average (str) – Type of averaging to use. Default is ‘macro’.
- Returns:
Completeness score.
- Return type:
- intervention_score(y_predictor: ~torch.nn.modules.module.Module, c_pred: ~torch.Tensor, c_true: ~torch.Tensor, y_true: ~torch.Tensor, intervention_groups: ~typing.List[~typing.List[int]], activation: ~typing.Callable = <built-in method sigmoid of type object>, scorer: ~typing.Callable = <function roc_auc_score>, average: str = 'macro', auc: bool = True) float | List[float][source]¶
Compute the effect of concept interventions on downstream task predictions.
Given set of intervention groups, the intervention score measures the effectiveness of each intervention group on the model’s task predictions.
Main reference: “Concept Bottleneck Models”
- Parameters:
y_predictor (torch.nn.Module) – Model that predicts downstream task abels.
c_pred (torch.Tensor) – Predicted concept values.
c_true (torch.Tensor) – Ground truth concept values.
y_true (torch.Tensor) – Ground truth task labels.
intervention_groups (List[List[int]]) – List of intervention groups.
activation (Callable) – Activation function to apply to the model’s predictions. Default is torch.sigmoid.
scorer (Callable) – Scoring function to evaluate predictions. Default is roc_auc_score.
average (str) – Type of averaging to use. Default is ‘macro’.
auc (bool) – Whether to return the average score across all intervention groups. Default is True.
- Returns:
- The intervention effectiveness for each
intervention group or the average score across all groups.
- Return type:
- cace_score(y_pred_c0, y_pred_c1)[source]¶
Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.
The ACE/CaCE score measures the causal effect of a concept on the predictions of a model. It is computed as the absolute difference between the expected predictions when the concept is inactive (c0) and active (c1).
Main reference: “Explaining Classifiers with Causal Concept Effect (CaCE)”
- Parameters:
y_pred_c0 (torch.Tensor) – Predictions of the model when the concept is inactive. Shape: (batch_size, num_classes).
y_pred_c1 (torch.Tensor) – Predictions of the model when the concept is active. Shape: (batch_size, num_classes).
- Returns:
The ACE/CaCE score for each class. Shape: (num_classes,).
- Return type: