Source code for torch_concepts.data.datamodules.categorical_toy_dag

"""
DataModule for ToyDAG synthetic datasets.
"""
import os
from typing import Dict, List, Tuple, Optional, Union
import numpy as np

from ..datasets.categorical_toy_dag import ToyDAGDataset
from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType


[docs] class ToyDAGDataModule(ConceptDataModule): """DataModule for ToyDAG synthetic datasets. Handles data loading, splitting, and batching for DAG-based synthetic datasets with support for concept-based learning. This datamodule wraps the ToyDAGDataset and provides standard train/val/test splits along with optional backbone feature extraction and embedding caching. Args: variables: List of all variable names in the DAG. cardinalities: Dictionary mapping variable names to their cardinality. dag: List of edges representing the DAG structure as (parent, child) tuples. conditional_probs: Dictionary mapping variables to their conditional probability tables. seed: Random seed for the train/val/test split. generation_seed: Random seed for data generation. root: Root directory to store/load the dataset. val_size: Validation set size (fraction or absolute count). test_size: Test set size (fraction or absolute count). batch_size: Batch size for dataloaders. backbone: Model backbone to use (if applicable). precompute_embs: Whether to precompute embeddings from backbone. force_recompute: Force recomputation of cached embeddings. n_gen: Total number of samples to generate. target_variable: Name of the target variable (optional). latent_variables: List of latent variable names. concept_subset: Subset of concepts to use. label_descriptions: Dictionary mapping concept names to descriptions. autoencoder_kwargs: Configuration for autoencoder-based feature extraction. workers: Number of workers for dataloaders. """
[docs] def __init__( self, variables: List[str], cardinalities: Dict[str, int], dag: List[Tuple[str, str]], conditional_probs: Optional[Dict[Union[Tuple[str, str], Tuple[str]], Union[np.ndarray, list]]] = None, seed: int = 42, generation_seed: int = 42, root: str = None, val_size: int | float = 0.1, test_size: int | float = 0.2, batch_size: int = 512, backbone: BackboneType = None, precompute_embs: bool = False, force_recompute: bool = False, n_gen: int = 10000, target_variable: Optional[str] = None, latent_variables: Optional[List[str]] = None, concept_subset: list | None = None, label_descriptions: dict | None = None, autoencoder_kwargs: dict | None = None, workers: int = 0, **kwargs ): # Create dataset dataset = ToyDAGDataset( variables=variables, cardinalities=cardinalities, dag=dag, conditional_probs=conditional_probs, root=root, seed=generation_seed, n_gen=n_gen, target_variable=target_variable, latent_variables=latent_variables, concept_subset=concept_subset, label_descriptions=label_descriptions, autoencoder_kwargs=autoencoder_kwargs, ) # Initialize parent class with default behavior super().__init__( dataset=dataset, val_size=val_size, test_size=test_size, batch_size=batch_size, backbone=backbone, precompute_embs=precompute_embs, force_recompute=force_recompute, workers=workers, seed=seed, )