Source code for torch_concepts.data.datamodules.completeness
import os
from ..datasets.toy import CompletenessDataset
from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType
[docs]
class CompletenessDataModule(ConceptDataModule):
"""
"""
[docs]
def __init__(
self,
name: str, # name of the bnlearn DAG
root: str,
seed: int = 42,
generation_seed: int = 42,
p: int = 2, # dimensionality of each view
n_views: int = 10, # number of views
n_concepts: int = 2, # number of concepts
n_hidden_concepts: int = 0, # number of hidden concepts
n_tasks: int = 1, # number of tasks
n_gen: int = 10000,
val_size: int | float = 0.1,
test_size: int | float = 0.2,
batch_size: int = 512,
backbone: BackboneType = None,
precompute_embs: bool = False,
force_recompute: bool = False,
concept_subset: list | None = None,
label_descriptions: dict | None = None,
workers: int = 0,
**kwargs
):
dataset = CompletenessDataset(
name=name,
root=root,
seed=generation_seed,
p=p,
n_views=n_views,
n_concepts=n_concepts,
n_hidden_concepts=n_hidden_concepts,
n_tasks=n_tasks,
n_gen=n_gen,
concept_subset=concept_subset
)
super().__init__(
dataset=dataset,
val_size=val_size,
test_size=test_size,
batch_size=batch_size,
backbone=backbone,
precompute_embs=precompute_embs,
force_recompute=force_recompute,
workers=workers,
seed=seed
)