torch_concepts.nn.functional.grouped_concept_exogenous_mixture¶
- grouped_concept_exogenous_mixture(c_emb: Tensor, c_scores: Tensor, groups: list[int]) Tensor[source]¶
Vectorized version of grouped concept exogenous mixture.
Extends to handle grouped concepts where some groups may contain multiple related concepts. Adapted from “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off” (Espinosa Zarlenga et al., 2022).
- Parameters:
c_emb – Concept exogenous of shape (B, n_concepts, emb_size).
c_scores – Concept scores of shape (B, sum(groups)).
groups – List of group sizes (e.g., [3, 4] for two groups).
- Returns:
Mixed exogenous of shape (B, len(groups), emb_size // 2).
- Return type:
Tensor
- Raises:
AssertionError – If group sizes don’t sum to n_concepts.
AssertionError – If exogenous dimension is not even.
References
Espinosa Zarlenga et al. “Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off”, NeurIPS 2022. https://arxiv.org/abs/2209.09056
Example
>>> import torch >>> from torch_concepts.nn.functional import grouped_concept_exogenous_mixture >>> >>> # 10 concepts in 3 groups: [3, 4, 3] >>> # Embedding size = 20 (must be even) >>> batch_size = 4 >>> n_concepts = 10 >>> emb_size = 20 >>> groups = [3, 4, 3] >>> >>> # Generate random latent and scores >>> c_emb = torch.randn(batch_size, n_concepts, emb_size) >>> c_scores = torch.rand(batch_size, n_concepts) # Probabilities >>> >>> # Apply grouped mixture >>> mixed = grouped_concept_exogenous_mixture(c_emb, c_scores, groups) >>> print(mixed.shape) # torch.Size([4, 3, 10]) >>> # Output shape: (batch_size, n_groups, emb_size // 2) >>> >>> # Singleton groups use two-half mixture >>> # Multi-concept groups use weighted average of base exogenous