Datasets

This module provides dataset implementations for concept-based learning.

Summary

Bayesian Network Datasets

BnLearnDataset

Dataset class for the Asia dataset from bnlearn.

Toy Datasets

ToyDataset

Synthetic datasets for concept-based learning experiments.

CompletenessDataset

Synthetic dataset for concept bottleneck completeness experiments.

MNIST Variants

ColorMNISTDataset

The color MNIST dataset is a modified version of the MNIST dataset where each digit is colored either red or green.

MNISTAddition

The MNIST addition dataset is a modified version of the MNIST dataset where each image is a concatenation of two MNIST images and the target label is the sum of the two digits.

PartialMNISTAddition

The partial MNIST addition dataset is a modified version of the MNIST addition dataset where the concept annotation is partial.

MNISTEvenOdd

The MNIST even-odd dataset is a modified version of the MNIST dataset where the task is to predict whether the digit is even or odd.

Image Datasets

celeba.CelebADataset

Dataset class for CelebA.

cub.CUBDataset

TODO

awa2.AwA2Dataset

Returns a compatible Torch Dataset object customized for the AwA2 dataset

Other Datasets

cebab.CEBaBDataset

traffic.TrafficLights

Synthetic traffic dataset

Class Documentation

Bayesian Network Datasets

class BnLearnDataset(name: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, concept_subset: list | None = None, label_descriptions: dict | None = None, autoencoder_kwargs: dict | None = None)[source]

Bases: ConceptDataset

Dataset class for the Asia dataset from bnlearn.

This dataset represents a small expert system that models the relationship between traveling to Asia, smoking habits, and various lung diseases.

property raw_filenames: List[str]

List of raw filenames that need to be present in the raw directory for the dataset to be considered present.

property processed_filenames: List[str]

List of processed filenames that will be created during build step.

download()[source]

Downloads dataset’s files to the self.root_dir folder.

build()[source]

Eventually build the dataset from raw data to self.root_dir folder.

load_raw()[source]

Loads raw dataset without any data preprocessing.

load()[source]

Loads raw dataset and preprocess data. Default to load_raw.

Toy Datasets

class ToyDataset(dataset: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, concept_subset: list | None = None)[source]

Bases: ConceptDataset

Synthetic datasets for concept-based learning experiments.

This class provides several toy datasets with known ground-truth concept relationships and causal structures. Each dataset includes input features, binary concepts, tasks, and a directed acyclic graph (DAG) representing concept-to-task relationships.

Available Datasets

  • xor: Simple XOR dataset with 2 input features, 2 concepts (C1, C2), and 1 task (xor). The task is the XOR of the two concepts.

  • trigonometry: Dataset with 7 trigonometric input features derived from 3 hidden variables, 3 concepts (C1, C2, C3) representing the signs of the hidden variables, and 1 task (sumGreaterThan1).

  • dot: Dataset with 4 input features, 2 concepts based on dot products (dotV1V2GreaterThan0, dotV3V4GreaterThan0), and 1 task (dotV1V3GreaterThan0).

  • checkmark: Dataset with 4 input features and 4 concepts (A, B, C, D), where C = NOT B and D = A AND C, demonstrating causal relationships.

param dataset:

Name of the toy dataset to load. Must be one of: ‘xor’, ‘trigonometry’, ‘dot’, or ‘checkmark’.

type dataset:

str

param root:

Root directory to store/load the dataset files. If None, defaults to ‘./data/toy_datasets/{dataset_name}’. Default: None

type root:

str, optional

param seed:

Random seed for reproducible data generation. Default: 42

type seed:

int, optional

param n_gen:

Number of samples to generate. Default: 10000

type n_gen:

int, optional

param concept_subset:

Subset of concept names to use. If provided, only the specified concepts will be included in the dataset. Default: None (use all concepts)

type concept_subset:

list of str, optional

input_data

Input features tensor of shape (n_samples, n_features).

Type:

torch.Tensor

concepts

Concepts and tasks tensor of shape (n_samples, n_concepts + n_tasks). Note: This includes both concepts and tasks concatenated.

Type:

torch.Tensor

annotations

Metadata about concept names, cardinalities, and types.

Type:

Annotations

