Source code for torch_concepts.nn.modules.low.encoders.selector

"""
Memory selector module for memory selection.

This module provides a memory-based selector that learns to attend over
a memory bank of concept exogenous.
"""
import numpy as np
import torch
import torch.nn.functional as F


from ..base.layer import BaseEncoder


[docs] class SelectorZU(BaseEncoder): """ Memory-based selector for concept exogenous with attention mechanism. This module maintains a learnable memory bank of exogenous and uses an attention mechanism to select relevant exogenous based on input. It supports both soft (weighted) and hard (Gumbel-softmax) selection. Attributes: temperature (float): Temperature for softmax/Gumbel-softmax. memory_size (int): Number of memory slots per concept. exogenous_size (int): Dimension of each memory exogenous. memory (nn.Embedding): Learnable memory bank. selector (nn.Sequential): Attention network for memory selection. Args: in_features: Number of input latent features. memory_size: Number of memory slots per concept. exogenous_size: Dimension of each memory exogenous. out_features: Number of output concepts. temperature: Temperature parameter for selection (default: 1.0). *args: Additional arguments for the linear layer. **kwargs: Additional keyword arguments for the linear layer. Example: >>> import torch >>> from torch_concepts.nn import SelectorZU >>> >>> # Create memory selector >>> selector = SelectorZU( ... in_features=64, ... memory_size=10, ... exogenous_size=32, ... out_features=5, ... temperature=0.5 ... ) >>> >>> # Forward pass with soft selection >>> latent = torch.randn(4, 64) # batch_size=4 >>> selected = selector(latent, sampling=False) >>> print(selected.shape) torch.Size([4, 5, 32]) >>> >>> # Forward pass with hard selection (Gumbel-softmax) >>> selected_hard = selector(latent, sampling=True) >>> print(selected_hard.shape) torch.Size([4, 5, 32]) References: Debot et al. "Interpretable Concept-Based Memory Reasoning", NeurIPS 2024. https://arxiv.org/abs/2407.15527 """
[docs] def __init__( self, in_features: int, memory_size : int, exogenous_size: int, out_features: int, temperature: float = 1.0, *args, **kwargs, ): """ Initialize the memory selector. Args: in_features: Number of input latent features. memory_size: Number of memory slots per concept. exogenous_size: Dimension of each memory exogenous. out_features: Number of output concepts. temperature: Temperature for selection (default: 1.0). *args: Additional arguments for the linear layer. **kwargs: Additional keyword arguments for the linear layer. """ super().__init__( in_features=in_features, out_features=out_features, ) self.temperature = temperature self.memory_size = memory_size self.exogenous_size = exogenous_size self._annotation_out_features = out_features self._exogenous_out_features = memory_size * exogenous_size self._selector_out_shape = (self._annotation_out_features, memory_size) self._selector_out_features = np.prod(self._selector_out_shape).item() # init memory of exogenous [out_features, memory_size * exogenous_size] self.memory = torch.nn.Embedding(self._annotation_out_features, self._exogenous_out_features) # init selector [B, out_features] self.selector = torch.nn.Sequential( torch.nn.Linear(in_features, exogenous_size), torch.nn.LeakyReLU(), torch.nn.Linear( exogenous_size, self._selector_out_features, *args, **kwargs, ), torch.nn.Unflatten(-1, self._selector_out_shape), )
[docs] def forward( self, input: torch.Tensor = None, sampling: bool = False, ) -> torch.Tensor: """ Select memory exogenous based on input input. Computes attention weights over memory slots and returns a weighted combination of memory exogenous. Can use soft attention or hard selection via Gumbel-softmax. Args: input: Input latent of shape (batch_size, in_features). sampling: If True, use Gumbel-softmax for hard selection; if False, use soft attention (default: False). Returns: torch.Tensor: Selected exogenous of shape (batch_size, out_features, exogenous_size). """ memory = self.memory.weight.view(-1, self.memory_size, self.exogenous_size) mixing_coeff = self.selector(input) if sampling: mixing_probs = F.gumbel_softmax(mixing_coeff, dim=1, tau=self.temperature, hard=True) else: mixing_probs = torch.softmax(mixing_coeff / self.temperature, dim=1) exogenous = torch.einsum("btm,tme->bte", mixing_probs, memory) # [Batch x Task x Memory] x [Task x Memory x Emb] -> [Batch x Task x Emb] return exogenous