torch_concepts.data.ToyDAGDataModule

class ToyDAGDataModule(variables: List[str], cardinalities: Dict[str, int], dag: List[Tuple[str, str]], conditional_probs: Dict[Tuple[str, str] | Tuple[str], ndarray | list] | None = None, seed: int = 42, generation_seed: int = 42, root: str | None = None, val_size: int | float = 0.1, test_size: int | float = 0.2, batch_size: int = 512, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, n_gen: int = 10000, target_variable: str | None = None, latent_variables: List[str] | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None, autoencoder_kwargs: dict | None = None, workers: int = 0, **kwargs)[source]

DataModule for ToyDAG synthetic datasets.

Handles data loading, splitting, and batching for DAG-based synthetic datasets with support for concept-based learning.

This datamodule wraps the ToyDAGDataset and provides standard train/val/test splits along with optional backbone feature extraction and embedding caching.

Parameters:
  • variables – List of all variable names in the DAG.

  • cardinalities – Dictionary mapping variable names to their cardinality.

  • dag – List of edges representing the DAG structure as (parent, child) tuples.

  • conditional_probs – Dictionary mapping variables to their conditional probability tables.

  • seed – Random seed for the train/val/test split.

  • generation_seed – Random seed for data generation.

  • root – Root directory to store/load the dataset.

  • val_size – Validation set size (fraction or absolute count).

  • test_size – Test set size (fraction or absolute count).

  • batch_size – Batch size for dataloaders.

  • backbone – Model backbone to use (if applicable).

  • precompute_embs – Whether to precompute embeddings from backbone.

  • force_recompute – Force recomputation of cached embeddings.

  • n_gen – Total number of samples to generate.

  • target_variable – Name of the target variable (optional).

  • latent_variables – List of latent variable names.

  • concept_subset – Subset of concepts to use.

  • label_descriptions – Dictionary mapping concept names to descriptions.

  • autoencoder_kwargs – Configuration for autoencoder-based feature extraction.

  • workers – Number of workers for dataloaders.

__init__(variables: List[str], cardinalities: Dict[str, int], dag: List[Tuple[str, str]], conditional_probs: Dict[Tuple[str, str] | Tuple[str], ndarray | list] | None = None, seed: int = 42, generation_seed: int = 42, root: str | None = None, val_size: int | float = 0.1, test_size: int | float = 0.2, batch_size: int = 512, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, n_gen: int = 10000, target_variable: str | None = None, latent_variables: List[str] | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None, autoencoder_kwargs: dict | None = None, workers: int = 0, **kwargs)[source]
prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

Methods

__init__(variables, cardinalities, dag[, ...])

from_datasets([train_dataset, val_dataset, ...])

Create an instance from torch.utils.data.Dataset.

get_dataloader([split, shuffle, batch_size])

Get the DataLoader for a specific split.

load_from_checkpoint(checkpoint_path[, ...])

Primary way of loading a datamodule from a checkpoint.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

on_exception(exception)

Called when the trainer execution is interrupted by an exception.

predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

prepare_data()

Use this to download and prepare data.

remove_ignored_hparams(ignore_list)

Remove ignored hyperparameters from the stored state.

save_hyperparameters(*args[, ignore, frame, ...])

Save arguments to hparams attribute.

setup([stage, backbone_device, verbose])

Prepare the data for training, validation, or testing.

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

test_dataloader([shuffle, batch_size])

Get the test DataLoader.

train_dataloader([shuffle, batch_size])

Get the training DataLoader.

transfer_batch_to_device(batch, device, ...)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

val_dataloader([shuffle, batch_size])

Get the validation DataLoader.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

backbone

The backbone model wrapper for feature extraction.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

n_samples

Total number of samples in the dataset.

name

test_len

Number of samples in the test set.

testset

The test subset.

train_len

Number of samples in the training set.

trainset

The training subset.

val_len

Number of samples in the validation set.

valset

The validation subset.