graph

Directed acyclic graph representing concept-to-task relationships. Stored as an adjacency matrix with concept/task names as indices.

Type:

pandas.DataFrame

concept_names

Names of all concepts and tasks in the dataset.

Type:

list of str

n_concepts

Total number of concepts and tasks (includes both).

Type:

int

n_features

Dimensionality of input features.

Type:

tuple or int

Examples

Basic usage with XOR dataset:

>>> from torch_concepts.data.datasets import ToyDataset
>>>
>>> # Create XOR dataset with 1000 samples
>>> dataset = ToyDataset(dataset='xor', seed=42, n_gen=1000)
>>> print(f"Dataset size: {len(dataset)}")
>>> print(f"Input features: {dataset.n_features}")
>>> print(f"Concepts: {dataset.concept_names}")
>>>
>>> # Access a single sample
>>> sample = dataset[0]
>>> x = sample['inputs']['x']  # input features
>>> c = sample['concepts']['c']  # concepts and task
>>>
>>> # Get concept graph
>>> print(dataset.graph)

References

See also

CompletenessDataset

Synthetic dataset for concept completeness experiments

property raw_filenames: List[str]

No raw files needed - data is generated.

property processed_filenames: List[str]

List of processed filenames that will be created during build step.

download()[source]

No download needed for toy datasets.

build()[source]

Generate synthetic data and save to disk.

load_raw()[source]

Load the generated dataset from disk.

load()[source]

Load the dataset (wraps load_raw).

class CompletenessDataset(name: str, root: str | None = None, seed: int = 42, n_gen: int = 10000, p: int = 2, n_views: int = 10, n_concepts: int = 2, n_hidden_concepts: int = 0, n_tasks: int = 1, concept_subset: list | None = None)[source]

Bases: ConceptDataset

Synthetic dataset for concept bottleneck completeness experiments.

This dataset generates synthetic data to study complete vs. incomplete concept bottlenecks. Data is generated using randomly initialized multi-layer perceptrons with ReLU activations. Input features are sampled from a multivariate normal distribution, and concepts are derived through nonlinear transformations. Hidden concepts can be included to simulate incomplete bottlenecks.

The dataset uses a two-stage generation process: 1. Map inputs X to concepts C (both observed and hidden) via nonlinear function g 2. Map concepts C to tasks Y via nonlinear function f

Parameters:
  • name (str) – Name identifier for the dataset (used for file storage).

  • root (str, optional) – Root directory to store/load the dataset files. If None, defaults to ‘./data/completeness_datasets/{name}’. Default: None

  • seed (int, optional) – Random seed for reproducible data generation. Default: 42

  • n_gen (int, optional) – Number of samples to generate. Default: 10000

  • p (int, optional) – Dimensionality of each view (feature group). Default: 2

  • n_views (int, optional) – Number of views/feature groups. Total input features = p * n_views. Default: 10

  • n_concepts (int, optional) – Number of observable concepts (not including hidden concepts). Default: 2

  • n_hidden_concepts (int, optional) – Number of hidden concepts not observable in the bottleneck. Use this to simulate incomplete concept bottlenecks. Default: 0

  • n_tasks (int, optional) – Number of downstream tasks to predict. Default: 1

  • concept_subset (list of str, optional) – Subset of concept names to use. If provided, only the specified concepts will be included. Concept names follow format ‘C0’, ‘C1’, etc. Default: None

input_data

Input features tensor of shape (n_samples, p * n_views).

Type:

torch.Tensor

concepts

Concepts and tasks tensor of shape (n_samples, n_concepts + n_tasks). Note: Hidden concepts are NOT included in this tensor.

Type:

torch.Tensor

annotations

Metadata about concept names, cardinalities, and types.

Type:

Annotations

graph

Directed acyclic graph representing concept-to-task relationships. All concepts influence all tasks in this dataset.

Type:

pandas.DataFrame

concept_names

Names of all concepts and tasks. Format: [‘C0’, ‘C1’, …, ‘y’]

Type:

list of str

n_concepts

Total number of observable concepts and tasks (includes both, excludes hidden).

Type:

int

n_features

Dimensionality of input features (p * n_views).

Type:

tuple or int

