torch_concepts.GroupConfig¶
- class GroupConfig(binary: Any | None = None, categorical: Any | None = None, continuous: Any | None = None, **kwargs)[source]¶
Container for storing classes organized by concept type groups.
This class acts as a convenient wrapper around a dictionary that maps concept type names to their corresponding classes or configurations.
- Parameters:
binary – Configuration for binary concepts. If provided alone, applies to all concept types.
categorical – Configuration for categorical concepts.
continuous – Configuration for continuous concepts.
**kwargs – Additional group configurations.
Example
>>> from torch_concepts.nn.modules.utils import GroupConfig >>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss >>> loss_config = GroupConfig(binary=CrossEntropyLoss()) >>> # Equivalent to: {'binary': CrossEntropyLoss()} >>> >>> # Different configurations per type >>> loss_config = GroupConfig( ... binary=BCEWithLogitsLoss(), ... categorical=CrossEntropyLoss(), ... continuous=MSELoss() ... ) >>> >>> # Access configurations >>> default_loss = MSELoss() >>> binary_loss = loss_config['binary'] >>> loss_config.get('continuous', default_loss) MSELoss() >>> >>> # Check what's configured >>> 'binary' in loss_config True >>> list(loss_config.keys()) ['binary', 'categorical', 'continuous']
- __init__(binary: Any | None = None, categorical: Any | None = None, continuous: Any | None = None, **kwargs)[source]¶
Methods
__init__([binary, categorical, continuous])from_dict(config_dict)Create GroupConfig from dictionary.
get(key[, default])Get configuration for a group with optional default.
items()Return (group, config) pairs.
keys()Return configured group names.
to_dict()Convert to plain dictionary.
values()Return configured values.