torch_concepts.annotations.AxisAnnotation

class AxisAnnotation(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, metadata: Dict[str, Dict] | None = None)[source]

Annotations for a single axis of a tensor.

This class provides semantic labeling for one dimension of a tensor, supporting both simple binary concepts and nested multi-state concepts.

labels

Ordered, unique labels for this axis.

Type:

list[str]

states

State labels for each concept (if nested).

Type:

Optional[list[list[str]]]

cardinalities

Cardinality of each concept.

Type:

Optional[list[int]]

metadata

Additional metadata for each label.

Type:

Optional[Dict[str, Dict]]

is_nested

Whether this axis has nested/hierarchical structure.

Type:

bool

Parameters:
  • labels – List of concept names for this axis.

  • states – Optional list of state lists for nested concepts.

  • cardinalities – Optional list of cardinalities per concept.

  • metadata – Optional metadata dictionary keyed by label names.

Example

>>> from torch_concepts import AxisAnnotation
>>>
>>> # Simple binary concepts
>>> axis_binary = AxisAnnotation(
...     labels=['has_wheels', 'has_windows', 'is_red']
... )
>>> print(axis_binary.labels)  # ['has_wheels', 'has_windows', 'is_red']
>>> print(axis_binary.is_nested)  # False
>>> print(axis_binary.cardinalities)  # [1, 1, 1] - binary concepts
>>>
>>> # Nested concepts with explicit states
>>> axis_nested = AxisAnnotation(
...     labels=['color', 'shape'],
...     states=[['red', 'green', 'blue'], ['circle', 'square']],
... )
>>> print(axis_nested.labels)  # ['color', 'shape']
>>> print(axis_nested.is_nested)  # True
>>> print(axis_nested.cardinalities)  # [3, 2]
>>> print(axis_nested.states[0])  # ['red', 'green', 'blue']
>>>
>>> # With cardinalities only (auto-generates state labels)
>>> axis_cards = AxisAnnotation(
...     labels=['size', 'material'],
...     cardinalities=[3, 4]  # 3 sizes, 4 materials
... )
>>> print(axis_cards.cardinalities)  # [3, 4]
>>> print(axis_cards.states[0])  # ['0', '1', '2']
>>>
>>> # Access methods
>>> idx = axis_binary.get_index('has_wheels')
>>> print(idx)  # 0
>>> label = axis_binary.get_label(1)
>>> print(label)  # 'has_windows'
__init__(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, metadata: Dict[str, Dict] | None = None) None

Methods

__init__(labels[, states, cardinalities, ...])

from_dict(data)

Create AxisAnnotation from dictionary.

get_endogenous_idx(labels)

Get endogenous (logit-level) indices for a list of concept labels.

get_index(label)

Get index of a label in this axis.

get_label(idx)

Get label at given index in this axis.

get_total_cardinality()

Get total cardinality for nested axis, or None if not nested.

groupby_metadata(key[, layout])

Check if metadata contains a specific key for all labels.

has_metadata(key)

Check if metadata contains a specific key for all labels.

subset(keep_labels)

Return a new AxisAnnotation restricted to keep_labels (order follows the order in keep_labels).

to_dict()

Convert to JSON-serializable dictionary.

union_with(other)

Attributes

cardinalities

metadata

shape

Return the size of this axis.

states

labels