Source code for torch_concepts.nn.modules.low.predictors.linear

"""
Linear predictor modules for concept-based models.

This module provides linear prediction layers that transform concept
representations into new concept representations using a linear layer.
"""
import torch

from ..base.layer import BasePredictor
from typing import Callable

from ....functional import prune_linear_layer


[docs] class LinearCC(BasePredictor): """ Linear concept predictor. This predictor transforms input concept endogenous into other concept endogenous using a linear layer followed by activation. Attributes: in_features_endogenous (int): Number of input logit features. out_features (int): Number of output concept features. in_activation (Callable): Activation function for inputs (default: sigmoid). predictor (nn.Sequential): The prediction network. Args: in_features_endogenous: Number of input logit features. out_features: Number of output concept features. in_activation: Activation function to apply to input endogenous (default: torch.sigmoid). Example: >>> import torch >>> from torch_concepts.nn import LinearCC >>> >>> # Create predictor >>> predictor = LinearCC( ... in_features_endogenous=10, ... out_features=5 ... ) >>> >>> # Forward pass >>> in_endogenous = torch.randn(2, 10) # batch_size=2, in_features=10 >>> out_endogenous = predictor(in_endogenous) >>> print(out_endogenous.shape) torch.Size([2, 5]) References: Koh et al. "Concept Bottleneck Models", ICML 2020. https://arxiv.org/pdf/2007.04612 """
[docs] def __init__( self, in_features_endogenous: int, out_features: int, in_activation: Callable = torch.sigmoid ): """ Initialize the probabilistic predictor. Args: in_features_endogenous: Number of input logit features. out_features: Number of output concept features. in_activation: Activation function for inputs (default: torch.sigmoid). """ super().__init__( in_features_endogenous=in_features_endogenous, out_features=out_features, in_activation=in_activation, ) self.predictor = torch.nn.Sequential( torch.nn.Linear( in_features_endogenous, out_features ), torch.nn.Unflatten(-1, (out_features,)), )
[docs] def forward( self, endogenous: torch.Tensor ) -> torch.Tensor: """ Forward pass through the predictor. Args: endogenous: Input endogenous of shape (batch_size, in_features_endogenous). Returns: torch.Tensor: Predicted concept probabilities of shape (batch_size, out_features). """ in_probs = self.in_activation(endogenous) probs = self.predictor(in_probs) return probs
[docs] def prune(self, mask: torch.Tensor): """ Prune input features based on a binary mask. Removes input features where mask is False/0, reducing model complexity. Args: mask: Binary mask of shape (in_features_endogenous,) indicating which features to keep (True/1) or remove (False/0). Example: >>> import torch >>> from torch_concepts.nn import LinearCC >>> >>> predictor = LinearCC(in_features_endogenous=10, out_features=5) >>> >>> # Prune first 3 features >>> mask = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=torch.bool) >>> predictor.prune(mask) >>> >>> # Now only accepts 7 input features >>> endogenous = torch.randn(2, 7) >>> probs = predictor(endogenous) >>> print(probs.shape) torch.Size([2, 5]) """ self.in_features_endogenous = sum(mask.int()) self.predictor[0] = prune_linear_layer(self.predictor[0], mask, dim=0)