torch_concepts.data.CUBDataset

class CUBDataset(root: str | None = None, image_size: int = 224, concept_subset: list | None = None, label_descriptions: Mapping | None = None)[source]

Dataset class for CUB-200-2011 (Caltech-UCSD Birds).

CUB-200-2011 contains 11,788 bird images across 200 species classes, annotated with 112 binary semantic attributes selected by Koh et al. [CBM Paper] from the full set of 312 CUB attributes.

Official train / val / test splits from the pre-processed pickle files are preserved; use NativeSplitter in the corresponding datamodule.

The concept vector per sample contains:

  • columns 0-111: 112 binary semantic attributes (cardinality 1 each)

  • column 112: bird species index 0-199 (cardinality 200)

Parameters:
  • root (str, optional) – Root directory that contains class_attr_data_10/ and CUB_200_2011/. Defaults to ./data/CUB200.

  • image_size (int, optional) – Side length (px) images are resized to. Defaults to 224.

  • concept_subset (list of str, optional) – Subset of concept names to retain. None keeps all 113.

  • label_descriptions (dict, optional) – Mapping from concept name to human-readable description.

__init__(root: str | None = None, image_size: int = 224, concept_subset: list | None = None, label_descriptions: Mapping | None = None)[source]

Methods

__init__([root, image_size, concept_subset, ...])

add_exogenous(name, value[, convert_precision])

add_scaler(key, scaler)

Add a scaler for preprocessing a specific tensor.

build()

Process raw CUB pickle files and save cached dataset artefacts.

collate(samples)

Collate samples into a batch, re-annotating the ground-truth concepts.

download()

Downloads the CUB dataset if it is not already present.

load()

Loads raw dataset and preprocess data.

load_raw()

Load processed artefacts from disk.

maybe_build()

maybe_download()

remove_exogenous(name)

set_concepts(concepts)

Set concept annotations for the dataset.

set_graph(graph)

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Attributes

annotations

Annotations for the concepts in the dataset.

concept_names

List of concept names in the dataset.

exogenous

Mapping of dataset's exogenous variables.

graph

Adjacency matrix of the causal graph between concepts.

has_concepts

Whether the dataset has concept annotations.

has_exogenous

Whether the dataset has exogenous information.

n_concepts

Number of concepts in the dataset.

n_exogenous

Number of exogenous variables in the dataset.

n_features

Shape of features in dataset's input (excluding number of samples).

n_samples

Number of samples in the dataset.

processed_filenames

The list of processed filenames in the self.root_dir folder that must be present in order to skip build().

processed_paths

The absolute paths of the processed files that must be present in order to skip building.

raw_filenames

The list of raw filenames in the self.root_dir folder that must be present in order to skip download().

raw_paths

The absolute paths of the raw files that must be present in order to skip downloading.

root_dir

shape

Shape of the input tensor.