Metrics

Concept-aware metrics with automatic routing and flexible tracking.

Summary

Metrics Classes

ConceptMetrics

Metrics manager for concept-based models with automatic type-aware routing.

Functional Metrics

Class Documentation

class ConceptMetrics(annotations: Annotations, fn_collection: GroupConfig, summary_metrics: bool = True, perconcept_metrics: bool | List[str] = False)[source]

Bases: Module

Metrics manager for concept-based models with automatic type-aware routing.

This class organizes and manages metrics for different concept types (binary, categorical, continuous) with support for both summary metrics (aggregated across all concepts of a type) and per-concept metrics (individual tracking per concept).

The class automatically routes predictions to the appropriate metrics based on concept types defined in the annotations, handles different metric instantiation patterns, and maintains independent metric tracking across train/val/test splits.

Parameters:
  • annotations (Annotations) – Concept annotations containing labels, types, and cardinalities. Should include axis 1 (concept axis) with metadata specifying concept types as ‘discrete’ or ‘continuous’.

  • fn_collection (GroupConfig) –

    Metric configurations organized by concept type (‘binary’, ‘categorical’, ‘continuous’). Each metric can be specified in three ways:

    1. Pre-instantiated metric: Pass an already instantiated metric object for full control over all parameters.

      Example:

      'accuracy': torchmetrics.classification.BinaryAccuracy(threshold=0.6)
      
    2. Class with user kwargs: Pass a tuple of (MetricClass, kwargs_dict) to provide custom parameters while letting ConceptMetrics handle concept-specific parameters like num_classes automatically.

      Example:

      'accuracy': (torchmetrics.classification.MulticlassAccuracy,
                  {'average': 'macro'})
      
    3. Class only: Pass just the metric class and let ConceptMetrics handle all instantiation with appropriate concept-specific parameters.

      Example:

      'accuracy': torchmetrics.classification.MulticlassAccuracy
      

  • summary_metrics (bool, optional) – Whether to compute summary metrics that aggregate performance across all concepts of each type. Defaults to True.

  • perconcept_metrics (Union[bool, List[str]], optional) –

    Controls per-concept metric tracking. Options:

    • False: No per-concept tracking (default)

    • True: Track all concepts individually

    • List[str]: Track only the specified concept names

n_concepts

Total number of concepts

Type:

int

concept_names

Names of all concepts

Type:

Tuple[str]

cardinalities

Number of classes for each concept

Type:

List[int]

summary_metrics

Whether summary metrics are computed

Type:

bool

perconcept_metrics

Per-concept tracking configuration

Type:

Union[bool, List[str]]

train_metrics

Metrics for training split

Type:

MetricCollection

val_metrics

Metrics for validation split

Type:

MetricCollection

test_metrics

Metrics for test split

Type:

MetricCollection

Raises:
  • NotImplementedError – If continuous concepts are found (not yet supported)

  • ValueError – If metric configuration doesn’t match concept types, or if user provides num_classes when it should be set automatically

Example

Basic usage with pre-instantiated metrics:

import torch
import torchmetrics
from torch_concepts import Annotations, AxisAnnotation
from torch_concepts.nn.modules.metrics import ConceptMetrics
from torch_concepts.nn.modules.utils import GroupConfig

# Define concept structure
annotations = Annotations({
    1: AxisAnnotation(
        labels=('round', 'smooth'),
        cardinalities=[1, 1],
        metadata={
            'round': {'type': 'discrete'},
            'smooth': {'type': 'discrete'}
        }
    )
})

# Create metrics with pre-instantiated objects
metrics = ConceptMetrics(
    annotations=annotations,
    fn_collection=GroupConfig(
        binary={
            'accuracy': torchmetrics.classification.BinaryAccuracy(),
            'f1': torchmetrics.classification.BinaryF1Score()
        }
    ),
    summary_metrics=True,
    perconcept_metrics=False
)

# Simulate training batch
predictions = torch.randn(32, 2)  # endogenous predictions
targets = torch.randint(0, 2, (32, 2))  # binary targets

# Update metrics
metrics.update(pred=predictions, target=targets, split='train')

