torch_concepts.data.CUBDataset¶
- class CUBDataset(root: str | None = None, image_size: int = 224, concept_subset: list | None = None, label_descriptions: Mapping | None = None)[source]¶
Dataset class for CUB-200-2011 (Caltech-UCSD Birds).
CUB-200-2011 contains 11,788 bird images across 200 species classes, annotated with 112 binary semantic attributes selected by Koh et al. [CBM Paper] from the full set of 312 CUB attributes.
Official train / val / test splits from the pre-processed pickle files are preserved; use
NativeSplitterin the corresponding datamodule.The concept vector per sample contains:
columns 0-111: 112 binary semantic attributes (cardinality 1 each)
column 112: bird species index 0-199 (cardinality 200)
- Parameters:
root (str, optional) – Root directory that contains
class_attr_data_10/andCUB_200_2011/. Defaults to./data/CUB200.image_size (int, optional) – Side length (px) images are resized to. Defaults to 224.
concept_subset (list of str, optional) – Subset of concept names to retain.
Nonekeeps all 113.label_descriptions (dict, optional) – Mapping from concept name to human-readable description.
- __init__(root: str | None = None, image_size: int = 224, concept_subset: list | None = None, label_descriptions: Mapping | None = None)[source]¶
Methods
__init__([root, image_size, concept_subset, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Process raw CUB pickle files and save cached dataset artefacts.
collate(samples)Collate samples into a batch, re-annotating the ground-truth concepts.
download()Downloads the CUB dataset if it is not already present.
load()Loads raw dataset and preprocess data.
load_raw()Load processed artefacts from disk.
maybe_build()maybe_download()remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
annotationsAnnotations for the concepts in the dataset.
concept_namesList of concept names in the dataset.
exogenousMapping of dataset's exogenous variables.
graphAdjacency matrix of the causal graph between concepts.
has_conceptsWhether the dataset has concept annotations.
has_exogenousWhether the dataset has exogenous information.
n_conceptsNumber of concepts in the dataset.
n_exogenousNumber of exogenous variables in the dataset.
n_featuresShape of features in dataset's input (excluding number of samples).
n_samplesNumber of samples in the dataset.
processed_filenamesThe list of processed filenames in the
self.root_dirfolder that must be present in order to skip build().processed_pathsThe absolute paths of the processed files that must be present in order to skip building.
raw_filenamesThe list of raw filenames in the
self.root_dirfolder that must be present in order to skip download().raw_pathsThe absolute paths of the raw files that must be present in order to skip downloading.
root_dirshapeShape of the input tensor.