Source code for torch_concepts.nn.modules.low.base.layer

"""
Base layer classes for concept-based neural networks.

This module provides abstract base classes for building concept layers,
including encoders and predictors.
"""
from typing import Callable

import torch

from abc import ABC


[docs] class BaseConceptLayer(ABC, torch.nn.Module): """ Abstract base class for concept layers. This class provides the foundation for all concept-based layers, defining the interface and basic structure for concept encoders and predictors. Attributes: in_features_endogenous (int): Number of input logit features. in_features (int): Number of input latent features. in_features_exogenous (int): Number of exogenous input features. out_features (int): Number of output features. Args: out_features: Number of output features. in_features_endogenous: Number of input logit features (optional). in_features: Number of input latent features (optional). in_features_exogenous: Number of exogenous input features (optional). Example: >>> import torch >>> from torch_concepts.nn import BaseConceptLayer >>> >>> # Create a custom concept layer >>> class MyConceptLayer(BaseConceptLayer): ... def __init__(self, out_features, in_features_endogenous): ... super().__init__( ... out_features=out_features, ... in_features_endogenous=in_features_endogenous ... ) ... self.linear = torch.nn.Linear(in_features_endogenous, out_features) ... ... def forward(self, endogenous): ... return torch.sigmoid(self.linear(endogenous)) >>> >>> # Example usage >>> layer = MyConceptLayer(out_features=5, in_features_endogenous=10) >>> >>> # Generate random input >>> endogenous = torch.randn(2, 10) # batch_size=2, in_features=10 >>> >>> # Forward pass >>> output = layer(endogenous) >>> print(output.shape) # torch.Size([2, 5]) """
[docs] def __init__( self, out_features: int, in_features_endogenous: int = None, in_features: int = None, in_features_exogenous: int = None, *args, **kwargs, ): super().__init__() self.in_features_endogenous = in_features_endogenous self.in_features = in_features self.in_features_exogenous = in_features_exogenous self.out_features = out_features
[docs] def forward( self, *args, **kwargs, ) -> torch.Tensor: """ Forward pass through the concept layer. Must be implemented by subclasses. Returns: torch.Tensor: Output tensor. Raises: NotImplementedError: This is an abstract method. """ raise NotImplementedError
[docs] class BaseEncoder(BaseConceptLayer): """ Abstract base class for concept encoder layers. Encoders transform input features (latent or exogenous variables) into concept representations. Args: out_features: Number of output concept features. in_features: Number of input latent features (optional). in_features_exogenous: Number of exogenous input features (optional). Example: >>> import torch >>> from torch_concepts.nn import BaseEncoder >>> >>> # Create a custom encoder >>> class MyEncoder(BaseEncoder): ... def __init__(self, out_features, in_features): ... super().__init__( ... out_features=out_features, ... in_features=in_features ... ) ... self.net = torch.nn.Sequential( ... torch.nn.Linear(in_features, 128), ... torch.nn.ReLU(), ... torch.nn.Linear(128, out_features) ... ) ... ... def forward(self, latent): ... return self.net(latent) >>> >>> # Example usage >>> encoder = MyEncoder(out_features=10, in_features=784) >>> >>> # Generate random image latent (e.g., flattened MNIST) >>> x = torch.randn(4, 784) # batch_size=4, pixels=784 >>> >>> # Encode to concepts >>> concepts = encoder(x) >>> print(concepts.shape) # torch.Size([4, 10]) """
[docs] def __init__(self, out_features: int, in_features: int = None, in_features_exogenous: int = None): super().__init__( in_features_endogenous=None, in_features=in_features, in_features_exogenous=in_features_exogenous, out_features=out_features )
[docs] class BasePredictor(BaseConceptLayer): """ Abstract base class for concept predictor layers. Predictors take concept representations (plus latent or exogenous variables) and predict other concept representations. Attributes: in_activation (Callable): Activation function for input (default: sigmoid). Args: out_features: Number of output concept features. in_features_endogenous: Number of input logit features. in_features: Number of input latent features (optional). in_features_exogenous: Number of exogenous input features (optional). in_activation: Activation function for input (default: torch.sigmoid). Example: >>> import torch >>> from torch_concepts.nn import BasePredictor >>> >>> # Create a custom predictor >>> class MyPredictor(BasePredictor): ... def __init__(self, out_features, in_features_endogenous): ... super().__init__( ... out_features=out_features, ... in_features_endogenous=in_features_endogenous, ... in_activation=torch.sigmoid ... ) ... self.linear = torch.nn.Linear(in_features_endogenous, out_features) ... ... def forward(self, endogenous): ... # Apply activation to input endogenous ... probs = self.in_activation(endogenous) ... # Predict next concepts ... return self.linear(probs) >>> >>> # Example usage >>> predictor = MyPredictor(out_features=3, in_features_endogenous=10) >>> >>> # Generate random concept endogenous >>> concept_endogenous = torch.randn(4, 10) # batch_size=4, n_concepts=10 >>> >>> # Predict task labels from concepts >>> task_endogenous = predictor(concept_endogenous) >>> print(task_endogenous.shape) # torch.Size([4, 3]) >>> >>> # Get task predictions >>> task_probs = torch.sigmoid(task_endogenous) >>> print(task_probs.shape) # torch.Size([4, 3]) """
[docs] def __init__(self, out_features: int, in_features_endogenous: int, in_features: int = None, in_features_exogenous: int = None, in_activation: Callable = torch.sigmoid): super().__init__( in_features_endogenous=in_features_endogenous, in_features=in_features, in_features_exogenous=in_features_exogenous, out_features=out_features, ) self.in_activation = in_activation
[docs] def prune(self, mask: torch.Tensor): """ Prune the predictor by removing connections based on the given mask. This method removes unnecessary connections in the predictor layer based on a binary mask, which can help reduce model complexity and improve interpretability. Args: mask: A binary mask indicating which connections to keep (1) or remove (0). Raises: NotImplementedError: Must be implemented by subclasses that support pruning. """ raise NotImplementedError(f"Pruning is not yet supported for {self.__class__.__name__}.")