torch_concepts.data.datasets.toy.ToyDataset¶
- class ToyDataset(dataset: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, concept_subset: list | None = None)[source]¶
Synthetic datasets for concept-based learning experiments.
This class provides several toy datasets with known ground-truth concept relationships and causal structures. Each dataset includes input features, binary concepts, tasks, and a directed acyclic graph (DAG) representing concept-to-task relationships.
Available Datasets¶
xor: Simple XOR dataset with 2 input features, 2 concepts (C1, C2), and 1 task (xor). The task is the XOR of the two concepts.
trigonometry: Dataset with 7 trigonometric input features derived from 3 hidden variables, 3 concepts (C1, C2, C3) representing the signs of the hidden variables, and 1 task (sumGreaterThan1).
dot: Dataset with 4 input features, 2 concepts based on dot products (dotV1V2GreaterThan0, dotV3V4GreaterThan0), and 1 task (dotV1V3GreaterThan0).
checkmark: Dataset with 4 input features and 4 concepts (A, B, C, D), where C = NOT B and D = A AND C, demonstrating causal relationships.
- param dataset:
Name of the toy dataset to load. Must be one of: ‘xor’, ‘trigonometry’, ‘dot’, or ‘checkmark’.
- type dataset:
str
- param root:
Root directory to store/load the dataset files. If None, defaults to ‘./data/toy_datasets/{dataset_name}’. Default: None
- type root:
str, optional
- param seed:
Random seed for reproducible data generation. Default: 42
- type seed:
int, optional
- param n_gen:
Number of samples to generate. Default: 10000
- type n_gen:
int, optional
- param concept_subset:
Subset of concept names to use. If provided, only the specified concepts will be included in the dataset. Default: None (use all concepts)
- type concept_subset:
list of str, optional
- input_data¶
Input features tensor of shape (n_samples, n_features).
- Type:
- concepts¶
Concepts and tasks tensor of shape (n_samples, n_concepts + n_tasks). Note: This includes both concepts and tasks concatenated.
- Type:
- annotations¶
Metadata about concept names, cardinalities, and types.
- Type:
- graph¶
Directed acyclic graph representing concept-to-task relationships. Stored as an adjacency matrix with concept/task names as indices.
- Type:
Examples
Basic usage with XOR dataset:
>>> from torch_concepts.data.datasets import ToyDataset >>> >>> # Create XOR dataset with 1000 samples >>> dataset = ToyDataset(dataset='xor', seed=42, n_gen=1000) >>> print(f"Dataset size: {len(dataset)}") >>> print(f"Input features: {dataset.n_features}") >>> print(f"Concepts: {dataset.concept_names}") >>> >>> # Access a single sample >>> sample = dataset[0] >>> x = sample['inputs']['x'] # input features >>> c = sample['concepts']['c'] # concepts and task >>> >>> # Get concept graph >>> print(dataset.graph)
References
See also
CompletenessDatasetSynthetic dataset for concept completeness experiments
- __init__(dataset: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, concept_subset: list | None = None)[source]¶
Methods
__init__(dataset[, root, seed, n_gen, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Generate synthetic data and save to disk.
download()No download needed for toy datasets.
load()Load the dataset (wraps load_raw).
load_raw()Load the generated dataset from disk.
maybe_build()maybe_download()maybe_reduce_annotations(annotations[, ...])Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use. If
None, will use all concepts.remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
Annotations for the concepts in the dataset.
List of concept names in the dataset.
exogenousMapping of dataset's exogenous variables.
Adjacency matrix of the causal graph between concepts.
has_conceptsWhether the dataset has concept annotations.
has_exogenousWhether the dataset has exogenous information.
Number of concepts in the dataset.
n_exogenousNumber of exogenous variables in the dataset.
Shape of features in dataset's input (excluding number of samples).
n_samplesNumber of samples in the dataset.
List of processed filenames that will be created during build step.
processed_pathsThe absolute paths of the processed files that must be present in order to skip building.
No raw files needed - data is generated.
raw_pathsThe absolute paths of the raw files that must be present in order to skip downloading.
root_dirshapeShape of the input tensor.