Source code for torch_concepts.nn.modules.metrics

"""
Metrics module for concept-based model evaluation.

This module provides the :class:`ConceptMetrics` class for evaluating concept-based models
with automatic handling of different concept types (binary, categorical, continuous).
It integrates seamlessly with TorchMetrics and PyTorch Lightning, providing flexible
metric tracking at both aggregate and per-concept levels.

Key Features:
    - Automatic routing of concept predictions to appropriate metrics based on type
    - Summary metrics: aggregated performance across all concepts of each type
    - Per-concept metrics: individual tracking for specific concepts
    - Flexible metric specification: pre-instantiated, class+kwargs, or class-only
    - Independent tracking across train/validation/test splits
    - Integration with PyTorch Lightning training loops

Classes:
    ConceptMetrics: Main metrics manager for concept-based models

Example:
    Basic usage with binary and categorical concepts::

        import torch
        import torchmetrics
        from torch_concepts import Annotations, AxisAnnotation
        from torch_concepts.nn.modules.metrics import ConceptMetrics
        from torch_concepts.nn.modules.utils import GroupConfig

        # Define concept structure
        annotations = Annotations({
            1: AxisAnnotation(
                labels=['is_round', 'is_smooth', 'color'],
                cardinalities=[1, 1, 3],  # binary, binary, categorical
                metadata={
                    'is_round': {'type': 'discrete'},
                    'is_smooth': {'type': 'discrete'},
                    'color': {'type': 'discrete'}
                }
            )
        })

        # Configure metrics
        metrics = ConceptMetrics(
            annotations=annotations,
            fn_collection=GroupConfig(
                binary={'accuracy': torchmetrics.classification.BinaryAccuracy()},
                categorical={'accuracy': torchmetrics.classification.MulticlassAccuracy}
            ),
            summary_metrics=True,
            perconcept_metrics=True
        )

        # During training
        predictions = torch.randn(32, 5)  # 2 binary + 3 categorical (endogenous space)
        targets = torch.cat([
            torch.randint(0, 2, (32, 2)),  # binary concepts
            torch.randint(0, 3, (32, 1))   # categorical concept
        ], dim=1)

        metrics.update(preds=predictions, target=targets, split='train')
        results = metrics.compute('train')
        metrics.reset('train')

See Also:
    - :doc:`/guides/using_metrics`: Comprehensive guide to using metrics
    - :doc:`/modules/nn.loss`: Loss functions for concept-based models
    - :class:`torch_concepts.nn.modules.utils.GroupConfig`: Metric configuration helper
"""
from typing import Optional, Union, List
import torch
from torch import nn
from torchmetrics import Metric, MetricCollection
from torchmetrics.collections import _remove_prefix
from yaml import warnings

from ...annotations import Annotations, AxisAnnotation
from ...nn.modules.utils import GroupConfig
from ...nn.modules.utils import check_collection, get_concept_groups


