torch_concepts.data.DSpritesRegressionDataset

class DSpritesRegressionDataset(formulas: Dict[str, str], root: str | None = None, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

DSprites regression dataset with sympy formula-based targets.

Each sample is a 64x64 grayscale image of a simple shape with known generative factors (concepts). A per-shape sympy formula over the concept values produces the regression target.

Parameters:
  • formulas (dict) – Mapping from shape name (‘square’, ‘circle’, ‘heart’) to a sympy formula string using the concept column names as variables. Must be provided for all three shapes.

  • root (str, optional) – Root directory for caching. Default: './data/dsprites_regression'.

  • num_samples (int, optional) – Number of samples to subsample. Default: None (all).

  • seed (int, optional) – Random seed. Default: 42.

  • concept_subset (list of str, optional) – Subset of concept names. Default: None.

  • label_descriptions (dict, optional) – Optional dict mapping concept names to descriptions.

__init__(formulas: Dict[str, str], root: str | None = None, seed: int = 42, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

Methods

__init__(formulas[, root, seed, ...])

add_exogenous(name, value[, convert_precision])

add_scaler(key, scaler)

Add a scaler for preprocessing a specific tensor.

build()

Extract concepts, compute formula targets, save to disk.

collate(samples)

Collate samples into a batch, re-annotating the ground-truth concepts.

download()

"Download the dSprites dataset from the original source and save to root directory.

load()

Load and optionally preprocess dataset.

load_raw()

Loads raw dataset without any data preprocessing.

maybe_build()

maybe_download()

remove_exogenous(name)

set_concepts(concepts)

Set concept annotations for the dataset.

set_graph(graph)

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Attributes

annotations

Annotations for the concepts in the dataset.

concept_names

List of concept names in the dataset.

exogenous

Mapping of dataset's exogenous variables.

graph

Adjacency matrix of the causal graph between concepts.

has_concepts

Whether the dataset has concept annotations.

has_exogenous

Whether the dataset has exogenous information.

n_concepts

Number of concepts in the dataset.

n_exogenous

Number of exogenous variables in the dataset.

n_features

Shape of features in dataset's input (excluding number of samples).

n_samples

Number of samples in the dataset.

processed_filenames

The list of processed filenames in the self.root_dir folder that must be present in order to skip build().

processed_paths

The absolute paths of the processed files that must be present in order to skip building.

raw_filenames

The list of raw filenames in the self.root_dir folder that must be present in order to skip download().

raw_paths

The absolute paths of the raw files that must be present in order to skip downloading.

root_dir

shape

Shape of the input tensor.