Annotations¶
Containers for model configuration and type information.
Summary¶
Annotation Classes
Annotations for a single axis of a tensor. |
|
Multi-axis annotation container for concept tensors. |
Class Documentation¶
- class AxisAnnotation(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, metadata: Dict[str, Dict] | None = None)[source]¶
Bases:
objectAnnotations for a single axis of a tensor.
This class provides semantic labeling for one dimension of a tensor, supporting both simple binary concepts and nested multi-state concepts.
- Parameters:
labels – List of concept names for this axis.
states – Optional list of state lists for nested concepts.
cardinalities – Optional list of cardinalities per concept.
metadata – Optional metadata dictionary keyed by label names.
Example
>>> from torch_concepts import AxisAnnotation >>> >>> # Simple binary concepts >>> axis_binary = AxisAnnotation( ... labels=['has_wheels', 'has_windows', 'is_red'] ... ) >>> print(axis_binary.labels) # ['has_wheels', 'has_windows', 'is_red'] >>> print(axis_binary.is_nested) # False >>> print(axis_binary.cardinalities) # [1, 1, 1] - binary concepts >>> >>> # Nested concepts with explicit states >>> axis_nested = AxisAnnotation( ... labels=['color', 'shape'], ... states=[['red', 'green', 'blue'], ['circle', 'square']], ... ) >>> print(axis_nested.labels) # ['color', 'shape'] >>> print(axis_nested.is_nested) # True >>> print(axis_nested.cardinalities) # [3, 2] >>> print(axis_nested.states[0]) # ['red', 'green', 'blue'] >>> >>> # With cardinalities only (auto-generates state labels) >>> axis_cards = AxisAnnotation( ... labels=['size', 'material'], ... cardinalities=[3, 4] # 3 sizes, 4 materials ... ) >>> print(axis_cards.cardinalities) # [3, 4] >>> print(axis_cards.states[0]) # ['0', '1', '2'] >>> >>> # Access methods >>> idx = axis_binary.get_index('has_wheels') >>> print(idx) # 0 >>> label = axis_binary.get_label(1) >>> print(label) # 'has_windows'
- property shape: int | Tuple[int, ...]¶
Return the size of this axis. For non-nested: int (number of labels) For nested: tuple of ints (cardinalities)
- groupby_metadata(key, layout: str = 'labels') dict[source]¶
Check if metadata contains a specific key for all labels.
- get_total_cardinality() int | None[source]¶
Get total cardinality for nested axis, or None if not nested.
- get_endogenous_idx(labels: List[str]) List[int][source]¶
Get endogenous (logit-level) indices for a list of concept labels.
This method returns the flattened tensor indices where the logits/values for the specified concepts appear, accounting for each concept’s cardinality.
- Parameters:
labels – List of concept label names to get indices for.
- Returns:
List of endogenous indices in the flattened tensor, in the order corresponding to the input labels.
- Raises:
ValueError – If any label is not found in the axis labels.
Example
>>> # Concepts: ['color', 'shape', 'size'] with cardinalities [3, 2, 1] >>> # Flattened tensor has 6 positions: [c0, c1, c2, s0, s1, sz] >>> axis = AxisAnnotation( ... labels=['color', 'shape', 'size'], ... cardinalities=[3, 2, 1] ... ) >>> axis.get_endogenous_idx(['color', 'size']) [0, 1, 2, 5] # color takes positions 0-2, size takes position 5
- to_dict() Dict[str, Any][source]¶
Convert to JSON-serializable dictionary.
- Returns:
Dictionary with all attributes, converting DataFrame to dict format.
- Return type:
- classmethod from_dict(data: Dict[str, Any]) AxisAnnotation[source]¶
Create AxisAnnotation from dictionary.
- Parameters:
data (dict) – Dictionary with serialized AxisAnnotation data.
- Returns:
Reconstructed AxisAnnotation object.
- Return type:
- subset(keep_labels: Sequence[str]) AxisAnnotation[source]¶
Return a new AxisAnnotation restricted to keep_labels (order follows the order in keep_labels).
- Raises:
ValueError if any requested label is missing. –
- union_with(other: AxisAnnotation) AxisAnnotation[source]¶
- class Annotations(axis_annotations: List | Dict[int, AxisAnnotation] | None = None)[source]¶
Bases:
objectMulti-axis annotation container for concept tensors.
This class manages annotations for multiple tensor dimensions, providing a unified interface for working with concept-based tensors that may have different semantic meanings along different axes.
- _axis_annotations¶
Map from axis index to annotation.
- Type:
Dict[int, AxisAnnotation]
- Parameters:
axis_annotations – Either a list of AxisAnnotations (indexed 0, 1, 2, …) or a dict mapping axis numbers to AxisAnnotations.
Example
>>> from torch_concepts import Annotations, AxisAnnotation >>> >>> # Create annotations for a concept tensor >>> # Axis 0: batch (typically not annotated) >>> # Axis 1: concepts >>> concept_ann = AxisAnnotation( ... labels=['color', 'shape', 'size'], ... cardinalities=[3, 2, 1] # 3 colors, 2 shapes, 1 binary size ... ) >>> >>> # Create annotations object >>> annotations = Annotations({1: concept_ann}) >>> >>> # Access concept labels >>> print(annotations.get_axis_labels(1)) # ['color', 'shape', 'size'] >>> >>> # Get index of a concept >>> idx = annotations.get_index(1, 'color') >>> print(idx) # 0 >>> >>> # Check if axis is nested >>> print(annotations.is_axis_nested(1)) # True >>> >>> # Get cardinalities >>> print(annotations.get_axis_cardinalities(1)) # [3, 2, 1] >>> >>> # Access via indexing >>> print(annotations[1].labels) # ['color', 'shape', 'size'] >>> >>> # Multiple axes example >>> task_ann = AxisAnnotation(labels=['task1', 'task2', 'task3']) >>> multi_ann = Annotations({ ... 1: concept_ann, ... 2: task_ann ... }) >>> print(multi_ann.annotated_axes) # (1, 2)
- annotate_axis(axis_annotation: AxisAnnotation, axis: int) None[source]¶
Add or update annotation for an axis.
- get_axis_annotation(axis: int) AxisAnnotation[source]¶
Get annotation for a specific axis.
- get_axis_cardinalities(axis: int) List[int] | None[source]¶
Get cardinalities for an axis (if nested), or None.
- get_label_states(axis: int, label: str) List[str][source]¶
Get states of a concept in a nested axis.
- get_label_state(axis: int, label: str, idx: int) str[source]¶
Get states of a concept in a nested axis.
- get_state_index(axis: int, label: str, state: str) int[source]¶
Get index of a state label for a concept in a nested axis.
- property axis_annotations: Dict[int, AxisAnnotation]¶
Access to the underlying axis annotations dictionary.
- select(axis: int, keep_labels: Sequence[str]) Annotations[source]¶
Return a new Annotations where only keep_labels are kept on axis. Other axes are unchanged.
- select_many(labels_by_axis: Dict[int, Sequence[str]]) Annotations[source]¶
Return a new Annotations applying independent label filters per axis.
- join_union(other: Annotations, axis: int) Annotations[source]¶
- to_dict() Dict[str, Any][source]¶
Convert to JSON-serializable dictionary.
- Returns:
Dictionary with axis annotations.
- Return type: