torch_concepts.AnnotatedTensor

class AnnotatedTensor(data: Tensor, annotation: Annotations)[source]

A tensor annotated along its second axis (axis 1).

Wraps a torch.Tensor together with an Annotations that describes the semantics of axis 1. Supports:

  • Label-based slicing — select columns by concept name:

    sliced = t["cat", "dog"]   # keeps only 'cat' and 'dog' columns
    sliced = t[["cat", "dog"]] # same via list syntax
    
  • Annotation-preserving operations — any tensor operation that leaves the size of axis 1 unchanged automatically returns a new AnnotatedTensor carrying the same (or a subset) annotation:

    t.sum(dim=0)    # aggregation over batch → still annotated on axis 1
    t.mean(dim=-1)  # aggregation over last axis → still annotated on axis 1
    t.reshape(8, 3, -1)  # reshape that keeps axis-1 size → still annotated
    
  • Transparent tensor proxy — all tensor attributes and methods not defined on this class (shape, dtype, .detach(), .to(), …) are forwarded to the underlying tensor via __getattr__.

  • ``torch.*`` function protocol — module-level functions such as torch.sum(t, dim=0) also propagate the annotation when axis 1 is unchanged.

Parameters:
  • data – The underlying tensor. Must have at least 2 dimensions.

  • annotation – Annotation for axis 1. annotation.size must equal data.shape[1].

Raises:

ValueError – If data.dim() < 2 or the annotation size does not match data.shape[1].

Example

>>> import torch
>>> from torch_concepts import Annotations
>>> from torch_concepts.tensor import AnnotatedTensor
>>>
>>> ann = Annotations(labels=["cat", "dog", "bird"])
>>> t = AnnotatedTensor(torch.rand(4, 3), ann)
>>>
>>> # Label-based slicing
>>> sliced = t["cat", "dog"]
>>> sliced.annotation.labels
['cat', 'dog']
>>> sliced.tensor.shape
torch.Size([4, 2])
__init__(data: Tensor, annotation: Annotations)[source]

Methods

__init__(data, annotation)

split_by_type([concept_type])

If concept_type is given, return the sub-tensor of concepts of concept_type.

to(*args, **kwargs)

Move/cast the underlying tensor, preserving the annotation.

union_with(*others)

Concatenate this tensor with one or more AnnotatedTensor instances along axis 1 (the annotated axis), merging their annotations.

Attributes

annotation

The Annotations describing axis 1.

device

Device of the underlying tensor.

tensor

The underlying torch.Tensor.