torch_concepts.data.CelebADataset¶
- class CelebADataset(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
Dataset class for CelebA.
CelebA is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. This class wraps torchvision’s CelebA dataset to work with the ConceptDataset framework. The dataset can be downloaded from the official website: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html.
- Parameters:
root – Root directory where the dataset is stored or will be downloaded.
concept_subset – Optional subset of concept labels to use.
label_descriptions – Optional dict mapping concept names to descriptions.
- __init__(root: str | None = None, concept_subset: list | None = None, label_descriptions: dict | None = None)[source]¶
Methods
__init__([root, concept_subset, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Build processed dataset: save concepts, annotations and splits metadata.
collate(samples)Collate samples into a batch, re-annotating the ground-truth concepts.
download()Download CelebA images zip and annotation files from Google Drive.
load()Load and optionally preprocess dataset.
load_raw()Load raw processed files for the current split.
maybe_build()maybe_download()Download and extract the dataset if needed.
maybe_extract()Extract the CelebA images archive.
remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
annotationsAnnotations for the concepts in the dataset.
concept_namesList of concept names in the dataset.
exogenousMapping of dataset's exogenous variables.
graphAdjacency matrix of the causal graph between concepts.
has_conceptsWhether the dataset has concept annotations.
has_exogenousWhether the dataset has exogenous information.
n_conceptsNumber of concepts in the dataset.
n_exogenousNumber of exogenous variables in the dataset.
n_featuresShape of features in dataset's input (excluding number of samples).
n_samplesNumber of samples in the dataset.
processed_filenamesList of processed filenames that will be created during build step.
processed_pathsThe absolute paths of the processed files that must be present in order to skip building.
raw_filenamesList of raw filenames that must be present to skip downloading.
raw_pathsThe absolute paths of the raw files that must be present in order to skip downloading.
root_dirshapeShape of the input tensor (n_samples, C, H, W).