torch_concepts.nn.functional.grouped_concept_exogenous_mixture

grouped_concept_exogenous_mixture(c_emb: Tensor, c_scores: Tensor, groups: list[int]) Tensor[source]

Vectorized version of grouped concept exogenous mixture.

Extends to handle grouped concepts where some groups may contain multiple related concepts. Adapted from “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off” (Espinosa Zarlenga et al., 2022).

Parameters:
  • c_emb – Concept exogenous of shape (B, n_concepts, emb_size).

  • c_scores – Concept scores of shape (B, sum(groups)).

  • groups – List of group sizes (e.g., [3, 4] for two groups).

Returns:

Mixed exogenous of shape (B, len(groups), emb_size // 2).

Return type:

Tensor

Raises:

References

Espinosa Zarlenga et al. “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off”, NeurIPS 2022. https://arxiv.org/abs/2209.09056

Example

>>> import torch
>>> from torch_concepts.nn.functional import grouped_concept_exogenous_mixture
>>>
>>> # 10 concepts in 3 groups: [3, 4, 3]
>>> # Embedding size = 20 (must be even)
>>> batch_size = 4
>>> n_concepts = 10
>>> emb_size = 20
>>> groups = [3, 4, 3]
>>>
>>> # Generate random latent and scores
>>> c_emb = torch.randn(batch_size, n_concepts, emb_size)
>>> c_scores = torch.rand(batch_size, n_concepts)  # Probabilities
>>>
>>> # Apply grouped mixture
>>> mixed = grouped_concept_exogenous_mixture(c_emb, c_scores, groups)
>>> print(mixed.shape)  # torch.Size([4, 3, 10])
>>> # Output shape: (batch_size, n_groups, emb_size // 2)
>>>
>>> # Singleton groups use two-half mixture
>>> # Multi-concept groups use weighted average of base exogenous