torch_concepts.data.datasets.toy.ToyDataset

class ToyDataset(dataset: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, concept_subset: list | None = None)[source]

Synthetic datasets for concept-based learning experiments.

This class provides several toy datasets with known ground-truth concept relationships and causal structures. Each dataset includes input features, binary concepts, tasks, and a directed acyclic graph (DAG) representing concept-to-task relationships.

Available Datasets

  • xor: Simple XOR dataset with 2 input features, 2 concepts (C1, C2), and 1 task (xor). The task is the XOR of the two concepts.

  • trigonometry: Dataset with 7 trigonometric input features derived from 3 hidden variables, 3 concepts (C1, C2, C3) representing the signs of the hidden variables, and 1 task (sumGreaterThan1).

  • dot: Dataset with 4 input features, 2 concepts based on dot products (dotV1V2GreaterThan0, dotV3V4GreaterThan0), and 1 task (dotV1V3GreaterThan0).

  • checkmark: Dataset with 4 input features and 4 concepts (A, B, C, D), where C = NOT B and D = A AND C, demonstrating causal relationships.

param dataset:

Name of the toy dataset to load. Must be one of: ‘xor’, ‘trigonometry’, ‘dot’, or ‘checkmark’.

type dataset:

str

param root:

Root directory to store/load the dataset files. If None, defaults to ‘./data/toy_datasets/{dataset_name}’. Default: None

type root:

str, optional

param seed:

Random seed for reproducible data generation. Default: 42

type seed:

int, optional

param n_gen:

Number of samples to generate. Default: 10000

type n_gen:

int, optional

param concept_subset:

Subset of concept names to use. If provided, only the specified concepts will be included in the dataset. Default: None (use all concepts)

type concept_subset:

list of str, optional

input_data

Input features tensor of shape (n_samples, n_features).

Type:

torch.Tensor

concepts

Concepts and tasks tensor of shape (n_samples, n_concepts + n_tasks). Note: This includes both concepts and tasks concatenated.

Type:

torch.Tensor

annotations

Metadata about concept names, cardinalities, and types.

Type:

Annotations

graph

Directed acyclic graph representing concept-to-task relationships. Stored as an adjacency matrix with concept/task names as indices.

Type:

pandas.DataFrame

concept_names

Names of all concepts and tasks in the dataset.

Type:

list of str

n_concepts

Total number of concepts and tasks (includes both).

Type:

int

n_features

Dimensionality of input features.

Type:

tuple or int

Examples

Basic usage with XOR dataset:

>>> from torch_concepts.data.datasets import ToyDataset
>>>
>>> # Create XOR dataset with 1000 samples
>>> dataset = ToyDataset(dataset='xor', seed=42, n_gen=1000)
>>> print(f"Dataset size: {len(dataset)}")
>>> print(f"Input features: {dataset.n_features}")
>>> print(f"Concepts: {dataset.concept_names}")
>>>
>>> # Access a single sample
>>> sample = dataset[0]
>>> x = sample['inputs']['x']  # input features
>>> c = sample['concepts']['c']  # concepts and task
>>>
>>> # Get concept graph
>>> print(dataset.graph)

References

See also

CompletenessDataset

Synthetic dataset for concept completeness experiments

__init__(dataset: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, concept_subset: list | None = None)[source]

Methods

__init__(dataset[, root, seed, n_gen, ...])

add_exogenous(name, value[, convert_precision])

add_scaler(key, scaler)

Add a scaler for preprocessing a specific tensor.

build()

Generate synthetic data and save to disk.

download()

No download needed for toy datasets.

load()

Load the dataset (wraps load_raw).

load_raw()

Load the generated dataset from disk.

maybe_build()

maybe_download()

maybe_reduce_annotations(annotations[, ...])

Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use. If None, will use all concepts.

remove_exogenous(name)

set_concepts(concepts)

Set concept annotations for the dataset.

set_graph(graph)

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Attributes

annotations

Annotations for the concepts in the dataset.

concept_names

List of concept names in the dataset.

exogenous

Mapping of dataset's exogenous variables.

graph

Adjacency matrix of the causal graph between concepts.

has_concepts

Whether the dataset has concept annotations.

has_exogenous

Whether the dataset has exogenous information.

n_concepts

Number of concepts in the dataset.

n_exogenous

Number of exogenous variables in the dataset.

n_features

Shape of features in dataset's input (excluding number of samples).

n_samples

Number of samples in the dataset.

processed_filenames

List of processed filenames that will be created during build step.

processed_paths

The absolute paths of the processed files that must be present in order to skip building.

raw_filenames

No raw files needed - data is generated.

raw_paths

The absolute paths of the raw files that must be present in order to skip downloading.

root_dir

shape

Shape of the input tensor.