torch_concepts.data.MNISTArithmeticDataset¶
- class MNISTArithmeticDataset(root: str | None = None, num_train_samples: int = 10000, num_test_samples: int = 2000, val_size: float = 0.1, img_size: int = 224, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
MNIST Arithmetic dataset for regression with concept annotations.
Composite images of two MNIST digits with an arithmetic operator between them. The concepts are the two digit values (treated as continuous). The regression task is the arithmetic result.
Images in the training/validation splits are composed from MNIST train digits, while test images are composed from MNIST test digits, ensuring no digit-level leakage between train and test.
- Parameters:
root (str, optional) – Root directory to store/load the dataset. Default:
'./data/mnist_arithmetic'.num_train_samples (int, optional) – Number of composite samples from MNIST train split (used for train + validation). Default: 10000.
num_test_samples (int, optional) – Number of composite samples from MNIST test split. Default: 2000.
val_size (float, optional) – Fraction of the train pool to use as validation. Default: 0.1.
img_size (int, optional) – Output image size (square). Default: 224.
seed (int, optional) – Random seed for reproducible generation. Default: 42.
label_descriptions (Optional dict mapping concept names to descriptions.)
- __init__(root: str | None = None, num_train_samples: int = 10000, num_test_samples: int = 2000, val_size: float = 0.1, img_size: int = 224, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
Methods
__init__([root, num_train_samples, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Generate composite arithmetic images from both MNIST splits and save metadata.
collate(samples)Collate samples into a batch, re-annotating the ground-truth concepts.
download()setup MNIST root and trigger MNIST download.
load()Loads raw dataset and preprocess data.
load_raw()Load raw processed files for the current split.
maybe_build()maybe_download()Download and extract the dataset if needed.
remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
annotationsAnnotations for the concepts in the dataset.
concept_namesList of concept names in the dataset.
exogenousMapping of dataset's exogenous variables.
graphAdjacency matrix of the causal graph between concepts.
has_conceptsWhether the dataset has concept annotations.
has_exogenousWhether the dataset has exogenous information.
n_conceptsNumber of concepts in the dataset.
n_exogenousNumber of exogenous variables in the dataset.
n_featuresShape of features in dataset's input (excluding number of samples).
n_samplesNumber of samples in the dataset.
processed_filenamesThe list of processed filenames in the
self.root_dirfolder that must be present in order to skip build().processed_pathsThe absolute paths of the processed files that must be present in order to skip building.
raw_filenamesThe list of raw filenames in the
self.root_dirfolder that must be present in order to skip download().raw_pathsThe absolute paths of the raw files that must be present in order to skip downloading.
root_dirshapeShape of the input tensor.