torch_concepts.data.datasets.celeba.CelebADataset

class CelebADataset(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

Dataset class for CelebA.

CelebA is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. This class wraps torchvision’s CelebA dataset to work with the ConceptDataset framework. The dataset can be downloaded from the official website: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html.

Parameters:
  • root – Root directory where the dataset is stored or will be downloaded.

  • split – The split of the dataset to use (‘train’, ‘valid’, or ‘test’). Default is ‘train’.

  • transform – The transformations to apply to the images. Default is None.

  • download – Whether to download the dataset if it does not exist. Default is False.

  • task_label – The attribute(s) to use for the task. Default is ‘Attractive’.

  • concept_subset – Optional subset of concept labels to use.

  • label_descriptions – Optional dict mapping concept names to descriptions.

__init__(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]

Methods

__init__([root, concept_subset, ...])

add_exogenous(name, value[, convert_precision])

add_scaler(key, scaler)

Add a scaler for preprocessing a specific tensor.

build()

Build processed dataset: save concepts, annotations and splits metadata.

download()

Download CelebA images zip and annotation files from Google Drive.

load()

Load and optionally preprocess dataset.

load_raw()

Load raw processed files for the current split.

maybe_build()

maybe_download()

Download and extract the dataset if needed.

maybe_extract()

Extract the CelebA images archive.

maybe_reduce_annotations(annotations[, ...])

Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use. If None, will use all concepts.

remove_exogenous(name)

set_concepts(concepts)

Set concept annotations for the dataset.

set_graph(graph)

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Attributes

annotations

Annotations for the concepts in the dataset.

concept_names

List of concept names in the dataset.

exogenous

Mapping of dataset's exogenous variables.

graph

Adjacency matrix of the causal graph between concepts.

has_concepts

Whether the dataset has concept annotations.

has_exogenous

Whether the dataset has exogenous information.

n_concepts

Number of concepts in the dataset.

n_exogenous

Number of exogenous variables in the dataset.

n_features

Shape of features in dataset's input (excluding number of samples).

n_samples

Number of samples in the dataset.

processed_filenames

List of processed filenames that will be created during build step.

processed_paths

The absolute paths of the processed files that must be present in order to skip building.

raw_filenames

List of raw filenames that must be present to skip downloading.

raw_paths

The absolute paths of the raw files that must be present in order to skip downloading.

root_dir

shape

Shape of the input tensor (n_samples, C, H, W).