Contributing a New Metric

This guide explains how to add a new metric to pyc_logo PyC. ConceptMetrics wraps a torchmetrics.MetricCollection and routes predictions to the right metric based on concept type (binary or categorical) as declared in the Annotations.

The recommended path for most metrics is to reuse an existing torchmetrics metric. Write a custom subclass only when no existing metric covers your need.

Using Existing torchmetrics Metrics

Pass metrics as a dict mapping a name string to a metric spec. There are three supported spec forms, listed in order of preference.

Form 1 — Class only

Pass the class directly. ConceptMetrics instantiates it with no extra arguments. Use this for metrics that need no configuration (e.g. BinaryAccuracy):

from torchmetrics.classification import BinaryAccuracy
import torch_concepts as pyc
from torch_concepts.nn import ConceptMetrics

ann = pyc.Annotations(
    labels=["is_round", "color", "label"],
    cardinalities=[1, 3, 1],
    types=["binary", "categorical", "binary"],
)

metrics = ConceptMetrics(
    annotations=ann,
    binary={"accuracy": BinaryAccuracy},
)

Form 2 — (Class, kwargs) tuple

Pass a (MetricClass, kwargs) tuple to supply extra constructor arguments. For categorical metrics, num_classes is injected automatically from the annotations — do not include it in kwargs:

from torchmetrics.classification import (
    BinaryAccuracy, MulticlassAccuracy, MulticlassF1Score,
)

metrics = ConceptMetrics(
    annotations=ann,
    binary={"accuracy": BinaryAccuracy},
    categorical={
        "accuracy": (MulticlassAccuracy, {"average": "micro"}),
        "f1":       (MulticlassF1Score,   {"average": "macro"}),
    },
)

Form 3 — Pre-instantiated metric

Pass an already-constructed torchmetrics.Metric object. You retain full control, but you must set num_classes yourself for categorical metrics:

from torchmetrics.classification import MulticlassAccuracy

# num_classes must be set manually — ConceptMetrics cannot inject it.
cat_acc = MulticlassAccuracy(num_classes=3, average="weighted")

metrics = ConceptMetrics(
    annotations=ann,
    categorical={"accuracy": cat_acc},
)

Calling the metrics object

import torch

batch = 8
logits = torch.randn(batch, 5)        # 1 + 3 + 1 logits
target = torch.randint(0, 2, (batch, 3)).float()

metrics.update(logits, target)
results = metrics.compute()
# {"SUMMARY-binary_accuracy": tensor(...),
#  "SUMMARY-categorical_accuracy": tensor(...), ...}
metrics.reset()

Per-concept tracking

Set per_concept=True to get one metric entry per concept, or pass a list of concept names to track a subset:

metrics = ConceptMetrics(
    annotations=ann,
    binary={"accuracy": BinaryAccuracy},
    categorical={"accuracy": (MulticlassAccuracy, {"average": "micro"})},
    per_concept=["is_round", "color"],  # only these two
)

metrics.update(logits, target)
results = metrics.compute()
# Includes "is_round_accuracy", "color_accuracy", plus SUMMARY keys.

Split tracking

Use clone(prefix=...) to get independent metric objects for train, validation, and test splits, each with its own accumulated state:

train_metrics = metrics.clone(prefix="train")
val_metrics   = metrics.clone(prefix="val")
test_metrics  = metrics.clone(prefix="test")
Custom Metric

When no existing torchmetrics metric fits your need, subclass torchmetrics.Metric directly. The three methods to implement are __init__, update, and compute.

The example below computes Concept Alignment Score (CAS) — the fraction of concepts whose predicted label matches the ground-truth label (macro-averaged binary accuracy per concept, then averaged across concepts). It illustrates the full torchmetrics.Metric contract.

# torch_concepts/nn/modules/metrics.py  (add near end of file)
import torch
from torchmetrics import Metric


