torch_concepts.data.base.datamodule.ConceptDataModule¶
- class ConceptDataModule(dataset: ConceptDataset, val_size: float = 0.1, test_size: float = 0.2, batch_size: int = 64, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, scalers: Mapping | None = None, splitter: object | None = None, workers: int = 0, pin_memory: bool = False)[source]¶
PyTorch Lightning DataModule for concept-based datasets.
Handles the complete data pipeline for concept-based learning:
Data splitting: Train/validation/test splits using configurable splitters
Embedding precomputation: Optional backbone feature extraction with caching
Data scaling: Optional normalization through configurable scalers
DataLoader creation: Efficient data loading with proper configurations
The datamodule automatically caches computed embeddings to disk, allowing fast reloading on subsequent runs without recomputation.
- Parameters:
dataset (ConceptDataset) – Complete dataset to be split and processed.
val_size (float, optional) – Validation set fraction (0.0 to 1.0). Default is 0.1.
test_size (float, optional) – Test set fraction (0.0 to 1.0). Default is 0.2.
batch_size (int, optional) – Mini-batch size for DataLoaders. Default is 64.
backbone (str or None, optional) –
Feature extraction model name. Can be:
HuggingFace model: ‘facebook/dinov2-base’, ‘google/vit-base-patch16-224’
torchvision model: ‘resnet18’, ‘resnet50’, ‘vgg16’, ‘efficientnet_b0’
If provided with
precompute_embs=True, embeddings are computed and cached to disk. Default is None.precompute_embs (bool, optional) – If True and backbone is provided, precompute and cache backbone embeddings before training. Embeddings are saved to
{dataset.root_dir}/{backbone_filename}.pt. Default is False.force_recompute (bool, optional) – If True, recompute embeddings even if cached file exists. Useful when the dataset or backbone changes. Default is False.
scalers (Mapping or None, optional) – Dictionary of custom scalers for data normalization. Keys should match target keys in the batch (e.g., ‘input’, ‘concepts’). If None, no scaling is applied. Default is None.
splitter (object or None, optional) – Custom splitter for train/val/test splits. Must implement a
split(dataset)method that setstrain_idxs,val_idxs, andtest_idxsattributes. If None, uses RandomSplitter with the specifiedval_sizeandtest_size. Default is None.workers (int, optional) – Number of subprocesses for data loading. 0 means data will be loaded in the main process. Default is 0.
pin_memory (bool, optional) – If True, the data loader will copy Tensors into pinned memory before returning them. Useful for GPU training. Default is False.
- dataset¶
The underlying concept dataset.
- Type:
- trainset¶
Training subset after setup().
- Type:
Subset or None
- valset¶
Validation subset after setup().
- Type:
Subset or None
- testset¶
Test subset after setup().
- Type:
Subset or None
- backbone¶
The backbone wrapper for feature extraction.
- Type:
Backbone or None
Examples
Basic usage with random splitting:
>>> from torch_concepts.data.datasets import ToyDataset >>> dataset = ToyDataset(dataset='xor', n_gen=1000) >>> dm = ConceptDataModule( ... dataset=dataset, ... val_size=0.1, ... test_size=0.2, ... batch_size=32 ... ) >>> dm.setup('fit') >>> print(f"Train: {dm.train_len}, Val: {dm.val_len}, Test: {dm.test_len}") Train: 700, Val: 100, Test: 200
Using backbone for embedding precomputation:
>>> dm = ConceptDataModule( ... dataset=image_dataset, ... backbone='resnet50', ... precompute_embs=True, ... batch_size=64, ... workers=4 ... ) >>> dm.setup('fit') # Computes and caches embeddings >>> # On subsequent runs, embeddings are loaded from cache
Using HuggingFace backbone:
>>> dm = ConceptDataModule( ... dataset=image_dataset, ... backbone='facebook/dinov2-base', ... precompute_embs=True ... )
See also
BackboneFeature extraction wrapper class.
ConceptDatasetBase dataset class for concept data.
RandomSplitterDefault splitter for train/val/test splits.
NativeSplitterSplitter using dataset’s native splits.
- __init__(dataset: ConceptDataset, val_size: float = 0.1, test_size: float = 0.2, batch_size: int = 64, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, scalers: Mapping | None = None, splitter: object | None = None, workers: int = 0, pin_memory: bool = False)[source]¶
- prepare_data_per_node¶
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices¶
If True, dataloader with zero length within local rank is allowed. Default value is False.
Methods
__init__(dataset[, val_size, test_size, ...])from_datasets([train_dataset, val_dataset, ...])Create an instance from torch.utils.data.Dataset.
get_dataloader([split, shuffle, batch_size])Get the DataLoader for a specific split.
load_from_checkpoint(checkpoint_path[, ...])Primary way of loading a datamodule from a checkpoint.
load_state_dict(state_dict)Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
on_after_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_exception(exception)Called when the trainer execution is interrupted by an exception.
predict_dataloader()An iterable or collection of iterables specifying prediction samples.
prepare_data()Use this to download and prepare data.
remove_ignored_hparams(ignore_list)Remove ignored hyperparameters from the stored state.
save_hyperparameters(*args[, ignore, frame, ...])Save arguments to
hparamsattribute.setup([stage, backbone_device, verbose])Prepare the data for training, validation, or testing.
state_dict()Called when saving a checkpoint, implement to generate and save datamodule state.
teardown(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader([shuffle, batch_size])Get the test DataLoader.
train_dataloader([shuffle, batch_size])Get the training DataLoader.
transfer_batch_to_device(batch, device, ...)Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.val_dataloader([shuffle, batch_size])Get the validation DataLoader.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPEThe backbone model wrapper for feature extraction.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().Total number of samples in the dataset.
nameNumber of samples in the test set.
The test subset.
Number of samples in the training set.
The training subset.
Number of samples in the validation set.
The validation subset.