torch_concepts.annotations.Annotations¶
- class Annotations(axis_annotations: List | Dict[int, AxisAnnotation] | None = None)[source]¶
Multi-axis annotation container for concept tensors.
This class manages annotations for multiple tensor dimensions, providing a unified interface for working with concept-based tensors that may have different semantic meanings along different axes.
- _axis_annotations¶
Map from axis index to annotation.
- Type:
Dict[int, AxisAnnotation]
- Parameters:
axis_annotations – Either a list of AxisAnnotations (indexed 0, 1, 2, …) or a dict mapping axis numbers to AxisAnnotations.
Example
>>> from torch_concepts import Annotations, AxisAnnotation >>> >>> # Create annotations for a concept tensor >>> # Axis 0: batch (typically not annotated) >>> # Axis 1: concepts >>> concept_ann = AxisAnnotation( ... labels=['color', 'shape', 'size'], ... cardinalities=[3, 2, 1] # 3 colors, 2 shapes, 1 binary size ... ) >>> >>> # Create annotations object >>> annotations = Annotations({1: concept_ann}) >>> >>> # Access concept labels >>> print(annotations.get_axis_labels(1)) # ['color', 'shape', 'size'] >>> >>> # Get index of a concept >>> idx = annotations.get_index(1, 'color') >>> print(idx) # 0 >>> >>> # Check if axis is nested >>> print(annotations.is_axis_nested(1)) # True >>> >>> # Get cardinalities >>> print(annotations.get_axis_cardinalities(1)) # [3, 2, 1] >>> >>> # Access via indexing >>> print(annotations[1].labels) # ['color', 'shape', 'size'] >>> >>> # Multiple axes example >>> task_ann = AxisAnnotation(labels=['task1', 'task2', 'task3']) >>> multi_ann = Annotations({ ... 1: concept_ann, ... 2: task_ann ... }) >>> print(multi_ann.annotated_axes) # (1, 2)
- __init__(axis_annotations: List | Dict[int, AxisAnnotation] | None = None)[source]¶
Initialize Annotations container.
- Parameters:
axis_annotations – Either a list or dict of AxisAnnotation objects.
Methods
__init__([axis_annotations])Initialize Annotations container.
annotate_axis(axis_annotation, axis)Add or update annotation for an axis.
from_dict(data)Create Annotations from dictionary.
get_axis_annotation(axis)Get annotation for a specific axis.
get_axis_cardinalities(axis)Get cardinalities for an axis (if nested), or None.
get_axis_labels(axis)Get ordered labels for an axis.
get_index(axis, label)Get index of a label within an axis.
get_label(axis, idx)Get label at index within an axis.
get_label_state(axis, label, idx)Get states of a concept in a nested axis.
get_label_states(axis, label)Get states of a concept in a nested axis.
get_state_index(axis, label, state)Get index of a state label for a concept in a nested axis.
get_states(axis)Get states for a nested axis, or None.
has_axis(axis)Check if an axis is annotated.
is_axis_nested(axis)Check if an axis has nested structure.
items()Return (axis, AxisAnnotation) pairs (dict-like interface).
join_union(other, axis)keys()Return axis numbers (dict-like interface).
select(axis, keep_labels)Return a new Annotations where only keep_labels are kept on axis.
select_many(labels_by_axis)Return a new Annotations applying independent label filters per axis.
to_dict()Convert to JSON-serializable dictionary.
values()Return AxisAnnotation objects (dict-like interface).
Attributes
Tuple of annotated axis numbers (sorted).
Access to the underlying axis annotations dictionary.
Number of annotated axes.
Get shape of the annotated tensor based on annotations.