[docs] class ConceptMetrics(nn.Module): """Metrics manager for concept-based models with automatic type-aware routing. This class organizes and manages metrics for different concept types (binary, categorical, continuous) with support for both summary metrics (aggregated across all concepts of a type) and per-concept metrics (individual tracking per concept). The class automatically routes predictions to the appropriate metrics based on concept types defined in the annotations, handles different metric instantiation patterns, and maintains independent metric tracking across train/val/test splits. Args: annotations (Annotations): Concept annotations containing labels, types, and cardinalities. Should include axis 1 (concept axis) with metadata specifying concept types as 'discrete' or 'continuous'. fn_collection (GroupConfig): Metric configurations organized by concept type ('binary', 'categorical', 'continuous'). Each metric can be specified in three ways: 1. **Pre-instantiated metric**: Pass an already instantiated metric object for full control over all parameters. Example:: 'accuracy': torchmetrics.classification.BinaryAccuracy(threshold=0.6) 2. **Class with user kwargs**: Pass a tuple of (MetricClass, kwargs_dict) to provide custom parameters while letting ConceptMetrics handle concept-specific parameters like num_classes automatically. Example:: 'accuracy': (torchmetrics.classification.MulticlassAccuracy, {'average': 'macro'}) 3. **Class only**: Pass just the metric class and let ConceptMetrics handle all instantiation with appropriate concept-specific parameters. Example:: 'accuracy': torchmetrics.classification.MulticlassAccuracy summary_metrics (bool, optional): Whether to compute summary metrics that aggregate performance across all concepts of each type. Defaults to True. perconcept_metrics (Union[bool, List[str]], optional): Controls per-concept metric tracking. Options: - False: No per-concept tracking (default) - True: Track all concepts individually - List[str]: Track only the specified concept names Attributes: n_concepts (int): Total number of concepts concept_names (Tuple[str]): Names of all concepts cardinalities (List[int]): Number of classes for each concept summary_metrics (bool): Whether summary metrics are computed perconcept_metrics (Union[bool, List[str]]): Per-concept tracking configuration train_metrics (MetricCollection): Metrics for training split val_metrics (MetricCollection): Metrics for validation split test_metrics (MetricCollection): Metrics for test split Raises: NotImplementedError: If continuous concepts are found (not yet supported) ValueError: If metric configuration doesn't match concept types, or if user provides num_classes when it should be set automatically Example: **Basic usage with pre-instantiated metrics**:: import torch import torchmetrics from torch_concepts import Annotations, AxisAnnotation from torch_concepts.nn.modules.metrics import ConceptMetrics from torch_concepts.nn.modules.utils import GroupConfig # Define concept structure annotations = Annotations({ 1: AxisAnnotation( labels=('round', 'smooth'), cardinalities=[1, 1], metadata={ 'round': {'type': 'discrete'}, 'smooth': {'type': 'discrete'} } ) }) # Create metrics with pre-instantiated objects metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={ 'accuracy': torchmetrics.classification.BinaryAccuracy(), 'f1': torchmetrics.classification.BinaryF1Score() } ), summary_metrics=True, perconcept_metrics=False ) # Simulate training batch predictions = torch.randn(32, 2) # endogenous predictions targets = torch.randint(0, 2, (32, 2)) # binary targets # Update metrics metrics.update(pred=predictions, target=targets, split='train') # Compute at epoch end results = metrics.compute('train') print(results) # {'train/SUMMARY-binary_accuracy': ..., 'train/SUMMARY-binary_f1': ...} # Reset for next epoch metrics.reset('train') **Using class + kwargs for flexible configuration**:: # Mixed concept types with custom metric parameters annotations = Annotations({ 1: AxisAnnotation( labels=('binary1', 'binary2', 'category'), cardinalities=[1, 1, 5], metadata={ 'binary1': {'type': 'discrete'}, 'binary2': {'type': 'discrete'}, 'category': {'type': 'discrete'} } ) }) metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={ # Custom threshold 'accuracy': (torchmetrics.classification.BinaryAccuracy, {'threshold': 0.6}) }, categorical={ # Custom averaging, num_classes added automatically 'accuracy': (torchmetrics.classification.MulticlassAccuracy, {'average': 'macro'}) } ), summary_metrics=True, perconcept_metrics=True # Track all concepts individually ) # Predictions: 2 binary + 5 categorical = 7 dimensions predictions = torch.randn(16, 7) targets = torch.cat([ torch.randint(0, 2, (16, 2)), # binary torch.randint(0, 5, (16, 1)) # categorical ], dim=1) metrics.update(pred=predictions, target=targets, split='train') results = metrics.compute('train') # Results include both summary and per-concept metrics: # 'train/SUMMARY-binary_accuracy' # 'train/SUMMARY-categorical_accuracy' # 'train/binary1_accuracy' # 'train/binary2_accuracy' # 'train/category_accuracy' **Selective per-concept tracking**:: # Track only specific concepts metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={'accuracy': torchmetrics.classification.BinaryAccuracy} ), summary_metrics=True, perconcept_metrics=['binary1'] # Only track binary1 individually ) **Integration with PyTorch Lightning**:: import pytorch_lightning as pl class ConceptModel(pl.LightningModule): def __init__(self, annotations): super().__init__() self.model = ... # your model self.metrics = ConceptMetrics( annotations=annotations, fn_collection=GroupConfig( binary={'accuracy': torchmetrics.classification.BinaryAccuracy} ), summary_metrics=True ) def training_step(self, batch, batch_idx): x, concepts = batch preds = self.model(x) # Update metrics self.metrics.update(pred=preds, target=concepts, split='train') return loss def on_train_epoch_end(self): # Compute and log metrics metrics_dict = self.metrics.compute('train') self.log_dict(metrics_dict) self.metrics.reset('train') Note: - Continuous concepts are not yet supported and will raise NotImplementedError - For categorical concepts, ConceptMetrics automatically handles padding to the maximum cardinality when computing summary metrics - User-provided 'num_classes' parameter for categorical metrics will raise an error as it's set automatically based on concept cardinalities - Each split (train/val/test) maintains independent metric state See Also: - :class:`torch_concepts.nn.modules.utils.GroupConfig`: Configuration helper - :class:`torch_concepts.annotations.Annotations`: Concept annotations - `TorchMetrics Documentation <https://lightning.ai/docs/torchmetrics>`_: Available metrics and their parameters """
[docs] def __init__( self, annotations: Annotations, fn_collection: GroupConfig, summary_metrics: bool = True, perconcept_metrics: Union[bool, List[str]] = False ): super().__init__() self.summary_metrics = summary_metrics self.perconcept_metrics = perconcept_metrics # Extract and validate annotations annotations = annotations.get_axis_annotation(axis=1) self.concept_annotations = annotations self.concept_names = annotations.labels self.n_concepts = len(self.concept_names) self.cardinalities = annotations.cardinalities self.metadata = annotations.metadata self.types = [self.metadata[name]['type'] for name in self.concept_names] # Get concept groups self.groups = get_concept_groups(annotations) # Validate that continuous concepts are not used if self.groups['continuous_labels']: raise NotImplementedError( f"Continuous concepts are not yet supported. " f"Found continuous concepts: {self.groups['continuous_labels']}." ) # Validate and filter metrics configuration self.fn_collection = check_collection(annotations, fn_collection, 'metrics') # Pre-compute max cardinality for categorical concepts if self.fn_collection.get('categorical'): self.max_card = max([self.cardinalities[i] for i in self.groups['categorical_idx']]) # Setup metric collections self._setup_metric_collections()
[docs] def __repr__(self) -> str: metric_info = { k: [ (m.__class__.__name__ if isinstance(m, Metric) else m[0].__name__ if isinstance(m, (tuple, list)) else m.__name__) for m in v.values() ] for k, v in self.fn_collection.items() if v } metrics_str = ', '.join(f"{k}=[{','.join(v)}]" for k, v in metric_info.items()) return (f"{self.__class__.__name__}(n_concepts={self.n_concepts}, " f"metrics={{{metrics_str}}}, summary={self.summary_metrics}, " f"perconcept={self.perconcept_metrics})")
@staticmethod def _clone_metric(metric): """Clone and reset a metric for independent tracking across splits.""" metric = metric.clone() metric.reset() return metric def _instantiate_metric(self, metric_spec, concept_specific_kwargs=None): """Instantiate a metric from either an instance or a class+kwargs tuple/list. Args: metric_spec: Either a Metric instance, a tuple/list (MetricClass, kwargs_dict), or a MetricClass (will be instantiated with concept_specific_kwargs only). concept_specific_kwargs (dict): Concept-specific parameters to merge with user kwargs. Returns: Metric: Instantiated metric. Raises: ValueError: If user provides 'num_classes' in kwargs (it's set automatically). """ if isinstance(metric_spec, Metric): # Already instantiated return metric_spec elif isinstance(metric_spec, (tuple, list)) and len(metric_spec) == 2: # (MetricClass, user_kwargs) metric_class, user_kwargs = metric_spec # Check if user provided num_classes when it will be set automatically if 'num_classes' in user_kwargs and concept_specific_kwargs and 'num_classes' in concept_specific_kwargs: raise ValueError( f"'num_classes' should not be provided in metric kwargs. " f"ConceptMetrics automatically sets 'num_classes' based on concept cardinality." ) merged_kwargs = {**(concept_specific_kwargs or {}), **user_kwargs} return metric_class(**merged_kwargs) else: # Just a class, use concept_specific_kwargs only return metric_spec(**(concept_specific_kwargs or {})) def _setup_metric_collections(self): """Setup MetricCollections for train/val/test splits. Creates metric collections with appropriate prefixes and cloned metrics for each split to ensure independent tracking. """ # Build dictionary of all metrics (summary + per-concept) all_metrics = {} # Add summary metrics if self.summary_metrics: if self.fn_collection.get('binary'): for metric_name, metric_spec in self.fn_collection['binary'].items(): key = f"SUMMARY-binary_{metric_name}" all_metrics[key] = self._instantiate_metric(metric_spec) if self.fn_collection.get('categorical'): for metric_name, metric_spec in self.fn_collection['categorical'].items(): key = f"SUMMARY-categorical_{metric_name}" # Add num_classes for categorical summary metrics all_metrics[key] = self._instantiate_metric( metric_spec, concept_specific_kwargs={'num_classes': self.max_card} ) if self.fn_collection.get('continuous'): for metric_name, metric_spec in self.fn_collection['continuous'].items(): key = f"SUMMARY-continuous_{metric_name}" all_metrics[key] = self._instantiate_metric(metric_spec) # Add per-concept metrics if self.perconcept_metrics: # Determine which concepts to track if isinstance(self.perconcept_metrics, bool): concepts_to_trace = self.concept_names elif isinstance(self.perconcept_metrics, list): concepts_to_trace = self.perconcept_metrics else: raise ValueError( "perconcept_metrics must be either a bool or a list of concept names." ) for concept_name in concepts_to_trace: c_idx = self.concept_names.index(concept_name) c_type = self.types[c_idx] card = self.cardinalities[c_idx] # Get the appropriate metrics config for this concept type if c_type == 'discrete' and card == 1: metrics_dict = self.fn_collection.get('binary', {}) concept_kwargs = {} elif c_type == 'discrete' and card > 1: metrics_dict = self.fn_collection.get('categorical', {}) concept_kwargs = {'num_classes': card} elif c_type == 'continuous': metrics_dict = self.fn_collection.get('continuous', {}) concept_kwargs = {} else: metrics_dict = {} concept_kwargs = {} # Add metrics for this concept for metric_name, metric_spec in metrics_dict.items(): key = f"{concept_name}_{metric_name}" all_metrics[key] = self._instantiate_metric( metric_spec, concept_specific_kwargs=concept_kwargs ) # Create MetricCollections for each split with cloned metrics self.train_metrics = MetricCollection( metrics={k: self._clone_metric(m) for k, m in all_metrics.items()}, prefix="train/" ) if all_metrics else MetricCollection({}) self.val_metrics = MetricCollection( metrics={k: self._clone_metric(m) for k, m in all_metrics.items()}, prefix="val/" ) if all_metrics else MetricCollection({}) self.test_metrics = MetricCollection( metrics={k: self._clone_metric(m) for k, m in all_metrics.items()}, prefix="test/" ) if all_metrics else MetricCollection({})
[docs] def get(self, key: str, default=None): """Get a metric collection by key (dict-like interface). Args: key (str): Collection key ('train_metrics', 'val_metrics', 'test_metrics'). default: Default value to return if key not found. Returns: MetricCollection or default value. """ collections = { 'train_metrics': self.train_metrics, 'val_metrics': self.val_metrics, 'test_metrics': self.test_metrics } return collections.get(key, default)
def _get_collection(self, split: str) -> MetricCollection: """Get the metric collection for a specific split. Args: split (str): One of 'train', 'val', or 'test'. Returns: MetricCollection: The collection for the specified split. """ if split == 'train': return self.train_metrics elif split in ['val', 'validation']: return self.val_metrics elif split == 'test': return self.test_metrics else: raise ValueError(f"Unknown split: {split}. Must be 'train', 'val', or 'test'.")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor, split: str = 'train'): """Update metrics with predictions and targets for a given split. This method automatically routes predictions to the appropriate metrics based on concept types. For summary metrics, it aggregates all concepts of each type. For per-concept metrics, it extracts individual concept predictions. The preds tensor should be in the endogenous space (after applying the concept distributions' transformations), and the target tensor should contain the ground truth concept values. Args: preds (torch.Tensor): Model predictions in endogenous space. Shape depends on concept types: - Binary concepts: (batch_size, n_binary_concepts) - Categorical concepts: (batch_size, sum of cardinalities) - Mixed: (batch_size, n_binary + sum of cat cardinalities) target (torch.Tensor): Ground truth concept values. Shape (batch_size, n_concepts) where each column corresponds to a concept: - Binary concepts: float values in {0, 1} - Categorical concepts: integer class indices in {0, ..., cardinality-1} - Continuous concepts: float values (not yet supported) split (str, optional): Which data split to update. Must be one of: - 'train': Training split - 'val' or 'validation': Validation split - 'test': Test split Defaults to 'train'. Raises: ValueError: If split is not one of 'train', 'val', 'validation', or 'test' NotImplementedError: If continuous concepts are encountered Example: **Basic update**:: # Binary concepts only predictions = torch.randn(32, 3) # 3 binary concepts targets = torch.randint(0, 2, (32, 3)) # binary ground truth metrics.update(preds=predictions, target=targets, split='train') **Mixed concept types**:: # 2 binary + 1 categorical (3 classes) # Endogenous space: 2 binary + 3 categorical = 5 dims predictions = torch.randn(32, 5) targets = torch.cat([ torch.randint(0, 2, (32, 2)), # binary targets torch.randint(0, 3, (32, 1)) # categorical target ], dim=1) metrics.update(preds=predictions, target=targets, split='train') **Validation split**:: val_predictions = model(val_data) metrics.update(preds=val_predictions, target=val_targets, split='val') Note: - This method accumulates metric state across multiple batches - Call :meth:`compute` to calculate final metric values - Call :meth:`reset` after computing to start fresh for next epoch - Each split maintains independent state """ # Skip empty batches to avoid errors in underlying metric libraries if preds.shape[0] == 0: return metric_collection = self._get_collection(split) for key in metric_collection: # Update summary metrics if self.summary_metrics: if 'SUMMARY-binary_' in key and self.groups['binary_labels']: binary_pred = preds[:, self.groups['binary_endogenous_idx']] binary_target = target[:, self.groups['binary_idx']].float() metric_collection[key].update(binary_pred, binary_target) continue elif 'SUMMARY-categorical_' in key and self.groups['categorical_labels']: # Pad and stack categorical endogenous split_tuple = torch.split( preds[:, self.groups['categorical_endogenous_idx']], [self.cardinalities[i] for i in self.groups['categorical_idx']], dim=1 ) padded_endogenous = [ nn.functional.pad( endogenous, (0, self.max_card - endogenous.shape[1]), value=float('-inf') ) for endogenous in split_tuple ] cat_pred = torch.cat(padded_endogenous, dim=0) cat_target = target[:, self.groups['categorical_idx']].T.reshape(-1).long() metric_collection[key].update(cat_pred, cat_target) continue elif 'SUMMARY-continuous_' in key and self.groups['continuous_labels']: raise NotImplementedError("Continuous concepts not yet implemented.") # Update per-concept metrics if self.perconcept_metrics: # Extract concept name from key key_noprefix = _remove_prefix(key, prefix=metric_collection.prefix) concept_name = '_'.join(key_noprefix.split('_')[:-1]) if concept_name not in self.concept_names: concept_name = key_noprefix.split('_')[0] endogenous_idx = self.concept_annotations.get_endogenous_idx([concept_name]) c_idx = self.concept_annotations.get_index(concept_name) c_type = self.types[c_idx] card = self.cardinalities[c_idx] if c_type == 'discrete' and card == 1: metric_collection[key].update( preds[:, endogenous_idx], target[:, c_idx:c_idx+1].float() ) elif c_type == 'discrete' and card > 1: metric_collection[key].update( preds[:, endogenous_idx], target[:, c_idx].long() ) elif c_type == 'continuous': metric_collection[key].update( preds[:, endogenous_idx], target[:, c_idx:c_idx+1] ) else: raise ValueError(f"ConceptMetrics.update(): Unknown concept \ type '{c_type}' for concept '{concept_name}'.")
[docs] def compute(self, split: str = 'train'): """Compute final metric values from accumulated state for a split. This method calculates the final metric values using all data accumulated through :meth:`update` calls since the last :meth:`reset`. It does not reset the metric state, allowing you to log results before resetting. Args: split (str, optional): Which data split to compute metrics for. Must be one of 'train', 'val', 'validation', or 'test'. Defaults to 'train'. Returns: dict: Dictionary mapping metric names (with split prefix) to computed values. Keys follow the format: - Summary metrics: '{split}/SUMMARY-{type}_{metric_name}' - Per-concept metrics: '{split}/{concept_name}_{metric_name}' Values are torch.Tensor objects containing the computed metric values. Raises: ValueError: If split is not one of the valid options Example: **Basic compute**:: # After updating with training data train_results = metrics.compute('train') print(train_results) # { # 'train/SUMMARY-binary_accuracy': tensor(0.8500), # 'train/SUMMARY-binary_f1': tensor(0.8234), # 'train/concept1_accuracy': tensor(0.9000), # 'train/concept2_accuracy': tensor(0.8000) # } **Compute multiple splits**:: train_metrics = metrics.compute('train') val_metrics = metrics.compute('val') # Log to wandb or tensorboard logger.log_metrics(train_metrics) logger.log_metrics(val_metrics) **Extract specific metrics**:: results = metrics.compute('val') accuracy = results['val/SUMMARY-binary_accuracy'].item() print(f"Validation accuracy: {accuracy:.2%}") Note: - This method can be called multiple times without resetting - Always call :meth:`reset` after logging to start fresh for next epoch - Returned tensors are on the same device as the metric state """ metric_collection = self._get_collection(split) return metric_collection.compute()
[docs] def reset(self, split: Optional[str] = None): """Reset metric state for one or all splits. This method resets the accumulated metric state, clearing all data from previous :meth:`update` calls. Call this after computing and logging metrics to prepare for the next epoch. Args: split (Optional[str], optional): Which split to reset. Options: - 'train': Reset only training metrics - 'val' or 'validation': Reset only validation metrics - 'test': Reset only test metrics - None: Reset all splits simultaneously (default) Raises: ValueError: If split is not None and not a valid split name Example: **Reset single split**:: # At end of training epoch train_metrics = metrics.compute('train') logger.log_metrics(train_metrics) metrics.reset('train') # Reset only training **Reset all splits**:: # At end of validation train_metrics = metrics.compute('train') val_metrics = metrics.compute('val') logger.log_metrics({**train_metrics, **val_metrics}) metrics.reset() # Reset both train and val **Typical training loop**:: for epoch in range(num_epochs): # Training for batch in train_loader: preds = model(batch) metrics.update(preds, targets, split='train') # Validation for batch in val_loader: preds = model(batch) metrics.update(preds, targets, split='val') # Compute and log train_results = metrics.compute('train') val_results = metrics.compute('val') log_metrics({**train_results, **val_results}) # Reset for next epoch metrics.reset() # Resets both train and val Note: - Resetting is essential to avoid mixing data from different epochs - Each split can be reset independently - Resetting does not affect the metric configuration, only the state """ if split is None: self.train_metrics.reset() self.val_metrics.reset() self.test_metrics.reset() else: metric_collection = self._get_collection(split) metric_collection.reset()
# class ConceptCausalEffect(Metric): # """ # Concept Causal Effect (CaCE) metric for measuring causal effects. # # CaCE measures the causal effect between concept pairs or between a concept # and the task by comparing predictions under interventions do(C=1) vs do(C=0). # # Note: Currently only works on binary concepts. # # Attributes: # preds_do_1 (Tensor): Accumulated predictions under do(C=1). # preds_do_0 (Tensor): Accumulated predictions under do(C=0). # total (Tensor): Total number of samples processed. # # Example: # >>> import torch # >>> from torch_concepts.nn.modules.metrics import ConceptCausalEffect # >>> # >>> # Create metric # >>> cace = ConceptCausalEffect() # >>> # >>> # Update with predictions under interventions # >>> preds_do_1 = torch.tensor([[0.1, 0.9], [0.2, 0.8]]) # P(Y|do(C=1)) # >>> preds_do_0 = torch.tensor([[0.8, 0.2], [0.7, 0.3]]) # P(Y|do(C=0)) # >>> cace.update(preds_do_1, preds_do_0) # >>> # >>> # Compute causal effect # >>> effect = cace.compute() # >>> print(f"Causal effect: {effect:.3f}") # # References: # Goyal et al. "Explaining Classifiers with Causal Concept Effect (CaCE)", # arXiv 2019. https://arxiv.org/abs/1907.07165 # """ # def __init__(self): # super().__init__() # self.add_state("preds_do_1", default=torch.tensor(0.), dist_reduce_fx="sum") # self.add_state("preds_do_0", default=torch.tensor(0.), dist_reduce_fx="sum") # self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # def update(self, # preds_do_1: torch.Tensor, # preds_do_0: torch.Tensor): # """ # Update metric state with predictions under interventions. # # Args: # preds_do_1: Predictions when intervening C=1, shape (batch_size, n_classes). # preds_do_0: Predictions when intervening C=0, shape (batch_size, n_classes). # """ # _check_same_shape(preds_do_1, preds_do_0) # # expected value = 1*p(output=1|do(1)) + 0*(1-p(output=1|do(1)) # self.preds_do_1 += preds_do_1[:,1].sum() # # expected value = 1*p(output=1|do(0)) + 0*(1-p(output=1|do(0)) # self.preds_do_0 += preds_do_0[:,1].sum() # self.total += preds_do_1.size()[0] # def compute(self): # """ # Compute the Causal Concept Effect (CaCE). # # Returns: # torch.Tensor: The average causal effect E[Y|do(C=1)] - E[Y|do(C=0)]. # """ # return (self.preds_do_1.float() / self.total) - (self.preds_do_0.float() / self.total)