Source code for torch_concepts.nn.modules.utils

from typing import Optional, Dict, Union, List, Any
import warnings
import logging
import torch

from ...annotations import Annotations

logger = logging.getLogger(__name__)

[docs] class GroupConfig: """Container for storing classes organized by concept type groups. This class acts as a convenient wrapper around a dictionary that maps concept type names to their corresponding classes or configurations. Attributes: _config (Dict[str, Any]): Internal dictionary storing the configuration. Args: binary: Configuration for binary concepts. If provided alone, applies to all concept types. categorical: Configuration for categorical concepts. continuous: Configuration for continuous concepts. **kwargs: Additional group configurations. Example: >>> from torch_concepts.nn.modules.utils import GroupConfig >>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss >>> loss_config = GroupConfig(binary=CrossEntropyLoss()) >>> # Equivalent to: {'binary': CrossEntropyLoss()} >>> >>> # Different configurations per type >>> loss_config = GroupConfig( ... binary=BCEWithLogitsLoss(), ... categorical=CrossEntropyLoss(), ... continuous=MSELoss() ... ) >>> >>> # Access configurations >>> default_loss = MSELoss() >>> binary_loss = loss_config['binary'] >>> loss_config.get('continuous', default_loss) MSELoss() >>> >>> # Check what's configured >>> 'binary' in loss_config True >>> list(loss_config.keys()) ['binary', 'categorical', 'continuous'] """
[docs] def __init__( self, binary: Optional[Any] = None, categorical: Optional[Any] = None, continuous: Optional[Any] = None, **kwargs ): self._config: Dict[str, Any] = {} # Build config from all provided arguments if binary is not None: self._config['binary'] = binary if categorical is not None: self._config['categorical'] = categorical if continuous is not None: self._config['continuous'] = continuous # Add any additional groups self._config.update(kwargs)
def __getitem__(self, key: str) -> Any: """Get configuration for a specific group.""" return self._config[key] def __setitem__(self, key: str, value: Any) -> None: """Set configuration for a specific group.""" self._config[key] = value def __contains__(self, key: str) -> bool: """Check if a group is configured.""" return key in self._config def __len__(self) -> int: """Return number of configured groups.""" return len(self._config) def __repr__(self) -> str: """String representation.""" return f"GroupConfig({self._config})" def get(self, key: str, default: Any = None) -> Any: """Get configuration for a group with optional default.""" return self._config.get(key, default) def keys(self): """Return configured group names.""" return self._config.keys() def values(self): """Return configured values.""" return self._config.values() def items(self): """Return (group, config) pairs.""" return self._config.items() def to_dict(self) -> Dict[str, Any]: """Convert to plain dictionary.""" return self._config.copy() @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> 'GroupConfig': """Create GroupConfig from dictionary. Args: config_dict: Dictionary mapping group names to configurations. Returns: GroupConfig instance. """ return cls(**config_dict)
def check_collection(annotations: Annotations, collection: GroupConfig, collection_name: str) -> GroupConfig: """Validate loss/metric configurations against concept annotations. Ensures that: 1. Required losses/metrics are present for each concept type 2. Annotation structure (nested vs dense) matches concept types 3. Unused configurations are warned about Args: annotations (Annotations): Concept annotations with metadata. collection (GroupConfig): Configuration object with losses or metrics. collection_name (str): Either 'loss' or 'metrics' for error messages. Returns: GroupConfig: Filtered configuration containing only the needed concept types. Raises: ValueError: If validation fails (missing required configs, incompatible annotation structure). Example: >>> from torch_concepts.nn.modules.utils import GroupConfig, check_collection >>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss >>> from torch_concepts import Annotations >>> loss_config = GroupConfig( ... binary=BCEWithLogitsLoss(), ... categorical=CrossEntropyLoss() ... ) >>> concept_annotations = Annotations( ... labels=['c1', 'c2', 'c3'], ... cardinalities=[1, 3, 2], ... types=['binary', 'categorical', 'categorical'], ... ) >>> filtered_config = check_collection( ... concept_annotations, ... loss_config, ... 'loss' ... ) """ assert collection_name in ['loss', 'metrics'], ( f"collection_name must be 'loss' or 'metrics', got '{collection_name}'" ) # Use cached type_groups from Annotations groups = annotations.type_groups has_binary = len(groups['binary']['labels']) > 0 has_categorical = len(groups['categorical']['labels']) > 0 has_continuous = len(groups['continuous']['labels']) > 0 # Raise error if continuous concepts are present if has_continuous: raise NotImplementedError("Continuous concepts not yet implemented.") # Extract items from collection binary = collection.get('binary') categorical = collection.get('categorical') continuous = collection.get('continuous') # Validation rules errors = [] # Check nested/dense compatibility all_binary = has_binary and not has_categorical and not has_continuous all_categorical = has_categorical and not has_binary and not has_continuous all_continuous = has_continuous and not has_binary and not has_categorical if all_binary: if annotations.is_nested: errors.append("Annotations for all-binary concepts should NOT be nested.") elif all_categorical: if not annotations.is_nested: errors.append("Annotations for all-categorical concepts should be nested.") elif all_continuous: if annotations.is_nested: errors.append("Annotations for all-continuous concepts should NOT be nested.") elif has_binary or has_categorical: if not annotations.is_nested: errors.append("Annotations for mixed concepts should be nested.") # Check required items are present if has_binary and binary is None: errors.append(f"{collection_name} missing 'binary' for binary concepts.") if has_categorical and categorical is None: errors.append(f"{collection_name} missing 'categorical' for categorical concepts.") if has_continuous and continuous is None: errors.append(f"{collection_name} missing 'continuous' for continuous concepts.") if errors: raise ValueError(f"{collection_name} validation failed:\n" + "\n".join(f" - {e}" for e in errors)) # Warnings for unused items if not has_binary and binary is not None: warnings.warn(f"Binary {collection_name} will be ignored (no binary concepts).") if not has_categorical and categorical is not None: warnings.warn(f"Categorical {collection_name} will be ignored (no categorical concepts).") if not has_continuous and continuous is not None: warnings.warn(f"Continuous {collection_name} will be ignored (no continuous concepts).") # Build filtered GroupConfig with only needed items filtered = GroupConfig() if has_binary: filtered['binary'] = binary if has_categorical: filtered['categorical'] = categorical if has_continuous: filtered['continuous'] = continuous return filtered def indices_to_mask( c_idxs: Union[List[int], torch.Tensor], c_vals: Union[List[float], torch.Tensor], n_concepts: int, batch_size: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Convert index-based interventions to mask-based format. This helper translates interventions specified as (indices, values) pairs into (mask, target) tensors, enabling uniform "mask-space" processing while supporting intuitive index-based specifications for inference/practice. Args: c_idxs: Concept indices to intervene on. Can be a list or tensor of shape [K]. c_vals: Intervention values for each concept. Can be a list or tensor of shape [K] (same value for all batches) or [B, K] (per-batch values). n_concepts: Total number of concepts (F). batch_size: Batch size (B). Default: 1. device: Target device for output tensors. Default: None (CPU). dtype: Target dtype for output tensors. Default: None (float32). Returns: tuple: (mask, target) where: - mask: Binary tensor of shape [B, F] where 0 indicates intervention, 1 keeps prediction. - target: Target tensor of shape [B, F] with intervention values at specified indices. Non-intervened positions are set to 0.0 (arbitrary, as they're masked out). Example: >>> from torch_concepts.nn.modules.utils import indices_to_mask >>> # Intervene on concepts 0 and 2, setting them to 1.0 and 0.5 >>> mask, target = indices_to_mask( ... c_idxs=[0, 2], ... c_vals=[1.0, 0.5], ... n_concepts=5, ... batch_size=2 ... ) >>> print(mask.shape, target.shape) torch.Size([2, 5]) torch.Size([2, 5]) >>> print(mask[0]) # [0, 1, 0, 1, 1] - intervene on 0 and 2 tensor([0., 1., 0., 1., 1.]) >>> print(target[0]) # [1.0, 0, 0.5, 0, 0] tensor([1.0000, 0.0000, 0.5000, 0.0000, 0.0000]) """ if dtype is None: dtype = torch.float32 # Convert indices to tensor if not isinstance(c_idxs, torch.Tensor): c_idxs = torch.tensor(c_idxs, dtype=torch.long, device=device) else: c_idxs = c_idxs.to(dtype=torch.long, device=device) # Convert values to tensor if not isinstance(c_vals, torch.Tensor): c_vals = torch.tensor(c_vals, dtype=dtype, device=device) else: c_vals = c_vals.to(dtype=dtype, device=device) # Validate indices K = c_idxs.numel() if K == 0: # No interventions - return all-ones mask and zeros target mask = torch.ones((batch_size, n_concepts), dtype=dtype, device=device) target = torch.zeros((batch_size, n_concepts), dtype=dtype, device=device) return mask, target if c_idxs.dim() != 1: raise ValueError(f"c_idxs must be 1-D, got shape {c_idxs.shape}") if torch.any(c_idxs < 0) or torch.any(c_idxs >= n_concepts): raise ValueError(f"All indices must be in range [0, {n_concepts}), got {c_idxs}") # Handle c_vals shape: [K] or [B, K] if c_vals.dim() == 1: if c_vals.numel() != K: raise ValueError(f"c_vals length {c_vals.numel()} must match c_idxs length {K}") # Broadcast to [B, K] c_vals = c_vals.unsqueeze(0).expand(batch_size, -1) elif c_vals.dim() == 2: B_vals, K_vals = c_vals.shape if K_vals != K: raise ValueError(f"c_vals second dim {K_vals} must match c_idxs length {K}") if B_vals != batch_size: raise ValueError(f"c_vals first dim {B_vals} must match batch_size {batch_size}") else: raise ValueError(f"c_vals must be 1-D or 2-D, got shape {c_vals.shape}") # Initialize mask (1 = keep prediction, 0 = replace with target) mask = torch.ones((batch_size, n_concepts), dtype=dtype, device=device) # Initialize target (arbitrary values for non-intervened positions) target = torch.zeros((batch_size, n_concepts), dtype=dtype, device=device) # Set mask to 0 at intervention indices mask[:, c_idxs] = 0.0 # Set target values at intervention indices target[:, c_idxs] = c_vals return mask, target # ============================================================================= # Training mode utilities (reusable for other models) # ============================================================================= from .high.base.learner import BaseLearner _CLASS_CACHE = {} def with_training_mode(cls, lightning: bool = False): """Create a combined class with BaseLearner mixin for Lightning training. This utility adds the BaseLearner mixin to any model class when lightning=True. The BaseLearner provides the Lightning training loop. Inference engine selection (train vs eval) is handled automatically by the ``inference`` property on BaseModel, which returns ``train_inference`` or ``eval_inference`` depending on the module's train/eval mode (toggled by ``.train()``/``.eval()``). Parameters ---------- cls : type The base model class. lightning : bool, default False If True, adds BaseLearner mixin for Lightning training. If False, returns the original class (pure PyTorch module). Returns ------- type Combined class with BaseLearner mixin, or original class if lightning is False. """ # No training mode = pure PyTorch module (no learner mixin) if not lightning: return cls # Add BaseLearner mixin for Lightning training cache_key = (cls, True) if cache_key not in _CLASS_CACHE: _CLASS_CACHE[cache_key] = type( f'{cls.__name__}_Module', (cls, BaseLearner), {} ) return _CLASS_CACHE[cache_key]