Source code for torch_concepts.nn.modules.low.encoders.linear

"""
Linear encoder modules for concept prediction from embeddings.

These modules provide encoder layers that transform embeddings into concept representations.
"""
from typing import Union

import torch

from torch_concepts import Annotations
from ..base.layer import BaseConceptLayer


[docs] class LinearEmbeddingToConcept(BaseConceptLayer): """ Encoder that predicts concept representations from embeddings. This encoder transforms an embedding into concept representations using a linear layer. Attributes: in_embeddings (int): Number of input embedding features. out_concepts (int): Number of output concept representations. Args: in_embeddings: Number of input embedding features. out_concepts: Number of output concept representations. *args: Additional arguments for torch.nn.Linear. **kwargs: Additional keyword arguments for torch.nn.Linear. Example: >>> import torch >>> from torch_concepts.nn import LinearEmbeddingToConcept >>> >>> encoder = LinearEmbeddingToConcept( ... in_embeddings=128, ... out_concepts=10 ... ) >>> embeddings = torch.randn(4, 128) # batch_size=4, embedding_dim=128 >>> concepts = encoder(embeddings) >>> print(concepts.shape) torch.Size([4, 10]) References: Koh et al. "Concept Bottleneck Models", ICML 2020. https://arxiv.org/pdf/2007.04612 """
[docs] def __init__( self, in_embeddings: Union[int, Annotations], out_concepts: Union[int, Annotations], *args, **kwargs, ): """ Initialize the encoder. Args: in_embeddings: Number of input embedding features. out_concepts: Number of output concept representations. *args: Additional arguments for torch.nn.Linear. **kwargs: Additional keyword arguments for torch.nn.Linear. """ super().__init__( in_embeddings=in_embeddings, out_concepts=out_concepts, ) # (..., in_embeddings) -> (..., out_concepts) self.encoder = torch.nn.Linear( self.in_embeddings_shape, self.out_concepts_shape, *args, **kwargs, )
def forward( self, embeddings: torch.Tensor, ) -> torch.Tensor: """ Encode embeddings into concept representations. Args: embeddings: Input embeddings of shape (..., in_embeddings). Returns: torch.Tensor: Concept representations of shape (..., out_concepts). """ return self.encoder(embeddings)