torch_concepts.data.DSpritesRegressionDataset¶
- class DSpritesRegressionDataset(formulas: Dict[str, str], root: str | None = None, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
DSprites regression dataset with sympy formula-based targets.
Each sample is a 64x64 grayscale image of a simple shape with known generative factors (concepts). A per-shape sympy formula over the concept values produces the regression target.
- Parameters:
formulas (dict) – Mapping from shape name (‘square’, ‘circle’, ‘heart’) to a sympy formula string using the concept column names as variables. Must be provided for all three shapes.
root (str, optional) – Root directory for caching. Default:
'./data/dsprites_regression'.num_samples (int, optional) – Number of samples to subsample. Default: None (all).
seed (int, optional) – Random seed. Default: 42.
concept_subset (list of str, optional) – Subset of concept names. Default: None.
label_descriptions (dict, optional) – Optional dict mapping concept names to descriptions.
- __init__(formulas: Dict[str, str], root: str | None = None, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
Methods
__init__(formulas[, root, seed, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Extract concepts, compute formula targets, save to disk.
collate(samples)Collate samples into a batch, re-annotating the ground-truth concepts.
download()"Download the dSprites dataset from the original source and save to root directory.
load()Load and optionally preprocess dataset.
load_raw()Loads raw dataset without any data preprocessing.
maybe_build()maybe_download()remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
annotationsAnnotations for the concepts in the dataset.
concept_namesList of concept names in the dataset.
exogenousMapping of dataset's exogenous variables.
graphAdjacency matrix of the causal graph between concepts.
has_conceptsWhether the dataset has concept annotations.
has_exogenousWhether the dataset has exogenous information.
n_conceptsNumber of concepts in the dataset.
n_exogenousNumber of exogenous variables in the dataset.
n_featuresShape of features in dataset's input (excluding number of samples).
n_samplesNumber of samples in the dataset.
processed_filenamesThe list of processed filenames in the
self.root_dirfolder that must be present in order to skip build().processed_pathsThe absolute paths of the processed files that must be present in order to skip building.
raw_filenamesThe list of raw filenames in the
self.root_dirfolder that must be present in order to skip download().raw_pathsThe absolute paths of the raw files that must be present in order to skip downloading.
root_dirshapeShape of the input tensor.