Source code for torch_concepts.nn.modules.low.predictors.mix

import torch
import numpy as np

from torch_concepts import Annotations
from ..base.layer import BaseConceptLayer
from ....functional import grouped_concept_exogenous_mixture, replace_expand_cols
from typing import List, Union


[docs] class MixConceptEmbeddingToConcept(BaseConceptLayer): """ Concept predictor that mixes concept activations with embeddings. This predictor implements the Concept Embedding Model (CEM) task predictor that combines concept activations with learned embeddings using a mixture operation. Main reference: "Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off" (Espinosa Zarlenga et al., NeurIPS 2022). Attributes: in_concepts (int): Number of input concepts. in_embeddings (int): Number of embedding features. out_concepts (int): Number of output concepts. cardinalities (List[int]): Cardinalities for grouped concepts. predictor (nn.Module): Linear predictor module. Args: in_concepts: Number of input concepts. in_embeddings: Number of embedding features (must be even). out_concepts: Number of output concepts. cardinalities: List of concept group cardinalities. Required — must sum to ``in_concepts``. Example: >>> import torch >>> from torch_concepts.nn import MixConceptEmbeddingToConcept >>> from torch_concepts import Annotations >>> >>> # Create predictor: 3 concepts (cardinalities 3, 4, 3), 10 embedding dims, 2 outputs >>> in_ann = Annotations(labels=['color', 'shape', 'size'], cardinalities=[3, 4, 3]) >>> predictor = MixConceptEmbeddingToConcept( ... in_concepts=in_ann, ... in_embeddings=10, ... out_concepts=2, ... ) >>> >>> # Generate random inputs >>> concepts = torch.randn(4, 10) # batch_size=4, total logits (3+4+3=10) >>> embeddings = torch.randn(4, 10, 10) # (batch, total_cardinality, emb_size) >>> >>> # Forward pass >>> output = predictor(concepts=concepts, embeddings=embeddings) >>> print(output.shape) torch.Size([4, 2]) References: Espinosa Zarlenga et al. "Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off", NeurIPS 2022. https://arxiv.org/abs/2209.09056 """
[docs] def __init__( self, in_concepts: Annotations, in_embeddings: Union[int, Annotations], out_concepts: Union[int, Annotations], **kwargs, ): super().__init__( in_concepts=in_concepts, in_embeddings=in_embeddings, out_concepts=out_concepts, ) # find positions of concepts with cardinality 1 for Bernoulli to Categorical splitting self.cardinalities_expanded = torch.tensor(in_concepts.cardinalities) self.binary_mask = torch.from_numpy(np.array(in_concepts.types) != 'continuous') cumsum = torch.cumsum(self.cardinalities_expanded, dim=0) start_positions = cumsum - self.cardinalities_expanded bernoulli_mask = self.cardinalities_expanded == 1 & self.binary_mask self.mask_cardinality_1 = start_positions[bernoulli_mask] self.cardinalities_expanded[bernoulli_mask] = 2 self.bernoulli_to_categorical_embedding_splitter = torch.nn.Sequential( torch.nn.Linear(self.in_embeddings_shape, self.in_embeddings_shape*2), torch.nn.LeakyReLU(), torch.nn.Unflatten(-1, (-1, self.in_embeddings_shape)), ) self.predictor = torch.nn.Linear( self.in_embeddings_shape * len(in_concepts.cardinalities), self.out_concepts_shape, )
def _mix( self, concepts: torch.Tensor, embeddings: torch.Tensor, ) -> torch.Tensor: """Preprocess inputs and compute per-group mixed embeddings. Handles the Bernoulli→Categorical expansion for cardinality-1 concepts and returns ``c_mix`` of shape ``(batch, n_groups, in_embeddings)``. Subclasses can call this and only vary the final aggregation step. """ if len(self.mask_cardinality_1) > 0: embeddings_split = self.bernoulli_to_categorical_embedding_splitter(embeddings[:, self.mask_cardinality_1]) concepts_split = torch.cat([ concepts[:, self.mask_cardinality_1[:, None]], 1 - concepts[:, self.mask_cardinality_1[:, None]], ], dim=-1) embeddings = replace_expand_cols(embeddings, self.mask_cardinality_1, embeddings_split) concepts = replace_expand_cols(concepts, self.mask_cardinality_1, concepts_split) return grouped_concept_exogenous_mixture( embeddings, concepts, groups=list(self.cardinalities_expanded), ) def forward( self, concepts: torch.Tensor, embeddings: torch.Tensor ) -> torch.Tensor: """ Forward pass through the predictor. Args: concepts: Concept activations of shape ``(batch_size, in_concepts)``. embeddings: Concept embeddings of shape ``(batch_size, in_concepts, in_embeddings)``. Returns: torch.Tensor: Output concepts of shape (batch_size, out_concepts). """ # For concepts with cardinality 1, split the Bernoulli probability into a categorical distribution c_mix = self._mix(concepts, embeddings) # (batch, n_groups, in_embeddings) c_mix = c_mix.flatten(start_dim=1) # (batch, n_groups * in_embeddings) return self.predictor(c_mix)
class MixSumConceptEmbeddingToConcept(MixConceptEmbeddingToConcept): """Like :class:`MixConceptEmbeddingToConcept` but aggregates group embeddings by **summing** across groups instead of flattening. The predictor therefore maps ``(batch, in_embeddings)`` → ``(batch, out_concepts)`` rather than ``(batch, n_groups × in_embeddings)`` → ``(batch, out_concepts)``, which makes it group-count invariant and more parameter-efficient. """ def __init__( self, in_concepts: int, in_embeddings: int, out_concepts: int, cardinalities: list[int] = None, bias: bool = True, **kwargs ): # FIXME: update to use Annotations for in_concepts and out_concepts, # and remove cardinalities if cardinalities is None: cardinalities = [1] * in_concepts n_groups = len(cardinalities) types = ['binary' if c == 1 else 'categorical' for c in cardinalities] annotation = Annotations( labels=[f"c{i}" for i in range(n_groups)], cardinalities=cardinalities, types=types, ) super().__init__( in_concepts=annotation, in_embeddings=in_embeddings, out_concepts=out_concepts, **kwargs, ) self.predictor = torch.nn.Linear(in_embeddings, out_concepts, bias=bias) def forward(self, concepts: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: c_mix = self._mix(concepts, embeddings) # same as CEM-layer (batch, n_groups, in_embeddings) c_mix = c_mix.sum(dim=1) # (batch, in_embeddings) return self.predictor(c_mix)