Source code for torch_concepts.nn.modules.low.predictors.call
import torch
from ..base.layer import BasePredictor
from typing import Callable
[docs]
class CallableCC(BasePredictor):
"""
A predictor that applies a custom callable function to concept representations.
This predictor allows flexible task prediction by accepting any callable function
that operates on concept representations. It optionally includes learnable stochastic
bias parameters (mean and standard deviation) that are added to the output using
the reparameterization trick for gradient-based learning.
The module can be used to write custom layers for standard Structural Causal Models (SCMs).
Args:
func: Callable function that takes concept probabilities and returns task predictions.
Should accept a tensor of shape (batch_size, n_concepts) and return
a tensor of shape (batch_size, out_features).
in_activation: Activation function to apply to input endogenous before passing to func.
Default is identity (lambda x: x).
use_bias: Whether to add learnable stochastic bias to the output. Default is True.
init_bias_mean: Initial value for the bias mean parameter. Default is 0.0.
init_bias_std: Initial value for the bias standard deviation. Default is 0.01.
min_std: Minimum standard deviation floor for numerical stability. Default is 1e-6.
Examples:
>>> import torch
>>> from torch_concepts.nn import CallableCC
>>>
>>> # Generate sample data
>>> batch_size = 32
>>> n_concepts = 3
>>> endogenous = torch.randn(batch_size, n_concepts)
>>>
>>> # Define a polynomial function with fixed weights for 3 inputs, 2 outputs
>>> def quadratic_predictor(probs):
... c0, c1, c2 = probs[:, 0:1], probs[:, 1:2], probs[:, 2:3]
... output1 = 0.5*c0**2 + 1.0*c1**2 + 1.5*c2
... output2 = 2.0*c0 - 1.0*c1**2 + 0.5*c2**3
... return torch.cat([output1, output2], dim=1)
>>>
>>> predictor = CallableCC(
... func=quadratic_predictor,
... use_bias=True
... )
>>> predictions = predictor(endogenous)
>>> print(predictions.shape) # torch.Size([32, 2])
References
Pearl, J. "Causality", Cambridge University Press (2009).
"""
[docs]
def __init__(
self,
func: Callable,
in_activation: Callable = lambda x: x,
use_bias : bool = True,
init_bias_mean: float = 0.0,
init_bias_std: float = 0.01,
min_std: float = 1e-6
):
super().__init__(
in_features_endogenous=-1,
out_features=-1,
in_activation=in_activation,
)
self.use_bias = use_bias
self.min_std = float(min_std)
self.func = func
# Learnable distribution params for the stochastic bias (scalar, broadcasts to (B, Y))
if self.use_bias:
self.bias_mean = torch.nn.Parameter(torch.tensor(float(init_bias_mean)))
# raw_std is unconstrained; softplus(raw_std) -> positive std
# initialize so that softplus(raw_std) ~= init_bias_std
init_raw_std = torch.log(torch.exp(torch.tensor(float(init_bias_std))) - 1.0).item()
self.bias_raw_std = torch.nn.Parameter(torch.tensor(init_raw_std))
else:
# Keep attributes for shape/device consistency even if unused
self.register_buffer("bias_mean", torch.tensor(0.0))
self.register_buffer("bias_raw_std", torch.tensor(0.0))
def _bias_std(self) -> torch.Tensor:
"""
Compute the bias standard deviation using softplus activation.
Returns:
torch.Tensor: Positive standard deviation value with minimum floor applied.
"""
# softplus to ensure positivity; add small floor for stability
return torch.nn.functional.softplus(self.bias_raw_std) + self.min_std
[docs]
def forward(
self,
endogenous: torch.Tensor,
*args,
**kwargs
) -> torch.Tensor:
in_probs = self.in_activation(endogenous)
out_endogenous = self.func(in_probs, *args, **kwargs)
if self.use_bias:
# Reparameterized sampling so mean/std are learnable
eps = torch.randn_like(out_endogenous) # ~ N(0,1)
std = self._bias_std().to(out_endogenous.dtype).to(out_endogenous.device) # scalar -> broadcast
mean = self.bias_mean.to(out_endogenous.dtype).to(out_endogenous.device) # scalar -> broadcast
out_endogenous = out_endogenous + mean + std * eps
return out_endogenous