torch_concepts.Annotations¶
- class Annotations(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, types: List[str] | None = None, metadata: Dict[str, Dict] | None = None, concept_space: bool = False)[source]¶
Annotations for the concept axis of a tensor.
This class provides semantic labeling for the concept dimension (axis 1) of a tensor, supporting both simple binary concepts and nested multi-state concepts. Axis 0 is the (unannotated) batch dimension.
- Parameters:
labels – List of concept names.
states – Optional list of state lists for nested concepts.
cardinalities – Optional list of cardinalities per concept.
types – Optional concept types per concept.
metadata – Optional metadata dictionary keyed by label names.
Example
>>> from torch_concepts import Annotations >>> >>> # Simple binary concepts >>> ann_binary = Annotations( ... labels=['has_wheels', 'has_windows', 'is_red'] ... ) >>> print(ann_binary.labels) ['has_wheels', 'has_windows', 'is_red'] >>> print(ann_binary.is_nested) False >>> print(ann_binary.cardinalities) [1, 1, 1] >>> print(ann_binary.shape) (-1, 3) >>> >>> # Nested concepts with explicit states >>> ann_nested = Annotations( ... labels=['color', 'shape'], ... states=[['red', 'green', 'blue'], ['circle', 'square']], ... ) >>> print(ann_nested.labels) ['color', 'shape'] >>> print(ann_nested.is_nested) True >>> print(ann_nested.cardinalities) [3, 2] >>> print(ann_nested.states[0]) ['red', 'green', 'blue'] >>> print(ann_nested.shape) (-1, 5) >>> >>> # With cardinalities only (auto-generates state labels) >>> ann_cards = Annotations( ... labels=['size', 'material'], ... cardinalities=[3, 4] ... ) >>> print(ann_cards.cardinalities) [3, 4] >>> print(ann_cards.states[0]) ['0', '1', '2'] >>> >>> # Access methods >>> idx = ann_binary.get_index('has_wheels') >>> print(idx) 0 >>> label = ann_binary.get_label(1) >>> print(label) has_windows
- __init__(labels: List[str], states: List[List[str]] | None = None, cardinalities: List[int] | None = None, types: List[str] | None = None, metadata: Dict[str, Dict] | None = None, concept_space: bool = False) None¶
Methods
__init__(labels[, states, cardinalities, ...])concept(name)Return a read-only
Conceptview forname.empty(n[, cardinalities, types])Create an Annotations with n anonymous binary labels
c_0 … c_{n-1}.from_dict(data)Create Annotations from dictionary.
get_index(label)Get index of a label.
get_label(idx)Get label at given index.
get_label_state(label, idx)Get the state label at position
idxof a concept.get_label_states(label)Get the ordered state labels of a concept.
get_logits_idx(labels)Alias for get_slice(labels) when labels is a list.
get_slice(labels)Get slice or indices for concept(s) in the flattened tensor.
get_state_index(label, state)Get the index of a state label for a concept.
get_total_cardinality()Get total cardinality for nested annotations, or number of labels otherwise.
groupby_metadata(key[, layout])Check if metadata contains a specific key for all labels.
has_metadata(key)Check if metadata contains a specific key for all labels.
slice_tensor(tensor, concepts)Extract and concatenate columns for specified concepts.
subset(keep_labels)Return a new Annotations restricted to keep_labels (order follows the order in keep_labels).
to_concept_space()Return a concept-space view: one integer-coded column per concept.
to_dict()Convert to JSON-serializable dictionary.
union_with(other)Attributes
concept_slicesPrecomputed mapping from concept name to slice in flattened tensor.
concept_spaceconceptsAll concepts as
Conceptviews, in axis order.cumulative_cardinalitiesPrecomputed cumulative cardinalities for O(1) slicing.
label_to_indexPrecomputed mapping from concept name to concept-level index.
labels_by_typeMapping from concept type to the ordered list of labels of that type.
shapeAnnotated tensor shape
(B, sum(cardinalities)).sizesum(cardinalities).type_groupsPrecomputed type-based groupings at both concept and logit levels.