torch_concepts.data.CelebADataModule

class CelebADataModule(root: str | None = None, splitter: Splitter = NativeSplitter(train_size=None, val_size=None, test_size=None), val_size: int | float = 0.1, test_size: int | float = 0.2, batch_size: int = 512, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = True, force_recompute: bool = False, concept_subset: list | None = None, label_descriptions: dict | None = None, workers: int = 0, **kwargs)[source]

DataModule for CelebA dataset with concept-based learning support.

Handles data loading, splitting, and batching for the CelebA (Large-scale CelebFaces Attributes) dataset. Supports precomputing backbone embeddings and flexible train/val/test splitting strategies.

Parameters:
  • seed (int) – Random seed for reproducibility in data splitting and sampling.

  • name (str) – Dataset identifier used for caching and logging purposes.

  • root (str) – Root directory where the CelebA dataset is stored or will be downloaded.

  • splitter (Splitter, optional) – Splitting strategy for train/val/test partitioning. Default: NativeSplitter() which uses CelebA’s native split.

  • val_size (int or float, optional) – Validation set size. If float, interpreted as fraction of training data. If int, interpreted as absolute number of samples. Default: 0.1

  • test_size (int or float, optional) – Test set size. If float, interpreted as fraction of data. If int, interpreted as absolute number of samples. Default: 0.2

  • batch_size (int, optional) – Number of samples per batch. Default: 512

  • backbone (BackboneType, optional) – Backbone model for feature extraction (e.g., ResNet, ViT). If provided, can be used to precompute embeddings. Default: None

  • precompute_embs (bool, optional) – Whether to precompute and cache backbone embeddings for faster training. Requires backbone to be specified. Default: True

  • force_recompute (bool, optional) – If True, recompute embeddings even if cached version exists. Default: False

  • concept_subset (list of str, optional) – Subset of concept/attribute names to use. If None, uses all 40 CelebA attributes. Default: None

  • label_descriptions (dict, optional) – Dictionary mapping attribute names to human-readable descriptions. Default: None

  • workers (int, optional) – Number of worker processes for data loading. Default: 0 (main process only)

  • **kwargs – Additional arguments passed to parent ConceptDataModule.

dataset

The underlying CelebA dataset instance.

Type:

CelebADataset

train_dataset

Training split of the dataset.

Type:

Dataset

val_dataset

Validation split of the dataset.

Type:

Dataset

test_dataset

Test split of the dataset.

Type:

Dataset

Examples

Basic usage with default settings:

>>> from torch_concepts.data import CelebADataModule
>>>
>>> dm = CelebADataModule(
...     seed=42,
...     root='./data/celeba',
...     batch_size=64
... )
>>> dm.setup()
>>> train_loader = dm.train_dataloader()

With backbone for precomputed embeddings:

>>> from torchvision.models import resnet18
>>>
>>> backbone = resnet18(pretrained=True)
>>> dm = CelebADataModule(
...     seed=42,
...     root='./data/celeba',
...     backbone=backbone,
...     precompute_embs=True,
...     concept_subset=['Smiling', 'Male', 'Young']
... )

See also

CelebADataset

The underlying dataset class

ConceptDataModule

Parent class with common datamodule functionality

__init__(root: str | None = None, splitter: Splitter = NativeSplitter(train_size=None, val_size=None, test_size=None), val_size: int | float = 0.1, test_size: int | float = 0.2, batch_size: int = 512, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = True, force_recompute: bool = False, concept_subset: list | None = None, label_descriptions: dict | None = None, workers: int = 0, **kwargs)[source]
prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

Methods

__init__([root, splitter, val_size, ...])

from_datasets([train_dataset, val_dataset, ...])

Create an instance from torch.utils.data.Dataset.

get_dataloader([split, shuffle, batch_size])

Get the DataLoader for a specific split.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a datamodule from a checkpoint.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_exception(exception)

Called when the trainer execution is interrupted by an exception.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

prepare_data()

Use this to download and prepare data.

remove_ignored_hparams(ignore_list)

Remove ignored hyperparameters from the stored state.

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

setup([stage, backbone_device, verbose])

Prepare the data for training, validation, or testing.

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader([shuffle, batch_size])

Get the test DataLoader.

train_dataloader([shuffle, batch_size])

Get the training DataLoader.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

val_dataloader([shuffle, batch_size])

Get the validation DataLoader.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

backbone

The backbone model wrapper for feature extraction.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

n_samples

Total number of samples in the dataset.

name

test_len

Number of samples in the test set.

testset

The test subset.

train_len

Number of samples in the training set.

trainset

The training subset.

val_len

Number of samples in the validation set.

valset

The validation subset.