torch_concepts.nn.ModelOutput

class ModelOutput(params: ~typing.Dict[str, ~typing.Dict[str, ~torch.Tensor]] = <factory>, guide_params: ~typing.Dict[str, ~typing.Dict[str, ~torch.Tensor]] = <factory>, samples: ~typing.Dict[str, ~torch.Tensor] = <factory>, probabilities: ~torch.Tensor | None = None, logits: ~torch.Tensor | None = None, target: ~torch.Tensor | None = None, extra: ~typing.Dict[str, ~torch.Tensor] | None = None)[source]

Structured output from a high-level model’s forward() method.

params

Per-variable named parameter tensors of the model-side distribution (e.g. {'c': {'probs': ...}}).

Type:

dict[str, ParamDict]

guide_params

Per-latent named parameter tensors of the variational guide.

Type:

dict[str, ParamDict]

samples

Per-variable sampled values.

Type:

dict[str, torch.Tensor]

probabilities

Joint conditional probabilities for a fully realised query batch.

Type:

torch.Tensor or None

__init__(params: ~typing.Dict[str, ~typing.Dict[str, ~torch.Tensor]] = <factory>, guide_params: ~typing.Dict[str, ~typing.Dict[str, ~torch.Tensor]] = <factory>, samples: ~typing.Dict[str, ~torch.Tensor] = <factory>, probabilities: ~torch.Tensor | None = None, logits: ~torch.Tensor | None = None, target: ~torch.Tensor | None = None, extra: ~typing.Dict[str, ~torch.Tensor] | None = None) None

Methods

__init__([params, guide_params, samples, ...])

Attributes