Contributing a New Dataset¶
This guide explains how to add a new dataset to PyC.
Every dataset in PyC follows the same four-method contract:
download fetches raw
files, build processes them into tensors, load_raw reads the processed files
from disk, and load adds any final preprocessing before handing data to the rest
of the library. The caching layer in the base class means each step runs only once
unless the on-disk files are missing.
Dataset Class
Inheritance
Every dataset extends ConceptDataset, which
is found in torch_concepts.data.base:
ConceptDataset (torch.utils.data.Dataset)
└── YourDataset
The base class owns input_data, concepts, annotations, and graph.
Your job is to populate them by calling super().__init__ at the end of your
__init__.
The four abstract methods
Method / property |
What it must do |
|---|---|
|
Return a list of raw file names (empty if nothing needs to be downloaded). |
|
Return a list of processed file names written by
|
|
Fetch raw files into |
|
Read/generate data, create |
|
Call |
|
Call |
The caching pipeline is: load() → load_raw() → maybe_build() →
build() → maybe_download() → download(). You never call these helpers
yourself; the base class does.
Complete worked example
The dataset below is a minimal but real example. It generates a synthetic dataset with three concept types (binary, categorical, continuous), a downstream task, and a causal graph. It can serve as a copy-paste starting point.
# torch_concepts/data/datasets/my_dataset.py
import os
import logging
import pandas as pd
import torch
from typing import List, Optional
from torch_concepts.annotations import Annotations
from torch_concepts.data.base import ConceptDataset
logger = logging.getLogger(__name__)
class MyDataset(ConceptDataset):
"""Synthetic dataset with mixed concept types.
Parameters
----------
root : str, optional
Directory where processed files are stored. Defaults to
``./data/my_dataset`` in the current working directory.
seed : int, default 42
Controls data generation. Also baked into processed file names so
that different seeds produce independent caches on disk.
n_gen : int, default 5000
Number of samples to generate.
concept_subset : list of str, optional
If provided, only this subset of concepts is kept after loading.
"""
def __init__(
self,
root: str = None,
seed: int = 42,
n_gen: int = 5000,
concept_subset: Optional[List[str]] = None,
):
self.seed = seed
self.n_gen = n_gen
if root is None:
root = os.path.join(os.getcwd(), "data", "my_dataset")
self.root = root
input_data, concepts, annotations, graph = self.load()
super().__init__(
input_data=input_data,
concepts=concepts,
annotations=annotations,
graph=graph,
concept_names_subset=concept_subset,
name="MyDataset",
)
# ------------------------------------------------------------------
# File lists
# ------------------------------------------------------------------
@property
def raw_filenames(self) -> List[str]:
# Nothing to download — data is generated programmatically.
return []
@property
def processed_filenames(self) -> List[str]:
# Encode seed and n_gen so different runs have independent caches.
return [
f"inputs_N_{self.n_gen}_seed_{self.seed}.pt",
f"concepts_N_{self.n_gen}_seed_{self.seed}.h5",
"annotations.pt",
"graph.h5",
]
# ------------------------------------------------------------------
# Download (nothing to do for synthetic data)
# ------------------------------------------------------------------
def download(self):
pass # no remote files
# ------------------------------------------------------------------
# Build: generate data, create Annotations, save everything
# ------------------------------------------------------------------
def build(self):
logger.info(f"Generating MyDataset (n={self.n_gen}, seed={self.seed})")
torch.manual_seed(self.seed)
# --- generate raw tensors -----------------------------------------
n = self.n_gen
# binary concept: 0 or 1
smoker = torch.bernoulli(torch.full((n,), 0.4))
# categorical concept: 3 genotype states (one-hot stored as int index)
genotype = torch.multinomial(
torch.tensor([0.5, 0.3, 0.2]).expand(n, -1),
num_samples=1,
).squeeze(1).float()
# continuous concept: tar level
tar = torch.randn(n) * 0.5 + smoker * 2.0
# binary task: cancer risk
logit = smoker * 1.5 + tar * 0.8 + (genotype == 2).float() * 1.0
cancer = torch.bernoulli(torch.sigmoid(logit))
# input features: noisy version of the latent signal
inputs = torch.stack(
[smoker, genotype, tar,
torch.randn(n), torch.randn(n), torch.randn(n)],
dim=1,
)
# concepts tensor: one column per concept/task
concepts_df = pd.DataFrame({
"smoker": smoker.numpy(),
"genotype": genotype.numpy(),
"tar": tar.numpy(),
"cancer": cancer.numpy(),
})
# --- build Annotations --------------------------------------------
# cardinality: 1 = binary or scalar continuous,
# K = K-class categorical
concept_names = ["smoker", "genotype", "tar", "cancer"]
cardinalities = [1, 3, 1, 1]
types = ["binary", "categorical", "continuous", "binary"]
# states: human-readable labels for each state of each concept
# (None for continuous concepts)
states = [
["non-smoker", "smoker"], # smoker
["wild-type", "het", "hom"], # genotype
None, # tar (continuous)
["no cancer", "cancer"], # cancer
]
# per-concept metadata (optional free-form dict)
metadata = {
"smoker": {"source": "self-report"},
"genotype": {"source": "WGS"},
"tar": {"unit": "mg/cigarette"},
"cancer": {"icd10": "C34"},
}
annotations = Annotations(
labels=concept_names,
cardinalities=cardinalities,
types=types,
states=states,
metadata=metadata,
)
# --- causal graph -------------------------------------------------
# rows = causes, columns = effects; 1 means "causes"
graph = pd.DataFrame(
[[0, 0, 1, 1], # smoker -> tar, cancer
[0, 0, 0, 1], # genotype -> cancer
[0, 0, 0, 1], # tar -> cancer
[0, 0, 0, 0]], # cancer (sink)
index=concept_names,
columns=concept_names,
)
# --- save all four components -------------------------------------
os.makedirs(self.root_dir, exist_ok=True)
torch.save(inputs, self.processed_paths[0])
concepts_df.to_hdf(self.processed_paths[1], key="concepts", mode="w")
torch.save(annotations, self.processed_paths[2])
graph.to_hdf(self.processed_paths[3], key="graph", mode="w")
# ------------------------------------------------------------------
# Load
# ------------------------------------------------------------------
def load_raw(self):
self.maybe_build()
logger.info(f"Loading MyDataset from {self.root_dir}")
inputs = torch.load(self.processed_paths[0], weights_only=False)
concepts = pd.read_hdf(self.processed_paths[1], "concepts")
annotations = torch.load(self.processed_paths[2], weights_only=False)
graph = pd.read_hdf(self.processed_paths[3], "graph")
return inputs, concepts, annotations, graph
def load(self):
# load_raw already handles all the work for this dataset.
# Override here if you need to add preprocessing (e.g., normalization,
# autoencoder embedding extraction) on top of the raw tensors.
return self.load_raw()
Notes on processed file names
For synthetic datasets the seed and sample count must be part of the file names
(e.g., inputs_N_5000_seed_42.pt). This ensures that changing seed or
n_gen triggers a fresh build() rather than silently re-using an old cache.
Downloading a remote file
If your dataset comes from a URL, use download_url from
torch_concepts.data.io inside download():
from torch_concepts.data.io import download_url
def download(self):
url = "https://example.com/my_data.csv.gz"
download_url(url, self.root_dir) # saves to self.root_dir/<filename>
@property
def raw_filenames(self) -> List[str]:
return ["my_data.csv.gz"]
maybe_download() (called automatically by build()) skips the download if
all paths in self.raw_paths already exist on disk.
Verifying the dataset
After writing the class, check it interactively before registering it:
from torch_concepts.data.datasets.my_dataset import MyDataset
ds = MyDataset(seed=0, n_gen=200)
print(ds)
# MyDataset(n_samples=200, n_features=(6,), n_concepts=4)
print(ds.concept_names)
# ['smoker', 'genotype', 'tar', 'cancer']
print(ds.annotations.types)
# ('binary', 'categorical', 'continuous', 'binary')
sample = ds[0]
print(sample["inputs"]["x"].shape) # torch.Size([6])
print(sample["concepts"]["c"].shape) # torch.Size([4])
print(ds.graph)
DataModule
A DataModule wraps a dataset and handles splitting, batching, and dataloaders for
PyTorch Lightning (and plain PyTorch). Extend
ConceptDataModule from
torch_concepts.data.base.datamodule.
Your __init__ only needs to instantiate the dataset and call
super().__init__(dataset=dataset, ...). Everything else — train_dataloader,
val_dataloader, test_dataloader, split logic — is inherited.
# torch_concepts/data/datamodules/my_dataset.py
from ..base.datamodule import ConceptDataModule
from ..datasets.my_dataset import MyDataset
class MyDataModule(ConceptDataModule):
"""DataModule for MyDataset.
Parameters
----------
seed : int
Random seed for the train/val/test split (independent of the
generation seed passed to the dataset).
generation_seed : int, default 42
Seed forwarded to :class:`MyDataset` for data generation.
n_gen : int, default 5000
Number of samples forwarded to :class:`MyDataset`.
val_size : float or int, default 0.1
Fraction (float) or absolute count (int) for the validation split.
test_size : float or int, default 0.2
Fraction (float) or absolute count (int) for the test split.
batch_size : int, default 256
Batch size for all dataloaders.
workers : int, default 0
Number of dataloader worker processes.
"""
def __init__(
self,
seed: int,
root: str = None,
generation_seed: int = 42,
n_gen: int = 5000,
val_size: float = 0.1,
test_size: float = 0.2,
batch_size: int = 256,
workers: int = 0,
**kwargs,
):
dataset = MyDataset(
root=root,
seed=generation_seed,
n_gen=n_gen,
)
super().__init__(
dataset=dataset,
val_size=val_size,
test_size=test_size,
batch_size=batch_size,
workers=workers,
seed=seed,
)
Using the DataModule
dm = MyDataModule(seed=1, generation_seed=42, n_gen=5000, batch_size=128)
dm.setup()
for batch in dm.train_dataloader():
x = batch["inputs"]["x"] # (128, 6)
c = batch["concepts"]["c"] # (128, 4)
break
Registering the Dataset
Once the dataset and DataModule work locally, register them in three places.
1. Module files
Place the dataset class in torch_concepts/data/datasets/my_dataset.py and the
DataModule in torch_concepts/data/datamodules/my_dataset.py.
2. Public API exports
Add imports and __all__ entries to torch_concepts/data/__init__.py:
# in torch_concepts/data/__init__.py
from .datasets.my_dataset import MyDataset
from .datamodules.my_dataset import MyDataModule
__all__ = [
...
"MyDataset",
"MyDataModule",
]
After this, users can do from torch_concepts.data import MyDataset.
3. API reference page (optional but recommended)
Add autoclass directives to doc/modules/data_api.rst so the docstrings
appear in the rendered documentation:
.. autoclass:: torch_concepts.data.MyDataset
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: torch_concepts.data.MyDataModule
:members:
:undoc-members:
:show-inheritance:
4. Tests
Add a test in tests/ that instantiates the dataset, checks the output shapes,
and verifies the annotations. Mirror the existing tests in tests/test_data.py
for the expected structure.
Next Steps¶
Read the
ConceptDatasetAPI reference for the full list of inherited properties (n_samples,n_features,concept_names,graph, etc.).Explore existing datasets in
torch_concepts/data/datasets/to see how different data sources (remote files, bnlearn graphs, real image datasets) are handled.See Contributing for the full pull-request workflow.