Source code for torch_concepts.nn.modules.low.priors
"""Prior modules for root CPDs.
Two types are provided:
* :class:`LearnablePrior` — a trainable ``nn.Parameter`` (optimised during
training);
* :class:`FixedPrior` — values known a priori, held as a non-learnable buffer
(never updated by the optimizer).
The output of the prior modules is an *unconstrained* parameter.
If the parameter that has to be learned is a probability, an activation function (e.g. ``torch.sigmoid``)
must be applied to the output of the prior module to map it to the correct domain.
"""
from __future__ import annotations
from typing import Sequence, Union
import torch
import torch.nn as nn
[docs]
class LearnablePrior(nn.Module):
"""Learnable parameter module for root (parent-less) CPDs.
Wraps a single ``nn.Parameter`` of the requested ``size`` and returns it on
``forward()``, making it a drop-in parametrization for a root CPD. The
parameter is randomly initialised from a standard normal distribution.
Parameters
----------
size : int
Length of the parameter vector. This must match the per-parameter size
the target distribution expects (e.g. ``1`` for a Bernoulli ``logits``,
``k`` for a ``k``-way OneHotCategorical ``logits``).
"""
[docs]
def __init__(self, size: int) -> None:
super().__init__()
self.param = nn.Parameter(torch.randn(size))
def forward(self) -> torch.Tensor:
return self.param
[docs]
class FixedPrior(nn.Module):
"""Non-learnable prior module holding parameter values known a priori.
Mirrors :class:`LearnablePrior` but the parameter is **fixed**: the supplied
values are registered as a buffer, so they carry no gradient and are never
touched by the optimizer, while still moving with ``.to(device)`` and being
saved in the module ``state_dict``. Use this when a root distribution's
parameter is known in advance (e.g. a fixed prior probability) rather than
learned.
Parameters
----------
values : torch.Tensor or sequence of float
The fixed parameter values. Their length must match the per-parameter
size the target distribution expects. Coerced to a 1-D float tensor.
"""
[docs]
def __init__(self, values: Union[torch.Tensor, Sequence[float]]) -> None:
super().__init__()
if isinstance(values, torch.Tensor):
tensor = values.detach().clone().float()
else:
tensor = torch.as_tensor(values, dtype=torch.float)
self.register_buffer("values", tensor)
def forward(self) -> torch.Tensor:
return self.values