torch_concepts.EmbeddingVariable¶
- class EmbeddingVariable(names: str | List[str], distribution=None, shape: int | Tuple | Size | List | None = None, dist_kwargs: dict | List[dict | None] | None = None, size: int | List[int] | None = None, members: List[str] | None = None)[source]¶
A non-interpretable embedding variable.
May be observed, latent, or deterministic (via
dist.Delta); the engine decides on a per-call basis whether the variable is observed.- __init__(names: str | List[str], distribution=None, shape: int | Tuple | Size | None = None, dist_kwargs: dict | List[dict | None] | None = None, size: int | List[int] | None = None, members: List[str] | None = None)¶
Methods
__init__(names[, distribution, shape, ...])column_of(member)Column span of
memberwithin this variable's event (last) dimension.get_slice(labels)Get slice or indices for concept(s) in the flattened tensor.
member(name)A handle to a single member, usable as a parent (an edge to that member only).
Attributes
concept_slicesPrecomputed mapping from concept name to slice in flattened tensor.
is_plateWhether this variable was created with explicit named members.
param_sizesPer-parameter output sizes for this variable's distribution.
plateThe plate this variable belongs to.
shapeEvent shape as a
torch.Size, e.g.torch.Size([4])ortorch.Size([3, 4]).sizemath.prod(self.shape).variable_typeShort string tag identifying the variable kind.