Source code for torch_concepts.nn.modules.metrics


from typing import Optional, Union, List, Tuple
import torch
from torch import nn
from torchmetrics import Metric, MetricCollection
from copy import deepcopy

from ...annotations import Annotations
from .outputs import ModelOutput
from .utils import GroupConfig, check_collection


def clone_metric(metric):
    """Clone and reset a metric for independent tracking across splits."""
    metric = metric.clone()
    metric.reset()
    return metric


[docs] class ConceptMetrics(nn.Module): """Type-aware metric manager for concept-based models. Automatically routes predictions to the correct metrics based on concept type (binary / categorical) as defined in the annotations. Supports summary metrics (aggregated per type) and per-concept metrics, with independent state for each data split (train / val / test). Args: annotations (Annotations): Concept annotations (axis 1) with labels, cardinalities, and types (``'binary'``, ``'categorical'``, or ``'continuous'``). binary: Metric specs for binary concepts (cardinality 1). categorical: Metric specs for categorical concepts (cardinality > 1). continuous: Metric specs for continuous concepts (not yet supported). summary (bool): Compute summary metrics aggregated across all concepts of each type. Default ``True``. per_concept (bool | list[str]): ``False`` (default) disables per-concept tracking; ``True`` tracks every concept; a list of names tracks only those concepts. Each metric spec can be: * A pre-instantiated ``torchmetrics.Metric``. * A ``(MetricClass, kwargs)`` tuple — ``num_classes`` is injected automatically for categorical concepts. * A non-instantiated ``MetricClass``. Example:: metrics = ConceptMetrics( annotations=annotations, binary={"accuracy": BinaryAccuracy()}, categorical={"accuracy": (MulticlassAccuracy, {"average": "micro"})}, ) metrics.update(preds, target) results = metrics.compute() # {"SUMMARY-binary_accuracy": ..., ...} metrics.reset() """
[docs] def __init__( self, annotations: Annotations, binary: Union[nn.Module, Tuple[nn.Module, dict]] = None, categorical: Union[nn.Module, Tuple[nn.Module, dict]] = None, continuous: Union[nn.Module, Tuple[nn.Module, dict]] = None, summary: bool = True, per_concept: Union[bool, List[str]] = False, prefix: Optional[str] = None ): super().__init__() self.summary = summary self.per_concept = per_concept # Extract and validate annotations self.concept_annotations = annotations self.concept_names = annotations.labels self.n_concepts = len(self.concept_names) self.cardinalities = annotations.cardinalities self.metadata = annotations.metadata self.types = list(annotations.types) # Use cached type_groups from Annotations self.groups = annotations.type_groups # Validate that continuous concepts are not used if self.groups['continuous']['labels']: raise NotImplementedError( f"Continuous concepts are not yet supported. " f"Found continuous concepts: {self.groups['continuous']['labels']}." ) # Validate and filter metrics configuration fn_collection = GroupConfig(binary=binary, categorical=categorical, continuous=continuous) self.fn_collection = check_collection(annotations, fn_collection, 'metrics') # Pre-compute max cardinality for categorical concepts if self.fn_collection.get('categorical'): self.max_card = max([self.cardinalities[i] for i in self.groups['categorical']['concept_idx']]) # Determine which concepts to track for per-concept metrics if self.per_concept: if isinstance(self.per_concept, bool): self._concepts_to_trace = list(self.concept_names) elif isinstance(self.per_concept, list): invalid = [n for n in self.per_concept if n not in self.concept_names] if invalid: raise ValueError( f"Concept names not found in annotations: {invalid}" ) self._concepts_to_trace = self.per_concept else: raise ValueError( "per_concept must be either a bool or a list of concept names." ) else: self._concepts_to_trace = [] # Setup separate MetricCollections per type and per concept pfx = f"{prefix}/" if prefix else "" self._prefix = pfx summary_b, summary_c, summary_cont, per_concept_dict = self._setup_metrics() # Summary collections: one MetricCollection per concept type self.binary = MetricCollection( metrics=summary_b, prefix=f"{pfx}SUMMARY-binary_" ) if summary_b else MetricCollection({}) self.categorical = MetricCollection( metrics=summary_c, prefix=f"{pfx}SUMMARY-categorical_" ) if summary_c else MetricCollection({}) self.continuous = MetricCollection( metrics=summary_cont, prefix=f"{pfx}SUMMARY-continuous_" ) if summary_cont else MetricCollection({}) # Per-concept collections: one MetricCollection per tracked concept self._per_concept = nn.ModuleDict({ name: MetricCollection(metrics=metrics, prefix=f"{pfx}{name}_") for name, metrics in per_concept_dict.items() })
def __repr__(self) -> str: metric_info = { k: [ (m.__class__.__name__ if isinstance(m, Metric) else m[0].__name__ if isinstance(m, (tuple, list)) else m.__name__) for m in v.values() ] for k, v in self.fn_collection.items() if v } metrics_str = ', '.join(f"{k}=[{','.join(v)}]" for k, v in metric_info.items()) return (f"{self.__class__.__name__}(n_concepts={self.n_concepts}, " f"metrics={{{metrics_str}}}, summary={self.summary}, " f"per_concept={self.per_concept})") @property def collection(self): """Return all non-empty sub-collections as a dict.""" result = {} if len(self.binary): result['binary'] = self.binary if len(self.categorical): result['categorical'] = self.categorical if len(self.continuous): result['continuous'] = self.continuous for name, coll in self._per_concept.items(): if len(coll): result[name] = coll return result def clone(self, prefix=None): """Create an independent copy with fresh state and optional new prefix. Args: prefix: New prefix for metric keys. If None, keeps the original. """ cloned = deepcopy(self) if prefix is not None: pfx = f"{prefix}/" if prefix else "" cloned._prefix = pfx if len(cloned.binary): cloned.binary.prefix = f"{pfx}SUMMARY-binary_" if len(cloned.categorical): cloned.categorical.prefix = f"{pfx}SUMMARY-categorical_" if len(cloned.continuous): cloned.continuous.prefix = f"{pfx}SUMMARY-continuous_" for name, coll in cloned._per_concept.items(): coll.prefix = f"{pfx}{name}_" cloned.reset() return cloned def _instantiate_metric(self, metric_spec, concept_specific_kwargs=None): """Instantiate a metric from either an instance or a class+kwargs tuple/list. Args: metric_spec: Either a Metric instance, a tuple/list (MetricClass, kwargs_dict), or a MetricClass (will be instantiated with concept_specific_kwargs only). concept_specific_kwargs (dict): Concept-specific parameters to merge with user kwargs. Returns: Metric: Instantiated metric. Raises: ValueError: If user provides 'num_classes' in kwargs (it's set automatically). """ if isinstance(metric_spec, Metric): return metric_spec.clone() elif isinstance(metric_spec, (tuple, list)) and len(metric_spec) == 2: # (MetricClass, user_kwargs) metric_class, user_kwargs = metric_spec # Check if user provided num_classes when it will be set automatically if 'num_classes' in user_kwargs and concept_specific_kwargs and 'num_classes' in concept_specific_kwargs: raise ValueError( f"'num_classes' should not be provided in metric kwargs. " f"ConceptMetrics automatically sets 'num_classes' based on concept cardinality." ) merged_kwargs = {**(concept_specific_kwargs or {}), **user_kwargs} return metric_class(**merged_kwargs) else: # Just a class, use concept_specific_kwargs only return metric_spec(**(concept_specific_kwargs or {})) def _setup_metrics(self): """Instantiate metrics, separated into summary and per-concept groups. Returns: Tuple of (summary_binary, summary_categorical, summary_continuous, per_concept) where per_concept maps concept name to metric dict. """ summary_binary = {} summary_categorical = {} summary_continuous = {} per_concept = {} # Summary metrics (keyed by metric name; prefix added by MetricCollection) if self.summary: if self.fn_collection.get('binary'): for name, spec in self.fn_collection['binary'].items(): summary_binary[name] = self._instantiate_metric(spec) if self.fn_collection.get('categorical'): for name, spec in self.fn_collection['categorical'].items(): summary_categorical[name] = self._instantiate_metric( spec, concept_specific_kwargs={'num_classes': self.max_card} ) if self.fn_collection.get('continuous'): for name, spec in self.fn_collection['continuous'].items(): summary_continuous[name] = self._instantiate_metric(spec) # Per-concept metrics (one dict per concept) for concept_name in self._concepts_to_trace: c_idx = self.concept_names.index(concept_name) c_type = self.types[c_idx] card = self.cardinalities[c_idx] concept_metrics = {} if c_type == 'binary': for name, spec in self.fn_collection.get('binary', {}).items(): concept_metrics[name] = self._instantiate_metric(spec) elif c_type == 'categorical': for name, spec in self.fn_collection.get('categorical', {}).items(): concept_metrics[name] = self._instantiate_metric( spec, concept_specific_kwargs={'num_classes': card} ) elif c_type == 'continuous': for name, spec in self.fn_collection.get('continuous', {}).items(): concept_metrics[name] = self._instantiate_metric(spec) if concept_metrics: per_concept[concept_name] = concept_metrics return summary_binary, summary_categorical, summary_continuous, per_concept def _prepare_categorical(self, preds, target): """Pad and stack categorical logits/targets for summary metrics.""" cat_concept_idx = self.groups['categorical']['concept_idx'] split_tuple = torch.split( preds[:, self.groups['categorical']['logits_idx']], [self.cardinalities[i] for i in cat_concept_idx], dim=1 ) padded_logits = [ nn.functional.pad( logits, (0, self.max_card - logits.shape[1]), value=float('-inf') ) for logits in split_tuple ] cat_pred = torch.cat(padded_logits, dim=0) cat_target = target[:, cat_concept_idx].T.reshape(-1).long() return cat_pred, cat_target def update(self, preds, target: torch.Tensor = None): """Update metrics by routing predictions to the correct type collection. Summary metrics receive aggregated data for all concepts of a type. Per-concept metrics receive individual concept data. Args: preds: Model predictions (logits) as a ``torch.Tensor`` of shape ``(batch, logits_dim)``, or a ``ModelOutput`` whose ``.logits`` and ``.target`` fields will be used. target: Ground truth values. Shape ``(batch, n_concepts)``. Required when *preds* is a plain tensor; ignored when *preds* is a ``ModelOutput``. """ if isinstance(preds, ModelOutput): target = preds.target preds = preds.logits if preds.shape[0] == 0: return # Summary metrics — one MetricCollection.update() call per type if self.summary: if self.groups['binary']['labels'] and len(self.binary): binary_pred = preds[:, self.groups['binary']['logits_idx']] binary_target = target[:, self.groups['binary']['concept_idx']].float() self.binary.update(binary_pred, binary_target) if self.groups['categorical']['labels'] and len(self.categorical): cat_pred, cat_target = self._prepare_categorical(preds, target) self.categorical.update(cat_pred, cat_target) if self.groups['continuous']['labels'] and len(self.continuous): raise NotImplementedError("Continuous concepts not yet implemented.") # Per-concept metrics — one MetricCollection.update() call per concept for concept_name, collection in self._per_concept.items(): logits_slice = self.concept_annotations.get_slice(concept_name) c_idx = self.concept_annotations.get_index(concept_name) c_type = self.types[c_idx] if c_type == 'binary': collection.update(preds[:, logits_slice], target[:, c_idx:c_idx+1].float()) elif c_type == 'categorical': collection.update(preds[:, logits_slice], target[:, c_idx].long()) elif c_type == 'continuous': collection.update(preds[:, logits_slice], target[:, c_idx:c_idx+1]) def compute(self): """Compute all metrics and return as a flat dict.""" results = {} if len(self.binary): results.update(self.binary.compute()) if len(self.categorical): results.update(self.categorical.compute()) if len(self.continuous): results.update(self.continuous.compute()) for collection in self._per_concept.values(): results.update(collection.compute()) return results def reset(self): """Reset all metric state.""" self.binary.reset() self.categorical.reset() self.continuous.reset() for collection in self._per_concept.values(): collection.reset()
[docs] @torch.no_grad() def compute_cace( model, dataloader, source_concept: str, target_concept: str, prob_high: Union[float, torch.Tensor] = 1.0, prob_low: Union[float, torch.Tensor] = 0.0, ) -> torch.Tensor: """Compute the Causal Concept Effect of *source_concept* on *target_concept*. Runs ``do(source = prob_high)`` vs ``do(source = prob_low)`` over the dataloader and returns the mean difference on the target. Values are in **probability space** (0–1 for binary concepts). They are converted to logits internally via ``torch.logit``. * **Binary** (default): ``prob_high=1, prob_low=0``. * **Categorical**: pass probability vectors, e.g. ``prob_high=tensor([0, 0, 1])``, ``prob_low=tensor([1, 0, 0])``. Args: model: A high-level concept model (e.g. :class:`~torch_concepts.nn.ConceptBottleneckModel`). dataloader: Iterable yielding batch dicts with ``{'inputs': {'x': Tensor}}``. source_concept: Concept to intervene on. target_concept: Concept whose prediction is measured. prob_high: Probability for the *high* regime (default 1.0). prob_low: Probability for the *low* regime (default 0.0). Returns: Scalar tensor with the CaCE score. Example:: >>> cace = compute_cace( # doctest: +SKIP ... model=cbm, ... dataloader=test_loader, ... source_concept="c1", ... target_concept="task", ... ) """ from ..modules.low.inference.intervention import DoIntervention, intervention from ..modules.low.policy.uniform import UniformPolicy from ...nn.functional import cace_score cpds = model.model.probabilistic_model.parametric_cpds was_training = model.training model.eval() if not any(True for _ in dataloader): if was_training: model.train() raise ValueError("Dataloader yielded no batches.") # Convert probabilities → logits (the intervention operates in logit space) eps = 1e-6 if not torch.is_tensor(prob_high): prob_high = torch.tensor(prob_high) if not torch.is_tensor(prob_low): prob_low = torch.tensor(prob_low) logit_high = torch.logit(prob_high.float(), eps=eps) logit_low = torch.logit(prob_low.float(), eps=eps) strategy_high = DoIntervention(model=cpds, constants=logit_high) strategy_low = DoIntervention(model=cpds, constants=logit_low) all_high, all_low = [], [] for batch in dataloader: x = batch["inputs"]["x"] with intervention( policies=UniformPolicy(out_concepts=1), strategies=strategy_high, target_concepts=[source_concept], ): out_high = model(x=x, query=[target_concept]) with intervention( policies=UniformPolicy(out_concepts=1), strategies=strategy_low, target_concepts=[source_concept], ): out_low = model(x=x, query=[target_concept]) all_high.append(out_high.probs) all_low.append(out_low.probs) if was_training: model.train() return cace_score(torch.cat(all_low), torch.cat(all_high)).squeeze()