torch_concepts.AnnotatedTensor¶
- class AnnotatedTensor(data: Tensor, annotation: Annotations)[source]¶
A tensor annotated along its second axis (axis 1).
Wraps a
torch.Tensortogether with anAnnotationsthat describes the semantics of axis 1. Supports:Label-based slicing — select columns by concept name:
sliced = t["cat", "dog"] # keeps only 'cat' and 'dog' columns sliced = t[["cat", "dog"]] # same via list syntax
Annotation-preserving operations — any tensor operation that leaves the size of axis 1 unchanged automatically returns a new
AnnotatedTensorcarrying the same (or a subset) annotation:t.sum(dim=0) # aggregation over batch → still annotated on axis 1 t.mean(dim=-1) # aggregation over last axis → still annotated on axis 1 t.reshape(8, 3, -1) # reshape that keeps axis-1 size → still annotated
Transparent tensor proxy — all tensor attributes and methods not defined on this class (
shape,dtype,.detach(),.to(), …) are forwarded to the underlying tensor via__getattr__.``torch.*`` function protocol — module-level functions such as
torch.sum(t, dim=0)also propagate the annotation when axis 1 is unchanged.
- Parameters:
data – The underlying tensor. Must have at least 2 dimensions.
annotation – Annotation for axis 1.
annotation.sizemust equaldata.shape[1].
- Raises:
ValueError – If
data.dim() < 2or the annotation size does not matchdata.shape[1].
Example
>>> import torch >>> from torch_concepts import Annotations >>> from torch_concepts.tensor import AnnotatedTensor >>> >>> ann = Annotations(labels=["cat", "dog", "bird"]) >>> t = AnnotatedTensor(torch.rand(4, 3), ann) >>> >>> # Label-based slicing >>> sliced = t["cat", "dog"] >>> sliced.annotation.labels ['cat', 'dog'] >>> sliced.tensor.shape torch.Size([4, 2])
- __init__(data: Tensor, annotation: Annotations)[source]¶
Methods
__init__(data, annotation)split_by_type([concept_type])If
concept_typeis given, return the sub-tensor of concepts ofconcept_type.to(*args, **kwargs)Move/cast the underlying tensor, preserving the annotation.
union_with(*others)Concatenate this tensor with one or more
AnnotatedTensorinstances along axis 1 (the annotated axis), merging their annotations.Attributes
annotationThe
Annotationsdescribing axis 1.deviceDevice of the underlying tensor.
tensorThe underlying
torch.Tensor.