torch_concepts.data.base.datamodule.ConceptDataModule

class ConceptDataModule(dataset: ConceptDataset, val_size: float = 0.1, test_size: float = 0.2, batch_size: int = 64, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, scalers: Mapping | None = None, splitter: object | None = None, workers: int = 0, pin_memory: bool = False)[source]

PyTorch Lightning DataModule for concept-based datasets.

Handles the complete data pipeline for concept-based learning:

  1. Data splitting: Train/validation/test splits using configurable splitters

  2. Embedding precomputation: Optional backbone feature extraction with caching

  3. Data scaling: Optional normalization through configurable scalers

  4. DataLoader creation: Efficient data loading with proper configurations

The datamodule automatically caches computed embeddings to disk, allowing fast reloading on subsequent runs without recomputation.

Parameters:
  • dataset (ConceptDataset) – Complete dataset to be split and processed.

  • val_size (float, optional) – Validation set fraction (0.0 to 1.0). Default is 0.1.

  • test_size (float, optional) – Test set fraction (0.0 to 1.0). Default is 0.2.

  • batch_size (int, optional) – Mini-batch size for DataLoaders. Default is 64.

  • backbone (str or None, optional) –

    Feature extraction model name. Can be:

    • HuggingFace model: ‘facebook/dinov2-base’, ‘google/vit-base-patch16-224’

    • torchvision model: ‘resnet18’, ‘resnet50’, ‘vgg16’, ‘efficientnet_b0’

    If provided with precompute_embs=True, embeddings are computed and cached to disk. Default is None.

  • precompute_embs (bool, optional) – If True and backbone is provided, precompute and cache backbone embeddings before training. Embeddings are saved to {dataset.root_dir}/{backbone_filename}.pt. Default is False.

  • force_recompute (bool, optional) – If True, recompute embeddings even if cached file exists. Useful when the dataset or backbone changes. Default is False.

  • scalers (Mapping or None, optional) – Dictionary of custom scalers for data normalization. Keys should match target keys in the batch (e.g., ‘input’, ‘concepts’). If None, no scaling is applied. Default is None.

  • splitter (object or None, optional) – Custom splitter for train/val/test splits. Must implement a split(dataset) method that sets train_idxs, val_idxs, and test_idxs attributes. If None, uses RandomSplitter with the specified val_size and test_size. Default is None.

  • workers (int, optional) – Number of subprocesses for data loading. 0 means data will be loaded in the main process. Default is 0.

  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into pinned memory before returning them. Useful for GPU training. Default is False.

dataset

The underlying concept dataset.

Type:

ConceptDataset

trainset

Training subset after setup().

Type:

Subset or None

valset

Validation subset after setup().

Type:

Subset or None

testset

Test subset after setup().

Type:

Subset or None

backbone

The backbone wrapper for feature extraction.

Type:

Backbone or None

scalers

Dictionary of scalers for data normalization.

Type:

dict

splitter

The splitter used for data splitting.

Type:

object

Examples

Basic usage with random splitting:

>>> from torch_concepts.data.datasets import ToyDataset
>>> dataset = ToyDataset(dataset='xor', n_gen=1000)
>>> dm = ConceptDataModule(
...     dataset=dataset,
...     val_size=0.1,
...     test_size=0.2,
...     batch_size=32
... )
>>> dm.setup('fit')
>>> print(f"Train: {dm.train_len}, Val: {dm.val_len}, Test: {dm.test_len}")
Train: 700, Val: 100, Test: 200

Using backbone for embedding precomputation:

>>> dm = ConceptDataModule(
...     dataset=image_dataset,
...     backbone='resnet50',
...     precompute_embs=True,
...     batch_size=64,
...     workers=4
... )
>>> dm.setup('fit')  # Computes and caches embeddings
>>> # On subsequent runs, embeddings are loaded from cache

Using HuggingFace backbone:

>>> dm = ConceptDataModule(
...     dataset=image_dataset,
...     backbone='facebook/dinov2-base',
...     precompute_embs=True
... )

See also

Backbone

Feature extraction wrapper class.

ConceptDataset

Base dataset class for concept data.

RandomSplitter

Default splitter for train/val/test splits.

NativeSplitter

Splitter using dataset’s native splits.

__init__(dataset: ConceptDataset, val_size: float = 0.1, test_size: float = 0.2, batch_size: int = 64, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, scalers: Mapping | None = None, splitter: object | None = None, workers: int = 0, pin_memory: bool = False)[source]
prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

Methods

__init__(dataset[, val_size, test_size, ...])

from_datasets([train_dataset, val_dataset, ...])

Create an instance from torch.utils.data.Dataset.

get_dataloader([split, shuffle, batch_size])

Get the DataLoader for a specific split.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a datamodule from a checkpoint.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_exception(exception)

Called when the trainer execution is interrupted by an exception.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

prepare_data()

Use this to download and prepare data.

remove_ignored_hparams(ignore_list)

Remove ignored hyperparameters from the stored state.

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

setup([stage, backbone_device, verbose])

Prepare the data for training, validation, or testing.

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader([shuffle, batch_size])

Get the test DataLoader.

train_dataloader([shuffle, batch_size])

Get the training DataLoader.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

val_dataloader([shuffle, batch_size])

Get the validation DataLoader.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

backbone

The backbone model wrapper for feature extraction.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

n_samples

Total number of samples in the dataset.

name

test_len

Number of samples in the test set.

testset

The test subset.

train_len

Number of samples in the training set.

trainset

The training subset.

val_len

Number of samples in the validation set.

valset

The validation subset.