Source code for torch_concepts.annotations

"""
Concept annotations for tensors.

This module provides annotation structures for concept-based tensors, allowing
semantic labeling of a tensor's concept axis and its components. It supports both
simple (flat) and nested (hierarchical) concept structures.
"""

import warnings
import torch

from dataclasses import dataclass, field
from functools import cached_property
from typing import Dict, List, Tuple, Union, Optional, Any, Sequence


#: The canonical concept-type vocabulary. A concept is exactly one of these.
_CONCEPT_TYPES = ('binary', 'categorical', 'continuous')


[docs] @dataclass(frozen=True) class Concept: """Read-only, per-concept view over a single column of an :class:`Annotations`. Groups one concept's properties into a single named object so callers can write ``annotations.concept('color').cardinality`` instead of the index-dance ``int(annotations.cardinalities[annotations.get_index('color')])``. It is a *view*: the values are read from the owning ``Annotations``' parallel lists at construction time (no duplicated storage), so the lists remain the canonical representation. Attributes: name (str): Concept label. index (int): Concept-level index within the axis. cardinality (int): Number of states (1 for binary/continuous scalars). type (str): Concept type, one of ``'binary'`` / ``'categorical'`` / ``'continuous'``. states (Optional[List[str]]): State labels for this concept. slice (slice): Column span of this concept in the flattened (logit) tensor. metadata (dict): Raw per-concept metadata (escape hatch for extra keys). """ name: str index: int cardinality: int type: str states: Optional[List[str]] slice: slice metadata: dict = field(default_factory=dict) @property def is_continuous(self) -> bool: """Whether this concept is continuous.""" return self.type == 'continuous' @property def is_binary(self) -> bool: """Whether this concept is a single binary value (cardinality 1).""" return self.type == 'binary' @property def is_categorical(self) -> bool: """Whether this concept is a multi-state (categorical) variable.""" return self.type == 'categorical'
[docs] @dataclass class Annotations: """ Annotations for the concept axis of a tensor. This class provides semantic labeling for the concept dimension (axis 1) of a tensor, supporting both simple binary concepts and nested multi-state concepts. Axis 0 is the (unannotated) batch dimension. Attributes: labels (list[str]): Ordered, unique concept labels. states (Optional[list[list[str]]]): State labels for each concept (if nested). cardinalities (Optional[list[int]]): Cardinality of each concept. types (Optional[list[str]]): ``'binary'`` / ``'categorical'`` / ``'continuous'`` per concept. metadata (Optional[Dict[str, Dict]]): Additional metadata for each label. is_nested (bool): Whether the axis has nested/hierarchical structure. Args: labels: List of concept names. states: Optional list of state lists for nested concepts. cardinalities: Optional list of cardinalities per concept. types: Optional concept types per concept. metadata: Optional metadata dictionary keyed by label names. Example: >>> from torch_concepts import Annotations >>> >>> # Simple binary concepts >>> ann_binary = Annotations( ... labels=['has_wheels', 'has_windows', 'is_red'] ... ) >>> print(ann_binary.labels) ['has_wheels', 'has_windows', 'is_red'] >>> print(ann_binary.is_nested) False >>> print(ann_binary.cardinalities) [1, 1, 1] >>> print(ann_binary.shape) (-1, 3) >>> >>> # Nested concepts with explicit states >>> ann_nested = Annotations( ... labels=['color', 'shape'], ... states=[['red', 'green', 'blue'], ['circle', 'square']], ... ) >>> print(ann_nested.labels) ['color', 'shape'] >>> print(ann_nested.is_nested) True >>> print(ann_nested.cardinalities) [3, 2] >>> print(ann_nested.states[0]) ['red', 'green', 'blue'] >>> print(ann_nested.shape) (-1, 5) >>> >>> # With cardinalities only (auto-generates state labels) >>> ann_cards = Annotations( ... labels=['size', 'material'], ... cardinalities=[3, 4] ... ) >>> print(ann_cards.cardinalities) [3, 4] >>> print(ann_cards.states[0]) ['0', '1', '2'] >>> >>> # Access methods >>> idx = ann_binary.get_index('has_wheels') >>> print(idx) 0 >>> label = ann_binary.get_label(1) >>> print(label) has_windows """ labels: List[str] states: Optional[List[List[str]]] = field(default=None) cardinalities: Optional[List[int]] = field(default=None) types: Optional[List[str]] = field(default=None) # 'binary' | 'categorical' | 'continuous' metadata: Optional[Dict[str, Dict]] = field(default=None) # Concept-space annotation: each concept occupies a single integer-coded # column regardless of its type (so all cardinalities are 1). This describes # a ground-truth concept tensor, not the model's logit space. When True the # ``categorical requires cardinality > 1`` invariant is relaxed (a categorical # concept's column holds an integer class index). Build one from a normal # (logit-space) annotation via :meth:`to_concept_space`. concept_space: bool = field(default=False) def __setattr__(self, key, value): # `metadata` may change after construction, so it is # freely reassignable. The structural fields (labels, states, # cardinalities, types) remain write-once. if key == 'metadata': super().__setattr__(key, value) return if key in self.__dict__ and self.__dict__[key] is not None: raise AttributeError(f"'{key}' is write-once and already set") super().__setattr__(key, value) def __post_init__(self): """Validate consistency, infer is_nested and eventually states, and cardinalities.""" # Initialize states and cardinalities based on what's provided if self.states is not None and self.cardinalities is None: # Infer cardinalities from states self.cardinalities = [len(state_tuple) for state_tuple in self.states] elif self.states is None and self.cardinalities is not None: # Generate default state labels from cardinalities self.states = [ [str(i) for i in range(card)] if card > 1 else ['0'] for card in self.cardinalities ] elif self.states is None and self.cardinalities is None: # Neither provided - assume binary warnings.warn( "Annotations: neither 'states' nor 'cardinalities' provided; " "assuming all concepts are binary." ) self.cardinalities = [1 for _ in self.labels] self.states = [['0'] for _ in self.labels] else: # Both provided - use as-is for now, will validate below pass # Validate consistency now that both are populated if len(self.states) != len(self.labels): raise ValueError( f"Number of state tuples ({len(self.states)}) must match " f"number of labels ({len(self.labels)})" ) if len(self.cardinalities) != len(self.labels): raise ValueError( f"Number of cardinalities ({len(self.cardinalities)}) must match " f"number of labels ({len(self.labels)})" ) # Verify states length matches cardinalities # does not break with tuple cardinalities inferred_cardinalities = [len(state_tuple) for state_tuple in self.states] if list(self.cardinalities) != inferred_cardinalities: raise ValueError( f"Provided cardinalities {self.cardinalities} don't match " f"inferred cardinalities {inferred_cardinalities} from states" ) # Validate optional per-concept lists line up with the labels. if self.types is not None and len(self.types) != len(self.labels): raise ValueError( f"Number of types ({len(self.types)}) must match " f"number of labels ({len(self.labels)})" ) # Canonicalise the concept type. It is one of 'binary' / 'categorical' / # 'continuous'. When omitted, default discrete concepts from cardinality # (binary if card==1 else categorical); 'continuous' must be declared # explicitly (it cannot be inferred). Then enforce the type<->cardinality # invariant so the two never drift. if self.types is None: resolved_types = [ 'binary' if card == 1 else 'categorical' for card in self.cardinalities ] else: resolved_types = list(self.types) for label, t, card in zip(self.labels, resolved_types, self.cardinalities): if t not in _CONCEPT_TYPES: raise ValueError( f"Concept {label!r}: type must be one of {_CONCEPT_TYPES}, got {t!r}." ) if t == 'binary' and card != 1: raise ValueError( f"Concept {label!r}: 'binary' requires cardinality 1, got {card}." ) # In concept-space every concept is a single integer-coded column, so a # categorical concept legitimately has cardinality 1 (the column holds a # class index); only enforce the >1 invariant in logit-space annotations. if t == 'categorical' and card <= 1 and not self.concept_space: raise ValueError( f"Concept {label!r}: 'categorical' requires cardinality > 1, got {card}." ) object.__setattr__(self, 'types', resolved_types) # Determine is_nested from cardinalities is_nested = any(card > 1 for card in self.cardinalities) object.__setattr__(self, 'is_nested', is_nested) # Consistency checks on metadata if self.metadata is not None: if not isinstance(self.metadata, dict): raise ValueError("metadata must be a dictionary") # Only validate if metadata is non-empty if self.metadata: for label in self.labels: if label not in self.metadata: raise ValueError(f"Metadata missing for label {label!r}") @property def size(self) -> int: """Flattened concept dimension: ``sum(cardinalities)``. Equals ``len(labels)`` when non-nested (all cardinalities are 1). This is the size of axis 1 of the annotated (logit-space) tensor. """ return sum(self.cardinalities) @property def shape(self) -> Tuple[int, int]: """Annotated tensor shape ``(B, sum(cardinalities))``. Axis 0 is the unknown batch dimension B, returned as ``-1``; axis 1 is the flattened concept dimension :attr:`size`. """ return (-1, self.size) def has_metadata(self, key) -> bool: """Check if metadata contains a specific key for all labels.""" if self.metadata is None: return False return all(key in self.metadata.get(label, {}) for label in self.labels) def groupby_metadata(self, key, layout: str='labels') -> dict: """Check if metadata contains a specific key for all labels.""" if self.metadata is None: return {} result = {} for label in self.labels: meta = self.metadata.get(label, {}) if key in meta: group = meta[key] if group not in result: result[group] = [] if layout == 'labels': result[group].append(label) elif layout == 'indices': result[group].append(self.get_index(label)) else: raise ValueError(f"Unknown layout {layout}") return result def __len__(self) -> int: """Return number of labels.""" return len(self.labels) def __getitem__(self, key: Union[int, str]) -> Union[str, "Concept"]: """ Index by position or by name. - ``annotations[int]`` returns the label at that index (``str``). - ``annotations[str]`` returns the :class:`Concept` view for that label. """ if isinstance(key, str): return self.concept(key) if not (0 <= key < len(self.labels)): raise IndexError(f"Index {key} out of range") return self.labels[key] @cached_property def label_to_index(self) -> Dict[str, int]: """Precomputed mapping from concept name to concept-level index. Provides O(1) lookup for concept indices, useful for efficient concept extraction operations. Example: >>> ann = Annotations(labels=['color', 'shape', 'size']) >>> ann.label_to_index['shape'] 1 """ return {name: i for i, name in enumerate(self.labels)} def get_index(self, label: str) -> int: """Get index of a label.""" try: return self.label_to_index[label] except KeyError: raise ValueError(f"Label {label!r} not found in labels {self.labels}") def get_label(self, idx: int) -> str: """Get label at given index.""" if not (0 <= idx < len(self.labels)): raise IndexError(f"Index {idx} out of range with {len(self.labels)} labels") return self.labels[idx] def get_total_cardinality(self) -> Optional[int]: """Get total cardinality for nested annotations, or number of labels otherwise.""" if self.is_nested: if self.cardinalities is not None: return sum(self.cardinalities) else: raise ValueError("Cardinalities are not defined for this nested annotation") else: return len(self.labels) # ========================================================================= # State navigation # ========================================================================= def get_label_states(self, label: str) -> List[str]: """Get the ordered state labels of a concept.""" return self.states[self.get_index(label)] def get_label_state(self, label: str, idx: int) -> str: """Get the state label at position ``idx`` of a concept.""" return self.states[self.get_index(label)][idx] def get_state_index(self, label: str, state: str) -> int: """Get the index of a state label for a concept.""" states = self.states[self.get_index(label)] try: return states.index(state) except ValueError: raise ValueError(f"State {state!r} not found for concept {label!r}") # ========================================================================= # Cached index properties for efficient tensor slicing # ========================================================================= @cached_property def cumulative_cardinalities(self) -> List[int]: """Precomputed cumulative cardinalities for O(1) slicing. Returns a list where cumulative_cardinalities[i] is the starting position of concept i in the flattened tensor, and cumulative_cardinalities[i+1] is the ending position (exclusive). Example: >>> ann = Annotations(labels=['color', 'shape', 'size'], cardinalities=[3, 2, 1]) >>> ann.cumulative_cardinalities [0, 3, 5, 6] """ cum = [0] for c in self.cardinalities: cum.append(cum[-1] + c) return cum @cached_property def concept_slices(self) -> Dict[str, slice]: """Precomputed mapping from concept name to slice in flattened tensor. Example: >>> ann = Annotations(labels=['color', 'shape', 'size'], cardinalities=[3, 2, 1]) >>> ann.concept_slices['color'] slice(0, 3, None) """ cum = self.cumulative_cardinalities return {name: slice(cum[i], cum[i+1]) for i, name in enumerate(self.labels)} @cached_property def labels_by_type(self) -> Dict[str, List[str]]: """Mapping from concept type to the ordered list of labels of that type. Only non-empty types are included. Derived from :attr:`type_groups` so the grouping logic lives in one place. Example: >>> ann = Annotations( ... labels=['a', 'b', 'c'], ... cardinalities=[1, 1, 3], ... ) >>> ann.labels_by_type {'binary': ['a', 'b'], 'categorical': ['c']} """ return {t: g['labels'] for t, g in self.type_groups.items() if g['labels']} @cached_property def type_groups(self) -> Dict[str, Dict[str, List]]: """Precomputed type-based groupings at both concept and logit levels. Returns a dict with keys 'binary', 'categorical', 'continuous', each containing: - 'labels': list of concept names - 'concept_idx': list of concept-level indices - 'logits_idx': list of logit-level indices Example: >>> ann = Annotations( ... labels=['size', 'color', 'temp'], ... cardinalities=[1, 3, 1], ... types=['binary', 'categorical', 'continuous'], ... ) >>> ann.type_groups['binary']['labels'] ['size'] >>> ann.type_groups['categorical']['logits_idx'] [1, 2, 3] """ cum = self.cumulative_cardinalities groups = {t: {'labels': [], 'concept_idx': [], 'logits_idx': []} for t in _CONCEPT_TYPES} for i, label in enumerate(self.labels): group = groups[self.types[i]] group['labels'].append(label) group['concept_idx'].append(i) group['logits_idx'].extend(range(cum[i], cum[i + 1])) return groups def concept(self, name: str) -> "Concept": """Return a read-only :class:`Concept` view for ``name``. Groups the concept's per-column properties (cardinality, type, states, logit slice) into one object, so callers can write ``annotations.concept('color').cardinality`` instead of the index-dance over the parallel lists. Built fresh on each call so it always reflects the current (mutable) ``metadata``. """ i = self.get_index(name) meta = (self.metadata.get(name, {}) if self.metadata else {}) or {} return Concept( name=name, index=i, cardinality=int(self.cardinalities[i]), type=self.types[i], states=self.states[i] if self.states is not None else None, slice=self.concept_slices[name], metadata=meta, ) @property def concepts(self) -> List["Concept"]: """All concepts as :class:`Concept` views, in axis order. Views are read from the canonical parallel lists (no duplicated storage); useful for one-pass iteration. Not cached, so it reflects the current (mutable) ``metadata``. """ return [self.concept(name) for name in self.labels] def slice_tensor(self, tensor: torch.Tensor, concepts: List[str]) -> torch.Tensor: """Extract and concatenate columns for specified concepts. Args: tensor: Input tensor of shape (batch, total_logits) concepts: List of concept names to extract, in desired output order Returns: Tensor with columns for specified concepts concatenated Example: >>> import torch >>> ann = Annotations(labels=['color', 'shape'], cardinalities=[3, 2]) >>> predictions = torch.rand(4, 5) >>> reordered = ann.slice_tensor(predictions, ann.labels) >>> reordered.shape torch.Size([4, 5]) """ pieces = [tensor[:, self.concept_slices[c]] for c in concepts] return torch.cat(pieces, dim=1) def get_slice(self, labels: Union[str, List[str]]) -> Union[slice, List[int]]: """Get slice or indices for concept(s) in the flattened tensor. Unified method for accessing concept positions: - Single concept name → returns slice object for tensor indexing - List of concept names → returns flattened list of indices Uses precomputed concept_slices for O(1) per-concept lookup. Args: labels: Single concept name (str) or list of concept names. Returns: - slice: If labels is a single string - List[int]: If labels is a list of strings Raises: ValueError: If any label is not found in the axis labels. Example: >>> ann = Annotations( ... labels=['color', 'shape', 'size'], ... cardinalities=[3, 2, 1] ... ) >>> # Single concept → slice >>> ann.get_slice('color') slice(0, 3, None) >>> # Multiple concepts → flattened indices >>> ann.get_slice(['color', 'size']) [0, 1, 2, 5] """ slices = self.concept_slices # Use cached property # Single concept → return slice directly if isinstance(labels, str): if labels not in slices: raise ValueError(f"Label '{labels}' not found in axis labels {self.labels}") return slices[labels] # Multiple concepts → return flattened indices logits_indices = [] for label in labels: if label not in slices: raise ValueError(f"Label '{label}' not found in axis labels {self.labels}") s = slices[label] logits_indices.extend(range(s.start, s.stop)) return logits_indices def get_logits_idx(self, labels: List[str]) -> List[int]: """Alias for get_slice(labels) when labels is a list. Deprecated: Use get_slice() instead. """ return self.get_slice(labels) @classmethod def empty( cls, n: int, cardinalities: Optional[Union[int, List[int]]] = None, types: Optional[Union[str, List[str]]] = None ) -> "Annotations": """Create an Annotations with *n* anonymous binary labels ``c_0 … c_{n-1}``. Args: n: Number of labels. Returns: A new :class:`Annotations` with labels ``['c_0', 'c_1', 'c_2', 'c_3']``. Example: >>> ann = Annotations.empty(4) >>> ann.labels ['c_0', 'c_1', 'c_2', 'c_3'] """ cardinalities = [cardinalities] * n if isinstance(cardinalities, int) else cardinalities types = [types] * n if isinstance(types, str) else types # broadcast single str to list return cls( labels=[f"c_{i}" for i in range(n)], cardinalities=cardinalities, types=types ) def to_dict(self) -> Dict[str, Any]: """ Convert to JSON-serializable dictionary. Returns ------- dict Dictionary with all attributes, converting DataFrame to dict format. """ result = { 'labels': list(self.labels), 'is_nested': self.is_nested, 'states': [list(s) for s in self.states] if self.states else None, 'cardinalities': list(self.cardinalities) if self.cardinalities else None, 'types': list(self.types) if self.types else None, 'metadata': self.metadata, } return result @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Annotations': """ Create Annotations from dictionary. Parameters ---------- data : dict Dictionary with serialized Annotations data. Returns ------- Annotations Reconstructed Annotations object. """ # Keep as lists (native format) labels = data['labels'] states = [list(s) for s in data['states']] if data.get('states') else None cardinalities = data['cardinalities'] return cls( labels=labels, states=states, cardinalities=cardinalities, types=data.get('types'), metadata=data.get('metadata'), ) def subset(self, keep_labels: Sequence[str]) -> "Annotations": """ Return a new Annotations restricted to `keep_labels` (order follows the order in `keep_labels`). Raises ------ ValueError if any requested label is missing. """ # 1) validate + map to indices, preserving requested order label_set = set(self.labels) missing = [lab for lab in keep_labels if lab not in label_set] if missing: raise ValueError(f"Unknown labels for subset: {missing}") idxs = [self.get_index(lab) for lab in keep_labels] # 2) slice labels / states / cardinalities / types new_labels = [self.labels[i] for i in idxs] if self.states is not None: new_states = [self.states[i] for i in idxs] new_cards = [len(s) for s in new_states] else: new_states = None new_cards = None new_types = [self.types[i] for i in idxs] # 3) slice metadata (if present) new_metadata = None if self.metadata is not None: new_metadata = {lab: self.metadata[lab] for lab in keep_labels} # 4) build a fresh object return Annotations( labels=new_labels, states=new_states, cardinalities=new_cards, types=new_types, metadata=new_metadata, concept_space=self.concept_space, ) def to_concept_space(self) -> "Annotations": """Return a concept-space view: one integer-coded column per concept. Cardinalities collapse to 1 (each concept becomes a single column) while labels and types are preserved, so a ground-truth concept tensor of shape ``(batch, n_concepts)`` (integer class indices for categorical concepts, 0/1 for binary) can be wrapped as an :class:`~torch_concepts.tensor.AnnotatedTensor`. The result's :attr:`size` equals the number of concepts, and label-based slicing / :meth:`labels_by_type` operate per concept. Returns ``self`` unchanged if this annotation is already concept-space. """ if self.concept_space: return self return Annotations( labels=list(self.labels), cardinalities=[1] * len(self.labels), types=list(self.types), metadata=self.metadata, concept_space=True, ) def union_with(self, other: "Annotations") -> "Annotations": left = list(self.labels) right_only = [l for l in other.labels if l not in set(left)] labels = left + right_only def _merge(left_values, right_values): """Left values + right-only values (left wins for overlapping labels).""" return list(left_values) + [ right_values[other.labels.index(l)] for l in right_only ] # ``states`` / ``types`` are always populated after construction, so we # carry them through directly (cardinalities re-infer from states). This # keeps categorical concepts' cardinalities intact through a union. new_states = _merge(self.states, other.states) new_types = _merge(self.types, other.types) # merge metadata left-wins meta = None if self.metadata or other.metadata: meta = {} if self.metadata: meta.update(self.metadata) if other.metadata: for k, v in other.metadata.items(): if k not in meta: meta[k] = v return Annotations( labels=labels, states=new_states, cardinalities=None, types=new_types, metadata=meta, )