class ConceptAlignmentScore(Metric):
    """Mean per-concept accuracy across all binary concepts.

    Accumulates the number of correct predictions and total samples per
    concept, then returns the mean accuracy.

    Args:
        n_concepts (int): Number of binary concepts to track.

    Example::

        cas = ConceptAlignmentScore(n_concepts=3)
        cas.update(preds=torch.tensor([[0.9, -0.3, 0.1]]),
                   target=torch.tensor([[1.0,  0.0, 0.0]]))
        score = cas.compute()  # tensor(0.6667)
    """

    # Higher is better (for logging direction hints in Lightning).
    higher_is_better = True

    def __init__(self, n_concepts: int, **kwargs):
        super().__init__(**kwargs)
        # add_state registers tensors that are reset() on each epoch.
        self.add_state(
            "correct",
            default=torch.zeros(n_concepts),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "total",
            default=torch.zeros(n_concepts),
            dist_reduce_fx="sum",
        )
        self.n_concepts = n_concepts

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        """Accumulate correct predictions.

        Args:
            preds: Logit tensor, shape ``(batch, n_concepts)``.
            target: Ground-truth labels, shape ``(batch, n_concepts)``.
                Values are ``0.0`` or ``1.0``.
        """
        predicted = (preds > 0).float()              # threshold at 0
        correct   = (predicted == target).float()    # (batch, n_concepts)
        self.correct += correct.sum(dim=0)
        self.total   += torch.full_like(self.total, preds.shape[0])

    def compute(self) -> torch.Tensor:
        """Return mean per-concept accuracy (scalar)."""
        per_concept_acc = self.correct / self.total.clamp(min=1)
        return per_concept_acc.mean()

Plugging the custom metric into ConceptMetrics

Use Form 1 (class only) when n_concepts matches the number of binary concepts in the annotations. If you need to pass it explicitly, use Form 2 (tuple):

import torch_concepts as pyc
from torch_concepts.nn import ConceptMetrics
from torch_concepts.nn.modules.metrics import ConceptAlignmentScore

ann = pyc.Annotations(
    labels=["is_round", "shiny", "label"],
    cardinalities=[1, 1, 1],
    types=["binary", "binary", "binary"],
)

# n_concepts injected automatically via the annotations — not needed here
# because ConceptAlignmentScore takes n_concepts, not num_classes.
# Use Form 3 (pre-instantiated) when auto-injection does not apply.
n_binary = len(ann.type_groups["binary"]["labels"])
cas = ConceptAlignmentScore(n_concepts=n_binary)

metrics = ConceptMetrics(
    annotations=ann,
    binary={"cas": cas},   # Form 3: pre-instantiated
)

logits = torch.randn(16, 3)
target = torch.randint(0, 2, (16, 3)).float()
metrics.update(logits, target)
print(metrics.compute())
# {"SUMMARY-binary_cas": tensor(...)}
Registering

Once the metric works locally, register it in two places.

1. Module file

Add the class to torch_concepts/nn/modules/metrics.py. Place it near existing metric classes so related code stays together.

2. Public API export

Add two lines to torch_concepts/nn/__init__.py:

# in torch_concepts/nn/__init__.py
from .modules.metrics import ConceptMetrics, compute_cace, \
    ConceptAlignmentScore

__all__ = [
    ...
    "ConceptAlignmentScore",
]

After this, users can do from torch_concepts.nn import ConceptAlignmentScore.

3. API reference page (optional but recommended)

Add an autoclass directive to the appropriate reference page so the docstring appears in the rendered documentation:

.. autoclass:: torch_concepts.nn.ConceptAlignmentScore
   :members:
   :undoc-members:
   :show-inheritance:

4. Tests

Add a test in tests/ that constructs a ConceptMetrics with your metric, calls update and compute, and checks the output value. Mirror the existing tests in tests/test_metrics.py for the expected structure.

Next Steps

  • Read the ConceptMetrics API reference for the full list of constructor arguments, the clone / reset / compute lifecycle, and the output key naming scheme.

  • Browse torchmetrics documentation at https://torchmetrics.readthedocs.io to find existing metrics before writing a custom one.

  • Open a pull request to dev — see Contributing for the full workflow.