torch_concepts.annotations.AxisAnnotation¶
- class AxisAnnotation(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, metadata: Dict[str, Dict] | None = None)[source]¶
Annotations for a single axis of a tensor.
This class provides semantic labeling for one dimension of a tensor, supporting both simple binary concepts and nested multi-state concepts.
- Parameters:
labels – List of concept names for this axis.
states – Optional list of state lists for nested concepts.
cardinalities – Optional list of cardinalities per concept.
metadata – Optional metadata dictionary keyed by label names.
Example
>>> from torch_concepts import AxisAnnotation >>> >>> # Simple binary concepts >>> axis_binary = AxisAnnotation( ... labels=['has_wheels', 'has_windows', 'is_red'] ... ) >>> print(axis_binary.labels) # ['has_wheels', 'has_windows', 'is_red'] >>> print(axis_binary.is_nested) # False >>> print(axis_binary.cardinalities) # [1, 1, 1] - binary concepts >>> >>> # Nested concepts with explicit states >>> axis_nested = AxisAnnotation( ... labels=['color', 'shape'], ... states=[['red', 'green', 'blue'], ['circle', 'square']], ... ) >>> print(axis_nested.labels) # ['color', 'shape'] >>> print(axis_nested.is_nested) # True >>> print(axis_nested.cardinalities) # [3, 2] >>> print(axis_nested.states[0]) # ['red', 'green', 'blue'] >>> >>> # With cardinalities only (auto-generates state labels) >>> axis_cards = AxisAnnotation( ... labels=['size', 'material'], ... cardinalities=[3, 4] # 3 sizes, 4 materials ... ) >>> print(axis_cards.cardinalities) # [3, 4] >>> print(axis_cards.states[0]) # ['0', '1', '2'] >>> >>> # Access methods >>> idx = axis_binary.get_index('has_wheels') >>> print(idx) # 0 >>> label = axis_binary.get_label(1) >>> print(label) # 'has_windows'
- __init__(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, metadata: Dict[str, Dict] | None = None) None¶
Methods
__init__(labels[, states, cardinalities, ...])from_dict(data)Create AxisAnnotation from dictionary.
get_endogenous_idx(labels)Get endogenous (logit-level) indices for a list of concept labels.
get_index(label)Get index of a label in this axis.
get_label(idx)Get label at given index in this axis.
Get total cardinality for nested axis, or None if not nested.
groupby_metadata(key[, layout])Check if metadata contains a specific key for all labels.
has_metadata(key)Check if metadata contains a specific key for all labels.
subset(keep_labels)Return a new AxisAnnotation restricted to keep_labels (order follows the order in keep_labels).
to_dict()Convert to JSON-serializable dictionary.
union_with(other)Attributes
Return the size of this axis.