torch_concepts.data.MNISTArithmeticDataset

class MNISTArithmeticDataset(root: str | None = None, num_train_samples: int = 10000, num_test_samples: int = 2000, val_size: float = 0.1, img_size: int = 224, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

MNIST Arithmetic dataset for regression with concept annotations.

Composite images of two MNIST digits with an arithmetic operator between them. The concepts are the two digit values (treated as continuous). The regression task is the arithmetic result.

Images in the training/validation splits are composed from MNIST train digits, while test images are composed from MNIST test digits, ensuring no digit-level leakage between train and test.

Parameters:
  • root (str, optional) – Root directory to store/load the dataset. Default: './data/mnist_arithmetic'.

  • num_train_samples (int, optional) – Number of composite samples from MNIST train split (used for train + validation). Default: 10000.

  • num_test_samples (int, optional) – Number of composite samples from MNIST test split. Default: 2000.

  • val_size (float, optional) – Fraction of the train pool to use as validation. Default: 0.1.

  • img_size (int, optional) – Output image size (square). Default: 224.

  • seed (int, optional) – Random seed for reproducible generation. Default: 42.

  • label_descriptions (Optional dict mapping concept names to descriptions.)

__init__(root: str | None = None, num_train_samples: int = 10000, num_test_samples: int = 2000, val_size: float = 0.1, img_size: int = 224, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

Methods

__init__([root, num_train_samples, ...])

add_exogenous(name, value[, convert_precision])

add_scaler(key, scaler)

Add a scaler for preprocessing a specific tensor.

build()

Generate composite arithmetic images from both MNIST splits and save metadata.

collate(samples)

Collate samples into a batch, re-annotating the ground-truth concepts.

download()

setup MNIST root and trigger MNIST download.

load()

Loads raw dataset and preprocess data.

load_raw()

Load raw processed files for the current split.

maybe_build()

maybe_download()

Download and extract the dataset if needed.

remove_exogenous(name)

set_concepts(concepts)

Set concept annotations for the dataset.

set_graph(graph)

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Attributes

annotations

Annotations for the concepts in the dataset.

concept_names

List of concept names in the dataset.

exogenous

Mapping of dataset's exogenous variables.

graph

Adjacency matrix of the causal graph between concepts.

has_concepts

Whether the dataset has concept annotations.

has_exogenous

Whether the dataset has exogenous information.

n_concepts

Number of concepts in the dataset.

n_exogenous

Number of exogenous variables in the dataset.

n_features

Shape of features in dataset's input (excluding number of samples).

n_samples

Number of samples in the dataset.

processed_filenames

The list of processed filenames in the self.root_dir folder that must be present in order to skip build().

processed_paths

The absolute paths of the processed files that must be present in order to skip building.

raw_filenames

The list of raw filenames in the self.root_dir folder that must be present in order to skip download().

raw_paths

The absolute paths of the raw files that must be present in order to skip downloading.

root_dir

shape

Shape of the input tensor.