torch_concepts.annotations.Annotations

class Annotations(axis_annotations: List | Dict[int, AxisAnnotation] | None = None)[source]

Multi-axis annotation container for concept tensors.

This class manages annotations for multiple tensor dimensions, providing a unified interface for working with concept-based tensors that may have different semantic meanings along different axes.

_axis_annotations

Map from axis index to annotation.

Type:

Dict[int, AxisAnnotation]

Parameters:

axis_annotations – Either a list of AxisAnnotations (indexed 0, 1, 2, …) or a dict mapping axis numbers to AxisAnnotations.

Example

>>> from torch_concepts import Annotations, AxisAnnotation
>>>
>>> # Create annotations for a concept tensor
>>> # Axis 0: batch (typically not annotated)
>>> # Axis 1: concepts
>>> concept_ann = AxisAnnotation(
...     labels=['color', 'shape', 'size'],
...     cardinalities=[3, 2, 1]  # 3 colors, 2 shapes, 1 binary size
... )
>>>
>>> # Create annotations object
>>> annotations = Annotations({1: concept_ann})
>>>
>>> # Access concept labels
>>> print(annotations.get_axis_labels(1))  # ['color', 'shape', 'size']
>>>
>>> # Get index of a concept
>>> idx = annotations.get_index(1, 'color')
>>> print(idx)  # 0
>>>
>>> # Check if axis is nested
>>> print(annotations.is_axis_nested(1))  # True
>>>
>>> # Get cardinalities
>>> print(annotations.get_axis_cardinalities(1))  # [3, 2, 1]
>>>
>>> # Access via indexing
>>> print(annotations[1].labels)  # ['color', 'shape', 'size']
>>>
>>> # Multiple axes example
>>> task_ann = AxisAnnotation(labels=['task1', 'task2', 'task3'])
>>> multi_ann = Annotations({
...     1: concept_ann,
...     2: task_ann
... })
>>> print(multi_ann.annotated_axes)  # (1, 2)
__init__(axis_annotations: List | Dict[int, AxisAnnotation] | None = None)[source]

Initialize Annotations container.

Parameters:

axis_annotations – Either a list or dict of AxisAnnotation objects.

Methods

__init__([axis_annotations])

Initialize Annotations container.

annotate_axis(axis_annotation, axis)

Add or update annotation for an axis.

from_dict(data)

Create Annotations from dictionary.

get_axis_annotation(axis)

Get annotation for a specific axis.

get_axis_cardinalities(axis)

Get cardinalities for an axis (if nested), or None.

get_axis_labels(axis)

Get ordered labels for an axis.

get_index(axis, label)

Get index of a label within an axis.

get_label(axis, idx)

Get label at index within an axis.

get_label_state(axis, label, idx)

Get states of a concept in a nested axis.

get_label_states(axis, label)

Get states of a concept in a nested axis.

get_state_index(axis, label, state)

Get index of a state label for a concept in a nested axis.

get_states(axis)

Get states for a nested axis, or None.

has_axis(axis)

Check if an axis is annotated.

is_axis_nested(axis)

Check if an axis has nested structure.

items()

Return (axis, AxisAnnotation) pairs (dict-like interface).

join_union(other, axis)

keys()

Return axis numbers (dict-like interface).

select(axis, keep_labels)

Return a new Annotations where only keep_labels are kept on axis.

select_many(labels_by_axis)

Return a new Annotations applying independent label filters per axis.

to_dict()

Convert to JSON-serializable dictionary.

values()

Return AxisAnnotation objects (dict-like interface).

Attributes

annotated_axes

Tuple of annotated axis numbers (sorted).

axis_annotations

Access to the underlying axis annotations dictionary.

num_annotated_axes

Number of annotated axes.

shape

Get shape of the annotated tensor based on annotations.