Source code for torch_concepts.nn.modules.low.dense_layers

"""Simple fully-connected neural network layers.

This module provides Dense, MLP, and ResidualMLP layers adapted from the 
torch-spatiotemporal library. These layers serve as building blocks for 
neural network architectures in concept-based models.

Reference: https://torch-spatiotemporal.readthedocs.io/en/latest/
"""

import torch
from torch import nn
import torch.nn.functional as F

_torch_activations_dict = {
    'elu': 'ELU',
    'leaky_relu': 'LeakyReLU',
    'prelu': 'PReLU',
    'relu': 'ReLU',
    'rrelu': 'RReLU',
    'selu': 'SELU',
    'celu': 'CELU',
    'gelu': 'GELU',
    'glu': 'GLU',
    'mish': 'Mish',
    'sigmoid': 'Sigmoid',
    'softplus': 'Softplus',
    'tanh': 'Tanh',
    'silu': 'SiLU',
    'swish': 'SiLU',
    'linear': 'Identity'
}

def get_layer_activation(activation):
    """Get PyTorch activation layer class from string name.
    
    Args:
        activation (str or None): Activation function name (case-insensitive).
            Supported: 'elu', 'leaky_relu', 'prelu', 'relu', 'rrelu', 'selu',
            'celu', 'gelu', 'glu', 'mish', 'sigmoid', 'softplus', 'tanh', 
            'silu', 'swish', 'linear'. None returns Identity.
    
    Returns:
        torch.nn.Module: Activation layer class (uninstantiated).
        
    Raises:
        ValueError: If activation name is not recognized.
        
    Example:
        >>> from torch_concepts.nn.modules.low.dense_layers import get_layer_activation
        >>> act_class = get_layer_activation('relu')
        >>> activation = act_class()  # ReLU()
        >>> act_class = get_layer_activation(None)
        >>> activation = act_class()  # Identity()
    """
    if activation is None:
        return nn.Identity
    activation = activation.lower()
    if activation in _torch_activations_dict:
        return getattr(nn, _torch_activations_dict[activation])
    raise ValueError(f"Activation '{activation}' not valid.")