Examples

Basic usage with complete bottleneck:

>>> from torch_concepts.data.datasets import CompletenessDataset
>>>
>>> # Create dataset with complete bottleneck (no hidden concepts)
>>> dataset = CompletenessDataset(
...     name='complete_exp',
...     n_gen=5000,
...     n_concepts=5,
...     n_hidden_concepts=0,
...     seed=42
... )
>>> print(f"Dataset size: {len(dataset)}")
>>> print(f"Input features: {dataset.n_features}")
>>> print(f"Concepts: {dataset.concept_names}")

Creating incomplete bottleneck with hidden concepts:

>>> from torch_concepts.data.datasets import CompletenessDataset
>>>
>>> # Create dataset with incomplete bottleneck
>>> dataset = CompletenessDataset(
...     name='incomplete_exp',
...     n_gen=5000,
...     n_concepts=3,          # 3 observable concepts
...     n_hidden_concepts=2,   # 2 hidden concepts (not in bottleneck)
...     seed=42
... )
>>> # The hidden concepts affect tasks but are not observable
>>> print(f"Observable concepts: {dataset.n_concepts}")

References

property raw_filenames: List[str]

No raw files needed - data is generated.

property processed_filenames: List[str]

List of processed filenames that will be created during build step.

download()[source]

No download needed for synthetic datasets.

build()[source]

Generate synthetic completeness data and save to disk.

load_raw()[source]

Load the generated dataset from disk.

load()[source]

Load the dataset (wraps load_raw).

MNIST Variants

class ColorMNISTDataset(root: str, train: bool = False, transform=None, target_transform=None, download: bool = False, random: bool = True)[source]

Bases: MNIST

The color MNIST dataset is a modified version of the MNIST dataset where each digit is colored either red or green. The concept labels are the digit and the color of the digit. The task is to predict whether the digit is even or odd.

root

The root directory where the dataset is stored.

train

Whether to load the training or test split. Default is False.

transform

The transformations to apply to the images. Default is None.

target_transform

The transformations to apply to the target labels. Default is None.

download

Whether to download the dataset if it does not exist. Default is False.

random

Whether to colorize the digits randomly. Default is True.

class MNISTAddition(root, train, target_transform=None, download=True)[source]

Bases: MNIST

The MNIST addition dataset is a modified version of the MNIST dataset where each image is a concatenation of two MNIST images and the target label is the sum of the two digits. The concept label is a one-hot encoding of the two digits.

concept_names

The names of the concept labels.

task_names

The names of the task labels.

root

The root directory where the dataset is stored.

train

Whether to load the training or test split. Default is False.

transform

The transformations to apply to the images. Default is None.

target_transform

The transformations to apply to the target labels. Default is None.

download

Whether to download the dataset if it does not exist. Default is False.

name = 'mnist_addition'
n_concepts = 20
n_tasks = 19
concept_names = ['0_left', '1_left', '2_left', '3_left', '4_left', '5_left', '6_left', '7_left', '8_left', '9_left', '0_right', '1_right', '2_right', '3_right', '4_right', '5_right', '6_right', '7_right', '8_right', '9_right']
task_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']
transform = Compose(     ToTensor()     Normalize(mean=(0.1307,), std=(0.3081,)) )
input_shape = (1, 28, 56)
input_dim = 1568
class PartialMNISTAddition(root, train, target_transform=None, download=True)[source]

Bases: MNISTAddition

The partial MNIST addition dataset is a modified version of the MNIST addition dataset where the concept annotation is partial. The concept associated with the second digit is not provided.

name = 'partial_mnist_addition'
n_concepts = 10
concept_names = ['0_left', '1_left', '2_left', '3_left', '4_left', '5_left', '6_left', '7_left', '8_left', '9_left']
class MNISTEvenOdd(root, train, target_transform=None, download=True)[source]

Bases: MNIST

The MNIST even-odd dataset is a modified version of the MNIST dataset where the task is to predict whether the digit is even or odd. The concept label is a one-hot encoding of the digit.

concept_names

The names of the concept labels.

task_names

The names of the task labels.

root

The root directory where the dataset is stored.

train

Whether to load the training or test split. Default is False.

transform

