Annotations

Containers for model configuration and type information.

Summary

Annotation Classes

AxisAnnotation

Annotations for a single axis of a tensor.

Annotations

Multi-axis annotation container for concept tensors.

Class Documentation

class AxisAnnotation(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, metadata: Dict[str, Dict] | None = None)[source]

Bases: object

Annotations for a single axis of a tensor.

This class provides semantic labeling for one dimension of a tensor, supporting both simple binary concepts and nested multi-state concepts.

labels

Ordered, unique labels for this axis.

Type:

list[str]

states

State labels for each concept (if nested).

Type:

Optional[list[list[str]]]

cardinalities

Cardinality of each concept.

Type:

Optional[list[int]]

metadata

Additional metadata for each label.

Type:

Optional[Dict[str, Dict]]

is_nested

Whether this axis has nested/hierarchical structure.

Type:

bool

Parameters:
  • labels – List of concept names for this axis.

  • states – Optional list of state lists for nested concepts.

  • cardinalities – Optional list of cardinalities per concept.

  • metadata – Optional metadata dictionary keyed by label names.

Example

>>> from torch_concepts import AxisAnnotation
>>>
>>> # Simple binary concepts
>>> axis_binary = AxisAnnotation(
...     labels=['has_wheels', 'has_windows', 'is_red']
... )
>>> print(axis_binary.labels)  # ['has_wheels', 'has_windows', 'is_red']
>>> print(axis_binary.is_nested)  # False
>>> print(axis_binary.cardinalities)  # [1, 1, 1] - binary concepts
>>>
>>> # Nested concepts with explicit states
>>> axis_nested = AxisAnnotation(
...     labels=['color', 'shape'],
...     states=[['red', 'green', 'blue'], ['circle', 'square']],
... )
>>> print(axis_nested.labels)  # ['color', 'shape']
>>> print(axis_nested.is_nested)  # True
>>> print(axis_nested.cardinalities)  # [3, 2]
>>> print(axis_nested.states[0])  # ['red', 'green', 'blue']
>>>
>>> # With cardinalities only (auto-generates state labels)
>>> axis_cards = AxisAnnotation(
...     labels=['size', 'material'],
...     cardinalities=[3, 4]  # 3 sizes, 4 materials
... )
>>> print(axis_cards.cardinalities)  # [3, 4]
>>> print(axis_cards.states[0])  # ['0', '1', '2']
>>>
>>> # Access methods
>>> idx = axis_binary.get_index('has_wheels')
>>> print(idx)  # 0
>>> label = axis_binary.get_label(1)
>>> print(label)  # 'has_windows'
labels: List[str]
states: List[List[str]] | None = None
cardinalities: List[int] | None = None
metadata: Dict[str, Dict] | None = None
property shape: int | Tuple[int, ...]

Return the size of this axis. For non-nested: int (number of labels) For nested: tuple of ints (cardinalities)

has_metadata(key) bool[source]

Check if metadata contains a specific key for all labels.

groupby_metadata(key, layout: str = 'labels') dict[source]

Check if metadata contains a specific key for all labels.

get_index(label: str) int[source]

Get index of a label in this axis.

get_label(idx: int) str[source]

Get label at given index in this axis.

get_total_cardinality() int | None[source]

Get total cardinality for nested axis, or None if not nested.

get_endogenous_idx(labels: List[str]) List[int][source]

Get endogenous (logit-level) indices for a list of concept labels.

This method returns the flattened tensor indices where the logits/values for the specified concepts appear, accounting for each concept’s cardinality.

Parameters:

labels – List of concept label names to get indices for.

Returns:

List of endogenous indices in the flattened tensor, in the order corresponding to the input labels.

Raises:

ValueError – If any label is not found in the axis labels.

Example

>>> # Concepts: ['color', 'shape', 'size'] with cardinalities [3, 2, 1]
>>> # Flattened tensor has 6 positions: [c0, c1, c2, s0, s1, sz]
>>> axis = AxisAnnotation(
...     labels=['color', 'shape', 'size'],
...     cardinalities=[3, 2, 1]
... )
>>> axis.get_endogenous_idx(['color', 'size'])
[0, 1, 2, 5]  # color takes positions 0-2, size takes position 5
to_dict() Dict[str, Any][source]

Convert to JSON-serializable dictionary.

Returns:

Dictionary with all attributes, converting DataFrame to dict format.

Return type:

dict

classmethod from_dict(data: Dict[str, Any]) AxisAnnotation[source]

Create AxisAnnotation from dictionary.

Parameters:

data (dict) – Dictionary with serialized AxisAnnotation data.

Returns:

Reconstructed AxisAnnotation object.

