torch_concepts.data.datasets.celeba.CelebADataset¶
- class CelebADataset(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
Dataset class for CelebA.
CelebA is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. This class wraps torchvision’s CelebA dataset to work with the ConceptDataset framework. The dataset can be downloaded from the official website: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html.
- Parameters:
root – Root directory where the dataset is stored or will be downloaded.
split – The split of the dataset to use (‘train’, ‘valid’, or ‘test’). Default is ‘train’.
transform – The transformations to apply to the images. Default is None.
download – Whether to download the dataset if it does not exist. Default is False.
task_label – The attribute(s) to use for the task. Default is ‘Attractive’.
concept_subset – Optional subset of concept labels to use.
label_descriptions – Optional dict mapping concept names to descriptions.
- __init__(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
Methods
__init__([root, concept_subset, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Build processed dataset: save concepts, annotations and splits metadata.
download()Download CelebA images zip and annotation files from Google Drive.
load()Load and optionally preprocess dataset.
load_raw()Load raw processed files for the current split.
maybe_build()Download and extract the dataset if needed.
Extract the CelebA images archive.
maybe_reduce_annotations(annotations[, ...])Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use. If
None, will use all concepts.remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
annotationsAnnotations for the concepts in the dataset.
concept_namesList of concept names in the dataset.
exogenousMapping of dataset's exogenous variables.
graphAdjacency matrix of the causal graph between concepts.
has_conceptsWhether the dataset has concept annotations.
has_exogenousWhether the dataset has exogenous information.
n_conceptsNumber of concepts in the dataset.
n_exogenousNumber of exogenous variables in the dataset.
Shape of features in dataset's input (excluding number of samples).
Number of samples in the dataset.
List of processed filenames that will be created during build step.
processed_pathsThe absolute paths of the processed files that must be present in order to skip building.
List of raw filenames that must be present to skip downloading.
raw_pathsThe absolute paths of the raw files that must be present in order to skip downloading.
root_dirShape of the input tensor (n_samples, C, H, W).