Source code for torch_concepts.data.datamodules.bnlearn
import os
from ..datasets import BnLearnDataset
from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType
[docs]
class BnLearnDataModule(ConceptDataModule):
"""DataModule for all Bayesian Network datasets.
Handles data loading, splitting, and batching for all Bayesian Network datasets
with support for concept-based learning.
Args:
seed: Random seed for data generation and splitting.
val_size: Validation set size (fraction or absolute count).
test_size: Test set size (fraction or absolute count).
batch_size: Batch size for dataloaders.
n_samples: Total number of samples to generate.
autoencoder_kwargs: Configuration for autoencoder-based feature extraction.
concept_subset: Subset of concepts to use. If None, uses all concepts.
label_descriptions: Dictionary mapping concept names to descriptions.
backbone: Model backbone to use (if applicable).
workers: Number of workers for dataloaders.
"""
[docs]
def __init__(
self,
seed: int, # seed for data generation
name: str, # name of the bnlearn DAG
root: str = None,
val_size: int | float = 0.1,
test_size: int | float = 0.2,
batch_size: int = 512,
backbone: BackboneType = None,
precompute_embs: bool = False,
force_recompute: bool = False,
n_gen: int = 10000,
concept_subset: list | None = None,
label_descriptions: dict | None = None,
autoencoder_kwargs: dict | None = None,
workers: int = 0,
**kwargs
):
dataset = BnLearnDataset(
name=name,
root=root,
seed=seed,
n_gen=n_gen,
concept_subset=concept_subset,
label_descriptions=label_descriptions,
autoencoder_kwargs=autoencoder_kwargs
)
super().__init__(
dataset=dataset,
val_size=val_size,
test_size=test_size,
batch_size=batch_size,
backbone=backbone,
precompute_embs=precompute_embs,
force_recompute=force_recompute,
workers=workers
)