Source code for torch_concepts.nn.modules.low.dense_layers

"""Simple fully-connected neural network layers.

This module provides Dense, MLP, and ResidualMLP layers adapted from the 
torch-spatiotemporal library. These layers serve as building blocks for 
neural network architectures in concept-based models.

Reference: https://torch-spatiotemporal.readthedocs.io/en/latest/
"""

from torch import nn


_torch_activations_dict = {
    'elu': 'ELU',
    'leaky_relu': 'LeakyReLU',
    'prelu': 'PReLU',
    'relu': 'ReLU',
    'rrelu': 'RReLU',
    'selu': 'SELU',
    'celu': 'CELU',
    'gelu': 'GELU',
    'glu': 'GLU',
    'mish': 'Mish',
    'sigmoid': 'Sigmoid',
    'softplus': 'Softplus',
    'tanh': 'Tanh',
    'silu': 'SiLU',
    'swish': 'SiLU',
    'linear': 'Identity'
}

def get_layer_activation(activation):
    """Get PyTorch activation layer class from string name.
    
    Args:
        activation (str or None): Activation function name (case-insensitive).
            Supported: 'elu', 'leaky_relu', 'prelu', 'relu', 'rrelu', 'selu',
            'celu', 'gelu', 'glu', 'mish', 'sigmoid', 'softplus', 'tanh', 
            'silu', 'swish', 'linear'. None returns Identity.
    
    Returns:
        torch.nn.Module: Activation layer class (uninstantiated).
        
    Raises:
        ValueError: If activation name is not recognized.
        
    Example:
        >>> from torch_concepts.nn.modules.low.dense_layers import get_layer_activation
        >>> act_class = get_layer_activation('relu')
        >>> activation = act_class()  # ReLU()
        >>> act_class = get_layer_activation(None)
        >>> activation = act_class()  # Identity()
    """
    if activation is None:
        return nn.Identity
    activation = activation.lower()
    if activation in _torch_activations_dict:
        return getattr(nn, _torch_activations_dict[activation])
    raise ValueError(f"Activation '{activation}' not valid.")



[docs] class Dense(nn.Module): r"""A simple fully-connected layer implementing .. math:: \mathbf{x}^{\prime} = \sigma\left(\boldsymbol{\Theta}\mathbf{x} + \mathbf{b}\right) where :math:`\mathbf{x} \in \mathbb{R}^{d_{in}}, \mathbf{x}^{\prime} \in \mathbb{R}^{d_{out}}` are the input and output features, respectively, :math:`\boldsymbol{\Theta} \in \mathbb{R}^{d_{out} \times d_{in}} \mathbf{b} \in \mathbb{R}^{d_{out}}` are trainable parameters, and :math:`\sigma` is an activation function. Args: input_size (int): Number of input features. output_size (int): Number of output features. activation (str, optional): Activation function to be used. (default: :obj:`'relu'`) dropout (float, optional): The dropout rate. (default: :obj:`0`) bias (bool, optional): If :obj:`True`, then the bias vector is used. (default: :obj:`True`) """
[docs] def __init__(self, input_size: int, output_size: int, activation: str = 'relu', dropout: float = 0., bias: bool = True): super(Dense, self).__init__() self.affinity = nn.Linear(input_size, output_size, bias=bias) self.activation = get_layer_activation(activation)() self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
[docs] def reset_parameters(self) -> None: """Reset layer parameters to initial random values.""" self.affinity.reset_parameters()
[docs] def forward(self, x): """Apply linear transformation, activation, and dropout. Args: x (torch.Tensor): Input tensor of shape (batch_size, input_size). Returns: torch.Tensor: Output tensor of shape (batch_size, output_size). """ out = self.activation(self.affinity(x)) return self.dropout(out)
[docs] class MLP(nn.Module): """Simple Multi-layer Perceptron encoder with optional linear readout. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. output_size (int, optional): Size of the optional readout. n_layers (int, optional): Number of hidden layers. (default: 1) activation (str, optional): Activation function. (default: `relu`) dropout (float, optional): Dropout probability. """
[docs] def __init__(self, input_size, hidden_size=64, output_size=None, n_layers=1, activation='relu', dropout=0.): super(MLP, self).__init__() layers = [ Dense(input_size=input_size if i == 0 else hidden_size, output_size=hidden_size, activation=activation, dropout=dropout) for i in range(n_layers) ] self.mlp = nn.Sequential(*layers) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None)
[docs] def reset_parameters(self) -> None: """Reset all layer parameters to initial random values.""" for module in self.mlp._modules.values(): module.reset_parameters() if self.readout is not None: self.readout.reset_parameters()
[docs] def forward(self, x): """Forward pass through MLP layers with optional readout. Args: x (torch.Tensor): Input tensor of shape (batch_size, input_size). Returns: torch.Tensor: Output tensor of shape (batch_size, output_size) if readout is defined, else (batch_size, hidden_size). """ out = self.mlp(x) if self.readout is not None: return self.readout(out) return out
[docs] class ResidualMLP(nn.Module): """Multi-layer Perceptron with residual connections. Args: input_size (int): Input size. hidden_size (int): Units in the hidden layers. output_size (int, optional): Size of the optional readout. n_layers (int, optional): Number of hidden layers. (default: 1) activation (str, optional): Activation function. (default: `relu`) dropout (float, optional): Dropout probability. (default: 0.) parametrized_skip (bool, optional): Whether to use parametrized skip connections for the residuals. """
[docs] def __init__(self, input_size, hidden_size, output_size=None, n_layers=1, activation='relu', dropout=0., parametrized_skip=False): super(ResidualMLP, self).__init__() self.layers = nn.ModuleList([ nn.Sequential( Dense(input_size=input_size if i == 0 else hidden_size, output_size=hidden_size, activation=activation, dropout=dropout), nn.Linear(hidden_size, hidden_size)) for i in range(n_layers) ]) self.skip_connections = nn.ModuleList() for i in range(n_layers): if i == 0 and input_size != output_size: self.skip_connections.append(nn.Linear(input_size, hidden_size)) elif parametrized_skip: self.skip_connections.append( nn.Linear(hidden_size, hidden_size)) else: self.skip_connections.append(nn.Identity()) if output_size is not None: self.readout = nn.Linear(hidden_size, output_size) else: self.register_parameter('readout', None)
[docs] def forward(self, x): """Forward pass with residual connections. Args: x (torch.Tensor): Input tensor of shape (batch_size, input_size). Returns: torch.Tensor: Output tensor of shape (batch_size, output_size) if readout is defined, else (batch_size, hidden_size). Note: Each layer applies: x = layer(x) + skip(x), where skip is either Identity, a projection layer, or a parametrized transformation. """ for layer, skip in zip(self.layers, self.skip_connections): x = layer(x) + skip(x) if self.readout is not None: return self.readout(x) return x