Return type:

AxisAnnotation

subset(keep_labels: Sequence[str]) AxisAnnotation[source]

Return a new AxisAnnotation restricted to keep_labels (order follows the order in keep_labels).

Raises:

ValueError if any requested label is missing.

union_with(other: AxisAnnotation) AxisAnnotation[source]
class Annotations(axis_annotations: List | Dict[int, AxisAnnotation] | None = None)[source]

Bases: object

Multi-axis annotation container for concept tensors.

This class manages annotations for multiple tensor dimensions, providing a unified interface for working with concept-based tensors that may have different semantic meanings along different axes.

_axis_annotations

Map from axis index to annotation.

Type:

Dict[int, AxisAnnotation]

Parameters:

axis_annotations – Either a list of AxisAnnotations (indexed 0, 1, 2, …) or a dict mapping axis numbers to AxisAnnotations.

Example

>>> from torch_concepts import Annotations, AxisAnnotation
>>>
>>> # Create annotations for a concept tensor
>>> # Axis 0: batch (typically not annotated)
>>> # Axis 1: concepts
>>> concept_ann = AxisAnnotation(
...     labels=['color', 'shape', 'size'],
...     cardinalities=[3, 2, 1]  # 3 colors, 2 shapes, 1 binary size
... )
>>>
>>> # Create annotations object
>>> annotations = Annotations({1: concept_ann})
>>>
>>> # Access concept labels
>>> print(annotations.get_axis_labels(1))  # ['color', 'shape', 'size']
>>>
>>> # Get index of a concept
>>> idx = annotations.get_index(1, 'color')
>>> print(idx)  # 0
>>>
>>> # Check if axis is nested
>>> print(annotations.is_axis_nested(1))  # True
>>>
>>> # Get cardinalities
>>> print(annotations.get_axis_cardinalities(1))  # [3, 2, 1]
>>>
>>> # Access via indexing
>>> print(annotations[1].labels)  # ['color', 'shape', 'size']
>>>
>>> # Multiple axes example
>>> task_ann = AxisAnnotation(labels=['task1', 'task2', 'task3'])
>>> multi_ann = Annotations({
...     1: concept_ann,
...     2: task_ann
... })
>>> print(multi_ann.annotated_axes)  # (1, 2)
annotate_axis(axis_annotation: AxisAnnotation, axis: int) None[source]

Add or update annotation for an axis.

property shape: Tuple[int, ...]

Get shape of the annotated tensor based on annotations.

property num_annotated_axes: int

Number of annotated axes.

property annotated_axes: Tuple[int, ...]

Tuple of annotated axis numbers (sorted).

has_axis(axis: int) bool[source]

Check if an axis is annotated.

get_axis_annotation(axis: int) AxisAnnotation[source]

Get annotation for a specific axis.

get_axis_labels(axis: int) List[str][source]

Get ordered labels for an axis.

get_axis_cardinalities(axis: int) List[int] | None[source]

Get cardinalities for an axis (if nested), or None.

is_axis_nested(axis: int) bool[source]

Check if an axis has nested structure.

get_index(axis: int, label: str) int[source]

Get index of a label within an axis.

get_label(axis: int, idx: int) str[source]

Get label at index within an axis.

get_states(axis: int) List[List[str]] | None[source]

Get states for a nested axis, or None.

get_label_states(axis: int, label: str) List[str][source]

Get states of a concept in a nested axis.

get_label_state(axis: int, label: str, idx: int) str[source]

Get states of a concept in a nested axis.

get_state_index(axis: int, label: str, state: str) int[source]

Get index of a state label for a concept in a nested axis.

keys()[source]

Return axis numbers (dict-like interface).

values()[source]

Return AxisAnnotation objects (dict-like interface).

items()[source]

Return (axis, AxisAnnotation) pairs (dict-like interface).

property axis_annotations: Dict[int, AxisAnnotation]

Access to the underlying axis annotations dictionary.

select(axis: int, keep_labels: Sequence[str]) Annotations[source]

Return a new Annotations where only keep_labels are kept on axis. Other axes are unchanged.

select_many(labels_by_axis: Dict[int, Sequence[str]]) Annotations[source]

Return a new Annotations applying independent label filters per axis.

join_union(other: Annotations, axis: int) Annotations[source]
to_dict() Dict[str, Any][source]

Convert to JSON-serializable dictionary.

Returns:

Dictionary with axis annotations.

Return type:

dict

classmethod from_dict(data: Dict[str, Any]) Annotations[source]

Create Annotations from dictionary.

Parameters:

data (dict) – Dictionary with serialized Annotations data.

Returns:

Reconstructed Annotations object.

Return type:

Annotations