# Compute at epoch end
results = metrics.compute('train')
print(results)  # {'train/SUMMARY-binary_accuracy': ..., 'train/SUMMARY-binary_f1': ...}

# Reset for next epoch
metrics.reset('train')

Using class + kwargs for flexible configuration:

# Mixed concept types with custom metric parameters
annotations = Annotations({
    1: AxisAnnotation(
        labels=('binary1', 'binary2', 'category'),
        cardinalities=[1, 1, 5],
        metadata={
            'binary1': {'type': 'discrete'},
            'binary2': {'type': 'discrete'},
            'category': {'type': 'discrete'}
        }
    )
})

metrics = ConceptMetrics(
    annotations=annotations,
    fn_collection=GroupConfig(
        binary={
            # Custom threshold
            'accuracy': (torchmetrics.classification.BinaryAccuracy,
                       {'threshold': 0.6})
        },
        categorical={
            # Custom averaging, num_classes added automatically
            'accuracy': (torchmetrics.classification.MulticlassAccuracy,
                       {'average': 'macro'})
        }
    ),
    summary_metrics=True,
    perconcept_metrics=True  # Track all concepts individually
)

# Predictions: 2 binary + 5 categorical = 7 dimensions
predictions = torch.randn(16, 7)
targets = torch.cat([
    torch.randint(0, 2, (16, 2)),  # binary
    torch.randint(0, 5, (16, 1))   # categorical
], dim=1)

metrics.update(pred=predictions, target=targets, split='train')
results = metrics.compute('train')

# Results include both summary and per-concept metrics:
# 'train/SUMMARY-binary_accuracy'
# 'train/SUMMARY-categorical_accuracy'
# 'train/binary1_accuracy'
# 'train/binary2_accuracy'
# 'train/category_accuracy'

Selective per-concept tracking:

# Track only specific concepts
metrics = ConceptMetrics(
    annotations=annotations,
    fn_collection=GroupConfig(
        binary={'accuracy': torchmetrics.classification.BinaryAccuracy}
    ),
    summary_metrics=True,
    perconcept_metrics=['binary1']  # Only track binary1 individually
)

Integration with PyTorch Lightning:

import pytorch_lightning as pl

class ConceptModel(pl.LightningModule):
    def __init__(self, annotations):
        super().__init__()
        self.model = ... # your model
        self.metrics = ConceptMetrics(
            annotations=annotations,
            fn_collection=GroupConfig(
                binary={'accuracy': torchmetrics.classification.BinaryAccuracy}
            ),
            summary_metrics=True
        )

    def training_step(self, batch, batch_idx):
        x, concepts = batch
        preds = self.model(x)

        # Update metrics
        self.metrics.update(pred=preds, target=concepts, split='train')
        return loss

    def on_train_epoch_end(self):
        # Compute and log metrics
        metrics_dict = self.metrics.compute('train')
        self.log_dict(metrics_dict)
        self.metrics.reset('train')

Note

  • Continuous concepts are not yet supported and will raise NotImplementedError

  • For categorical concepts, ConceptMetrics automatically handles padding to the maximum cardinality when computing summary metrics

  • User-provided ‘num_classes’ parameter for categorical metrics will raise an error as it’s set automatically based on concept cardinalities

  • Each split (train/val/test) maintains independent metric state

See also