[docs] class Dense(nn.Module): r"""A simple fully-connected layer implementing .. math:: \mathbf{x}^{\prime} = \sigma\left(\boldsymbol{\Theta}\mathbf{x} + \mathbf{b}\right) where :math:`\mathbf{x} \in \mathbb{R}^{d_{in}}, \mathbf{x}^{\prime} \in \mathbb{R}^{d_{out}}` are the input and output features, respectively, :math:`\boldsymbol{\Theta} \in \mathbb{R}^{d_{out} \times d_{in}} \mathbf{b} \in \mathbb{R}^{d_{out}}` are trainable parameters, and :math:`\sigma` is an activation function. Args: in_features (int): Number of input features. out_features (int): Number of output features. activation (str, optional): Activation function to be used. (default: :obj:`'relu'`) dropout (float, optional): The dropout rate. (default: :obj:`0`) bias (bool, optional): If :obj:`True`, then the bias vector is used. (default: :obj:`True`) """
[docs] def __init__(self, in_features: int, out_features: int, activation: str = 'relu', dropout: float = 0., bias: bool = True): super(Dense, self).__init__() self.affinity = nn.Linear(in_features, out_features, bias=bias) self.activation = get_layer_activation(activation)() self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
def reset_parameters(self) -> None: """Reset layer parameters to initial random values.""" self.affinity.reset_parameters() def forward(self, x): """Apply linear transformation, activation, and dropout. Args: x (torch.Tensor): Input tensor of shape (..., in_features). Returns: torch.Tensor: Output tensor of shape (..., out_features). """ out = self.activation(self.affinity(x)) return self.dropout(out)
[docs] class MLP(nn.Module): """Simple Multi-layer Perceptron encoder with optional linear readout. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. output_size (int, optional): Size of the optional readout. n_layers (int, optional): Number of hidden layers. (default: 1) activation (str, optional): Activation function. (default: `relu`) dropout (float, optional): Dropout probability. """
[docs] def __init__(self, input_size, hidden_size, output_size=None, n_layers=1, activation='relu', dropout=0.): super(MLP, self).__init__() layers = [ Dense(in_features=input_size if i == 0 else hidden_size, out_features=hidden_size, activation=activation, dropout=dropout) for i in range(n_layers) ] self.mlp = nn.Sequential(*layers) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None)
def reset_parameters(self) -> None: """Reset all layer parameters to initial random values.""" for module in self.mlp._modules.values(): module.reset_parameters() if self.readout is not None: self.readout.reset_parameters() def forward(self, x): """Forward pass through MLP layers with optional readout. Args: x (torch.Tensor): Input tensor of shape (..., input_size). Returns: torch.Tensor: Output tensor of shape (..., output_size) if readout is defined, else (..., hidden_size). """ out = self.mlp(x) if self.readout is not None: return self.readout(out) return out
[docs] class ResidualMLP(nn.Module): """Multi-layer Perceptron with residual connections. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. output_size (int, optional): Size of the optional readout. n_layers (int, optional): Number of hidden layers. (default: 1) activation (str, optional): Activation function. (default: `relu`) dropout (float, optional): Dropout probability. (default: 0.) parametrized_skip (bool, optional): Whether to use parametrized skip connections for the residuals. """
[docs] def __init__(self, input_size, hidden_size, output_size=None, n_layers=1, activation='relu', dropout=0., parametrized_skip=False): super(ResidualMLP, self).__init__() self.layers = nn.ModuleList([ nn.Sequential( Dense(in_features=input_size if i == 0 else hidden_size, out_features=hidden_size, activation=activation, dropout=dropout), nn.Linear(hidden_size, hidden_size)) for i in range(n_layers) ]) self.skip_connections = nn.ModuleList() for i in range(n_layers): if i == 0 and input_size != output_size: self.skip_connections.append(nn.Linear(input_size, hidden_size)) elif parametrized_skip: self.skip_connections.append( nn.Linear(hidden_size, hidden_size)) else: self.skip_connections.append(nn.Identity()) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None)
def forward(self, x): """Forward pass with residual connections. Args: x (torch.Tensor): Input tensor of shape (..., input_size). Returns: torch.Tensor: Output tensor of shape (..., output_size) if readout is defined, else (..., hidden_size). Note: Each layer applies: x = layer(x) + skip(x), where skip is either Identity, a projection layer, or a parametrized transformation. """ for layer, skip in zip(self.layers, self.skip_connections): x = layer(x) + skip(x) if self.readout is not None: return self.readout(x) return x
[docs] class LinearEmbeddingEncoder(torch.nn.Module): """ Linear encoder that transforms embeddings into a set of embeddings. Applies a single linear projection from ``in_features`` to ``n_embeddings * out_features``, then unflattens the last dimension to ``(n_embeddings, out_features)``. Attributes: out_shape (Tuple[int, int]): Target shape used by ``nn.Unflatten``. encoder (nn.Sequential): ``Linear -> Unflatten`` encoder. Args: in_features (int): Number of input features. out_features (int): Feature dimension of each output embedding. n_embeddings (int, optional): Number of output embeddings. Defaults to ``1``. Example: >>> import torch >>> from torch_concepts.nn import LinearEmbeddingEncoder >>> >>> encoder = LinearEmbeddingEncoder( ... in_features=128, ... out_features=16, ... n_embeddings=5, ... ) >>> embeddings = torch.randn(4, 128) >>> out = encoder(embeddings) >>> out.shape torch.Size([4, 5, 16]) References: Espinosa Zarlenga et al. "Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off", NeurIPS 2022. https://arxiv.org/abs/2209.09056 """
[docs] def __init__( self, in_features: int, out_features: int, n_embeddings: int = 1, ): """ Initialize the linear embedding encoder. Args: in_features: Number of input features. out_features: Dimension of each output embedding. n_embeddings: Number of output embeddings. """ super().__init__() self.out_shape = (n_embeddings, out_features) self.encoder = torch.nn.Sequential( torch.nn.Linear( in_features, n_embeddings * out_features ), torch.nn.Unflatten(-1, self.out_shape) )
def forward(self, x: torch.Tensor): """ Encode into a set of embeddings. Args: x: Input tensor of shape ``(..., in_features)``. Returns: torch.Tensor: Embeddings of shape ``(..., n_embeddings, out_features)``. """ return self.encoder(x)
[docs] class SelectorEmbeddingEncoder(torch.nn.Module): """ Memory-based selector for embeddings with attention mechanism. This module maintains a learnable memory bank of embeddings and uses an attention mechanism to select relevant embeddings based on input. It supports both soft (weighted) and hard (Gumbel-softmax) selection. Attributes: temperature (float): Temperature for softmax/Gumbel-softmax. memory_size (int): Number of memory slots. out_features (int): Feature dimension of each output embedding. memory (nn.Embedding): Learnable memory bank. selector (nn.Sequential): Attention network for memory selection. Args: in_features: Number of input features. out_features: Feature dimension of each output embedding. n_embeddings: Number of output embeddings. Defaults to ``1``. memory_size: Number of memory slots. Defaults to ``2``. temperature: Temperature parameter for selection. Defaults to ``1.0``. *args: Additional arguments for the linear layer. **kwargs: Additional keyword arguments for the linear layer. Example: >>> import torch >>> from torch_concepts.nn import SelectorEmbeddingEncoder >>> >>> # Create memory selector >>> selector = SelectorEmbeddingEncoder( ... in_features=64, ... out_features=32, ... n_embeddings=5, ... memory_size=10, ... temperature=0.5 ... ) >>> >>> # Forward pass with soft selection >>> embeddings = torch.randn(4, 64) # batch_size=4 >>> selected = selector(embeddings, sampling=False) >>> print(selected.shape) torch.Size([4, 5, 32]) >>> >>> # Forward pass with hard selection (Gumbel-softmax) >>> selected_hard = selector(embeddings, sampling=True) >>> print(selected_hard.shape) torch.Size([4, 5, 32]) """
[docs] def __init__( self, in_features: int, out_features: int, n_embeddings: int = 1, memory_size: int = 2, temperature: float = 1.0, *args, **kwargs, ): """ Initialize the memory selector. Args: in_features: Number of input features. out_features: Feature dimension of each output embedding. n_embeddings: Number of output embeddings. Defaults to ``1``. memory_size: Number of memory slots. Defaults to ``2``. temperature: Temperature for selection. Defaults to ``1.0``. *args: Additional arguments for the linear layer. **kwargs: Additional keyword arguments for the linear layer. """ super().__init__() self.temperature = temperature self.memory_size = memory_size self.out_features = out_features self._selector_out_shape = (n_embeddings, memory_size) self._selector_out_dim = torch.tensor(self._selector_out_shape).prod().item() # init memory of embeddings [n_embeddings, memory_size * out_features] self.memory = torch.nn.Embedding(n_embeddings, memory_size * out_features) # init selector [B, n_embeddings] self.selector = torch.nn.Sequential( torch.nn.Linear(in_features, out_features), torch.nn.LeakyReLU(), torch.nn.Linear( out_features, self._selector_out_dim, *args, **kwargs, ), torch.nn.Unflatten(-1, self._selector_out_shape), )
def forward( self, x: torch.Tensor, sampling: bool = False, ) -> torch.Tensor: """ Select memory embeddings based on input. Computes attention weights over memory slots and returns a weighted combination of memory embeddings. Can use soft attention or hard selection via Gumbel-softmax. Args: x: Input tensor of shape ``(..., in_features)``. sampling: If True, use Gumbel-softmax for hard selection; if False, use soft attention (default: False). Returns: torch.Tensor: Selected embeddings of shape ``(..., n_embeddings, out_features)``. """ memory = self.memory.weight.view(-1, self.memory_size, self.out_features) mixing_coeff = self.selector(x) if sampling: mixing_probs = F.gumbel_softmax(mixing_coeff, dim=1, tau=self.temperature, hard=True) else: mixing_probs = torch.softmax(mixing_coeff / self.temperature, dim=1) embeddings = torch.einsum("btm,tme->bte", mixing_probs, memory) # [Batch x Task x Memory] x [Task x Memory x Emb] -> [Batch x Task x Emb] return embeddings