torch_concepts.GroupConfig

class GroupConfig(binary: Any | None = None, categorical: Any | None = None, continuous: Any | None = None, **kwargs)[source]

Container for storing classes organized by concept type groups.

This class acts as a convenient wrapper around a dictionary that maps concept type names to their corresponding classes or configurations.

_config

Internal dictionary storing the configuration.

Type:

Dict[str, Any]

Parameters:
  • binary – Configuration for binary concepts. If provided alone, applies to all concept types.

  • categorical – Configuration for categorical concepts.

  • continuous – Configuration for continuous concepts.

  • **kwargs – Additional group configurations.

Example

>>> from torch_concepts.nn.modules.utils import GroupConfig
>>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
>>> loss_config = GroupConfig(binary=CrossEntropyLoss())
>>> # Equivalent to: {'binary': CrossEntropyLoss()}
>>>
>>> # Different configurations per type
>>> loss_config = GroupConfig(
...     binary=BCEWithLogitsLoss(),
...     categorical=CrossEntropyLoss(),
...     continuous=MSELoss()
... )
>>>
>>> # Access configurations
>>> default_loss = MSELoss()
>>> binary_loss = loss_config['binary']
>>> loss_config.get('continuous', default_loss)
MSELoss()
>>>
>>> # Check what's configured
>>> 'binary' in loss_config
True
>>> list(loss_config.keys())
['binary', 'categorical', 'continuous']
__init__(binary: Any | None = None, categorical: Any | None = None, continuous: Any | None = None, **kwargs)[source]

Methods

__init__([binary, categorical, continuous])

from_dict(config_dict)

Create GroupConfig from dictionary.

get(key[, default])

Get configuration for a group with optional default.

items()

Return (group, config) pairs.

keys()

Return configured group names.

to_dict()

Convert to plain dictionary.

values()

Return configured values.