The transformations to apply to the images. Default is None.

target_transform

The transformations to apply to the target labels. Default is None.

download

Whether to download the dataset if it does not exist. Default is False.

name = 'mnist_even_odd'
n_concepts = 10
n_tasks = 2
concept_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
task_names = ['odd', 'even']
transform = Compose(     ToTensor()     Normalize(mean=(0.1307,), std=(0.3081,)) )
input_shape = (1, 28, 28)
input_dim = 784

Image Datasets

class CelebADataset(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

Bases: ConceptDataset

Dataset class for CelebA.

CelebA is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. This class wraps torchvision’s CelebA dataset to work with the ConceptDataset framework. The dataset can be downloaded from the official website: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html.

Parameters:
  • root – Root directory where the dataset is stored or will be downloaded.

  • split – The split of the dataset to use (‘train’, ‘valid’, or ‘test’). Default is ‘train’.

  • transform – The transformations to apply to the images. Default is None.

  • download – Whether to download the dataset if it does not exist. Default is False.

  • task_label – The attribute(s) to use for the task. Default is ‘Attractive’.

  • concept_subset – Optional subset of concept labels to use.

  • label_descriptions – Optional dict mapping concept names to descriptions.

property raw_filenames: List[str]

List of raw filenames that must be present to skip downloading.

property processed_filenames: List[str]

List of processed filenames that will be created during build step.

download()[source]

Download CelebA images zip and annotation files from Google Drive.

Downloads the aligned and cropped face images archive and annotation files. Extraction is handled separately by the extract() method.

Note: Requires gdown package for Google Drive downloads.

maybe_extract()[source]

Extract the CelebA images archive.

Extracts img_align_celeba.zip to the raw celeba folder.

maybe_download()[source]

Download and extract the dataset if needed.

build()[source]

Build processed dataset: save concepts, annotations and splits metadata.

Images are not saved as they are already in the downloaded folder and will be loaded on-the-fly in __getitem__.

load_raw()[source]

Load raw processed files for the current split.

load()[source]

Load and optionally preprocess dataset.

property n_samples: int

Number of samples in the dataset.

property n_features: tuple

Shape of features in dataset’s input (excluding number of samples).

CelebA images are 218x178x3 (H x W x C) reordered to (C, H, W).

property shape: tuple

Shape of the input tensor (n_samples, C, H, W).

class CUBDataset(split='train', uncertain_concept_labels=False, root='./CUB200/', path_transform=None, sample_transform=None, concept_transform=None, label_transform=None, uncertainty_based_random_labels=False, unc_map=[{0: 0.5, 1: 0.5, 2: 0.5, 3: 0.75, 4: 1.0}, {0: 0.5, 1: 0.5, 2: 0.5, 3: 0.75, 4: 1.0}], selected_concepts=None, training_augment=True)[source]

Bases: Dataset

TODO

concept_weights()[source]

Calculate class imbalance ratio for binary attribute labels

class AwA2Dataset(root, training_augment=True, split='train', image_size=224, concept_transform=None, sample_transform=None, selected_concepts=None, seed=42)[source]

Bases: Dataset

Returns a compatible Torch Dataset object customized for the AwA2 dataset

Other Datasets

class CEBaBDataset(pre_trained_transformer='bert-base-uncased', batch_size=32)[source]

Bases: object

preprocess_function(examples)[source]
collator()[source]
class TrafficLights(root_dir='./data_cache/', split='train', concept_transform=None, class_dtype=<class 'float'>, img_transform=None, regenerate=False, sym_link=None, use_absolute_path=False, num_threads=4, n_samples=1000, seed=42, position_para_noise=50, position_perp_noise=20, error_probability=0.1, p_ambulance=0.2, min_num_cars=0, max_num_cars=7, resize_final_image=0.15, car_colors=['black', 'blue', 'burgundy', 'green', 'pink', 'purple', 'silver', 'white'], possible_starting_directions=['north', 'east', 'south', 'west'], thickness=100, light_scale=1.5, use_lights_sprites=False, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2, test_config_override_values=None, val_config_override_values=None, selected_concepts=None)[source]

Bases: Dataset

Synthetic traffic dataset

sample_array(real_idx)[source]