__init__(annotations: Annotations, fn_collection: GroupConfig, summary_metrics: bool = True, perconcept_metrics: bool | List[str] = False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

__repr__() str[source]

Return repr(self).

get(key: str, default=None)[source]

Get a metric collection by key (dict-like interface).

Parameters:
  • key (str) – Collection key (‘train_metrics’, ‘val_metrics’, ‘test_metrics’).

  • default – Default value to return if key not found.

Returns:

MetricCollection or default value.

update(preds: Tensor, target: Tensor, split: str = 'train')[source]

Update metrics with predictions and targets for a given split.

This method automatically routes predictions to the appropriate metrics based on concept types. For summary metrics, it aggregates all concepts of each type. For per-concept metrics, it extracts individual concept predictions.

The preds tensor should be in the endogenous space (after applying the concept distributions’ transformations), and the target tensor should contain the ground truth concept values.

Parameters:
  • preds (torch.Tensor) –

    Model predictions in endogenous space. Shape depends on concept types:

    • Binary concepts: (batch_size, n_binary_concepts)

    • Categorical concepts: (batch_size, sum of cardinalities)

    • Mixed: (batch_size, n_binary + sum of cat cardinalities)

  • target (torch.Tensor) –

    Ground truth concept values. Shape (batch_size, n_concepts) where each column corresponds to a concept:

    • Binary concepts: float values in {0, 1}

    • Categorical concepts: integer class indices in {0, …, cardinality-1}

    • Continuous concepts: float values (not yet supported)

  • split (str, optional) –

    Which data split to update. Must be one of:

    • ’train’: Training split

    • ’val’ or ‘validation’: Validation split

    • ’test’: Test split

    Defaults to ‘train’.

Raises:
  • ValueError – If split is not one of ‘train’, ‘val’, ‘validation’, or ‘test’

  • NotImplementedError – If continuous concepts are encountered

Example

Basic update:

# Binary concepts only
predictions = torch.randn(32, 3)  # 3 binary concepts
targets = torch.randint(0, 2, (32, 3))  # binary ground truth

metrics.update(preds=predictions, target=targets, split='train')

Mixed concept types:

# 2 binary + 1 categorical (3 classes)
# Endogenous space: 2 binary + 3 categorical = 5 dims
predictions = torch.randn(32, 5)
targets = torch.cat([
    torch.randint(0, 2, (32, 2)),  # binary targets
    torch.randint(0, 3, (32, 1))   # categorical target
], dim=1)

metrics.update(preds=predictions, target=targets, split='train')

Validation split:

val_predictions = model(val_data)
metrics.update(preds=val_predictions, target=val_targets, split='val')        Note:
  • This method accumulates metric state across multiple batches

  • Call compute() to calculate final metric values

  • Call reset() after computing to start fresh for next epoch

  • Each split maintains independent state

compute(split: str = 'train')[source]

Compute final metric values from accumulated state for a split.

This method calculates the final metric values using all data accumulated through update() calls since the last reset(). It does not reset the metric state, allowing you to log results before resetting.

Parameters:

split (str, optional) – Which data split to compute metrics for. Must be one of ‘train’, ‘val’, ‘validation’, or ‘test’. Defaults to ‘train’.

Returns:

Dictionary mapping metric names (with split prefix) to computed

values. Keys follow the format:

  • Summary metrics: ‘{split}/SUMMARY-{type}_{metric_name}’

  • Per-concept metrics: ‘{split}/{concept_name}_{metric_name}’

Values are torch.Tensor objects containing the computed metric values.

Return type:

dict

Raises:

ValueError – If split is not one of the valid options

Example

Basic compute:

# After updating with training data
train_results = metrics.compute('train')
print(train_results)
# {
#     'train/SUMMARY-binary_accuracy': tensor(0.8500),
#     'train/SUMMARY-binary_f1': tensor(0.8234),
#     'train/concept1_accuracy': tensor(0.9000),
#     'train/concept2_accuracy': tensor(0.8000)
# }

Compute multiple splits:

train_metrics = metrics.compute('train')
val_metrics = metrics.compute('val')

# Log to wandb or tensorboard
logger.log_metrics(train_metrics)
logger.log_metrics(val_metrics)

Extract specific metrics:

results = metrics.compute('val')
accuracy = results['val/SUMMARY-binary_accuracy'].item()
print(f"Validation accuracy: {accuracy:.2%}")

Note

  • This method can be called multiple times without resetting

  • Always call reset() after logging to start fresh for next epoch

  • Returned tensors are on the same device as the metric state

reset(split: str | None = None)[source]

Reset metric state for one or all splits.

This method resets the accumulated metric state, clearing all data from previous update() calls. Call this after computing and logging metrics to prepare for the next epoch.

Parameters:

split (Optional[str], optional) –

Which split to reset. Options:

  • ’train’: Reset only training metrics

  • ’val’ or ‘validation’: Reset only validation metrics

  • ’test’: Reset only test metrics

  • None: Reset all splits simultaneously (default)

Raises:

ValueError – If split is not None and not a valid split name

Example

Reset single split:

# At end of training epoch
train_metrics = metrics.compute('train')
logger.log_metrics(train_metrics)
metrics.reset('train')  # Reset only training

Reset all splits:

# At end of validation
train_metrics = metrics.compute('train')
val_metrics = metrics.compute('val')
logger.log_metrics({**train_metrics, **val_metrics})
metrics.reset()  # Reset both train and val

Typical training loop:

for epoch in range(num_epochs):
    # Training
    for batch in train_loader:
        preds = model(batch)
        metrics.update(preds, targets, split='train')

    # Validation
    for batch in val_loader:
        preds = model(batch)
        metrics.update(preds, targets, split='val')

    # Compute and log
    train_results = metrics.compute('train')
    val_results = metrics.compute('val')
    log_metrics({**train_results, **val_results})

    # Reset for next epoch
    metrics.reset()  # Resets both train and val

Note

  • Resetting is essential to avoid mixing data from different epochs

  • Each split can be reset independently

  • Resetting does not affect the metric configuration, only the state

training: bool

Functional Metrics

completeness_score

Calculate the completeness score for the given predictions and true labels.

intervention_score

Compute the effect of concept interventions on downstream task predictions.

cace_score

Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.

completeness_score(y_true, y_pred_blackbox, y_pred_whitebox, scorer=<function roc_auc_score>, average='macro')[source]

Calculate the completeness score for the given predictions and true labels. Main reference: “On Completeness-aware Concept-Based Explanations in Deep Neural Networks”

Parameters:
  • y_true (torch.Tensor) – True labels.

  • y_pred_blackbox (torch.Tensor) – Predictions from the blackbox model.

  • y_pred_whitebox (torch.Tensor) – Predictions from the whitebox model.

  • scorer (function) – Scoring function to evaluate predictions. Default is roc_auc_score.

  • average (str) – Type of averaging to use. Default is ‘macro’.

Returns:

Completeness score.

Return type:

float

intervention_score(y_predictor: ~torch.nn.modules.module.Module, c_pred: ~torch.Tensor, c_true: ~torch.Tensor, y_true: ~torch.Tensor, intervention_groups: ~typing.List[~typing.List[int]], activation: ~typing.Callable = <built-in method sigmoid of type object>, scorer: ~typing.Callable = <function roc_auc_score>, average: str = 'macro', auc: bool = True) float | List[float][source]

Compute the effect of concept interventions on downstream task predictions.

Given set of intervention groups, the intervention score measures the effectiveness of each intervention group on the model’s task predictions.

Main reference: “Concept Bottleneck Models”

Parameters:
  • y_predictor (torch.nn.Module) – Model that predicts downstream task abels.

  • c_pred (torch.Tensor) – Predicted concept values.

  • c_true (torch.Tensor) – Ground truth concept values.

  • y_true (torch.Tensor) – Ground truth task labels.

  • intervention_groups (List[List[int]]) – List of intervention groups.

  • activation (Callable) – Activation function to apply to the model’s predictions. Default is torch.sigmoid.

  • scorer (Callable) – Scoring function to evaluate predictions. Default is roc_auc_score.

  • average (str) – Type of averaging to use. Default is ‘macro’.

  • auc (bool) – Whether to return the average score across all intervention groups. Default is True.

Returns:

The intervention effectiveness for each

intervention group or the average score across all groups.

Return type:

Union[float, List[float]]

cace_score(y_pred_c0, y_pred_c1)[source]

Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.

The ACE/CaCE score measures the causal effect of a concept on the predictions of a model. It is computed as the absolute difference between the expected predictions when the concept is inactive (c0) and active (c1).

Main reference: “Explaining Classifiers with Causal Concept Effect (CaCE)”

Parameters:
  • y_pred_c0 (torch.Tensor) – Predictions of the model when the concept is inactive. Shape: (batch_size, num_classes).

  • y_pred_c1 (torch.Tensor) – Predictions of the model when the concept is active. Shape: (batch_size, num_classes).

Returns:

The ACE/CaCE score for each class. Shape: (num_classes,).

Return type:

torch.Tensor