Source code for torch_concepts.nn.modules.mid.inference.torch.independent

"""Independent training inference."""

import logging

from ...models.bayesian_network import BayesianNetwork
from .deterministic import DeterministicInference

logger = logging.getLogger(__name__)


[docs] class IndependentInference(DeterministicInference): """ Independent training inference. This is a convenience subclass of :class:`DeterministicInference` that forces ``p_int=1``, so ground truth concepts are always propagated to downstream predictors during training. Equivalent to ``DeterministicInference(..., p_int=1.0)``. ``activate_before_propagation`` is forwarded to :class:`DeterministicInference`. """
[docs] def __init__(self, pgm: BayesianNetwork, activate_before_propagation: bool = True): super().__init__( pgm, activate_before_propagation=activate_before_propagation, p_int=1.0, )