torch_concepts.EmbeddingVariable

class EmbeddingVariable(names: str | List[str], distribution=None, shape: int | Tuple | Size | List | None = None, dist_kwargs: dict | List[dict | None] | None = None, size: int | List[int] | None = None, members: List[str] | None = None)[source]

A non-interpretable embedding variable.

May be observed, latent, or deterministic (via dist.Delta); the engine decides on a per-call basis whether the variable is observed.

__init__(names: str | List[str], distribution=None, shape: int | Tuple | Size | None = None, dist_kwargs: dict | List[dict | None] | None = None, size: int | List[int] | None = None, members: List[str] | None = None)

Methods

__init__(names[, distribution, shape, ...])

column_of(member)

Column span of member within this variable's event (last) dimension.

get_slice(labels)

Get slice or indices for concept(s) in the flattened tensor.

member(name)

A handle to a single member, usable as a parent (an edge to that member only).

Attributes

concept_slices

Precomputed mapping from concept name to slice in flattened tensor.

is_plate

Whether this variable was created with explicit named members.

param_sizes

Per-parameter output sizes for this variable's distribution.

plate

The plate this variable belongs to.

shape

Event shape as a torch.Size, e.g. torch.Size([4]) or torch.Size([3, 4]).

size

math.prod(self.shape).

variable_type

Short string tag identifying the variable kind.