torch_concepts.Annotations

class Annotations(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, types: List[str] | None = None, metadata: Dict[str, Dict] | None = None, concept_space: bool = False)[source]

Annotations for the concept axis of a tensor.

This class provides semantic labeling for the concept dimension (axis 1) of a tensor, supporting both simple binary concepts and nested multi-state concepts. Axis 0 is the (unannotated) batch dimension.

labels

Ordered, unique concept labels.

Type:

list[str]

states

State labels for each concept (if nested).

Type:

Optional[list[list[str]]]

cardinalities

Cardinality of each concept.

Type:

Optional[list[int]]

types

'binary' / 'categorical' / 'continuous' per concept.

Type:

Optional[list[str]]

metadata

Additional metadata for each label.

Type:

Optional[Dict[str, Dict]]

is_nested

Whether the axis has nested/hierarchical structure.

Type:

bool

Parameters:
  • labels – List of concept names.

  • states – Optional list of state lists for nested concepts.

  • cardinalities – Optional list of cardinalities per concept.

  • types – Optional concept types per concept.

  • metadata – Optional metadata dictionary keyed by label names.

Example

>>> from torch_concepts import Annotations
>>>
>>> # Simple binary concepts
>>> ann_binary = Annotations(
...     labels=['has_wheels', 'has_windows', 'is_red']
... )
>>> print(ann_binary.labels)
['has_wheels', 'has_windows', 'is_red']
>>> print(ann_binary.is_nested)
False
>>> print(ann_binary.cardinalities)
[1, 1, 1]
>>> print(ann_binary.shape)
(-1, 3)
>>>
>>> # Nested concepts with explicit states
>>> ann_nested = Annotations(
...     labels=['color', 'shape'],
...     states=[['red', 'green', 'blue'], ['circle', 'square']],
... )
>>> print(ann_nested.labels)
['color', 'shape']
>>> print(ann_nested.is_nested)
True
>>> print(ann_nested.cardinalities)
[3, 2]
>>> print(ann_nested.states[0])
['red', 'green', 'blue']
>>> print(ann_nested.shape)
(-1, 5)
>>>
>>> # With cardinalities only (auto-generates state labels)
>>> ann_cards = Annotations(
...     labels=['size', 'material'],
...     cardinalities=[3, 4]
... )
>>> print(ann_cards.cardinalities)
[3, 4]
>>> print(ann_cards.states[0])
['0', '1', '2']
>>>
>>> # Access methods
>>> idx = ann_binary.get_index('has_wheels')
>>> print(idx)
0
>>> label = ann_binary.get_label(1)
>>> print(label)
has_windows
__init__(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, types: List[str] | None = None, metadata: Dict[str, Dict] | None = None, concept_space: bool = False) None

Methods

__init__(labels[, states, cardinalities, ...])

concept(name)

Return a read-only Concept view for name.

empty(n[, cardinalities, types])

Create an Annotations with n anonymous binary labels c_0 c_{n-1}.

from_dict(data)

Create Annotations from dictionary.

get_index(label)

Get index of a label.

get_label(idx)

Get label at given index.

get_label_state(label, idx)

Get the state label at position idx of a concept.

get_label_states(label)

Get the ordered state labels of a concept.

get_logits_idx(labels)

Alias for get_slice(labels) when labels is a list.

get_slice(labels)

Get slice or indices for concept(s) in the flattened tensor.

get_state_index(label, state)

Get the index of a state label for a concept.

get_total_cardinality()

Get total cardinality for nested annotations, or number of labels otherwise.

groupby_metadata(key[, layout])

Check if metadata contains a specific key for all labels.

has_metadata(key)

Check if metadata contains a specific key for all labels.

slice_tensor(tensor, concepts)

Extract and concatenate columns for specified concepts.

subset(keep_labels)

Return a new Annotations restricted to keep_labels (order follows the order in keep_labels).

to_concept_space()

Return a concept-space view: one integer-coded column per concept.

to_dict()

Convert to JSON-serializable dictionary.

union_with(other)

Attributes

cardinalities

concept_slices

Precomputed mapping from concept name to slice in flattened tensor.

concept_space

concepts

All concepts as Concept views, in axis order.

cumulative_cardinalities

Precomputed cumulative cardinalities for O(1) slicing.

label_to_index

Precomputed mapping from concept name to concept-level index.

labels_by_type

Mapping from concept type to the ordered list of labels of that type.

metadata

shape

Annotated tensor shape (B, sum(cardinalities)).

size

sum(cardinalities).

states

type_groups

Precomputed type-based groupings at both concept and logit levels.

types

labels