Source code for torch_concepts.data.datamodules.cub
from ..datasets.cub import CUBDataset
from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType
from ..base.splitter import Splitter
from ..splitters.native import NativeSplitter
[docs]
class CUBDataModule(ConceptDataModule):
"""DataModule for CUB-200-2011 (Caltech-UCSD Birds).
Handles data loading, splitting, and batching for the CUB-200-2011 dataset
with support for concept-based learning. CUB-200-2011 provides official
train / val / test splits via the Koh et al. pre-processed pickle files,
so :class:`~torch_concepts.data.splitters.NativeSplitter` is used by
default.
.. note::
CUB-200-2011 must be **manually downloaded** before use.
See :class:`~torch_concepts.data.datasets.CUBDataset` for instructions.
Parameters
----------
root : str, optional
Root directory containing ``class_attr_data_10/`` and
``CUB_200_2011/``. Default: ``None`` (auto-creates ``./data/CUB200``).
image_size : int, optional
Side length (px) to resize images to. Default: 224.
splitter : Splitter, optional
Splitting strategy. Default: ``NativeSplitter()`` (uses the official
train / val / test splits from the pickle files).
batch_size : int, optional
Number of samples per batch. Default: 512.
backbone : BackboneType, optional
Backbone model for feature extraction (e.g. ``'resnet50'``).
Default: ``None``.
precompute_embs : bool, optional
Whether to precompute and cache backbone embeddings. Default: ``True``.
force_recompute : bool, optional
Recompute embeddings even if a cache exists. Default: ``False``.
concept_subset : list of str, optional
Subset of concept names to retain. Default: ``None`` (all 113).
label_descriptions : dict, optional
Mapping from concept name to human-readable description.
workers : int, optional
Number of data-loading worker processes. Default: 0.
Examples
--------
>>> from torch_concepts.data import CUBDataModule
>>>
>>> dm = CUBDataModule(
... root="./data/CUB200",
... backbone="resnet50",
... precompute_embs=True,
... batch_size=64,
... )
>>> dm.setup()
>>> train_loader = dm.train_dataloader()
See Also
--------
CUBDataset : The underlying dataset class.
ConceptDataModule : Parent class with common datamodule functionality.
"""
[docs]
def __init__(
self,
root: str = None,
image_size: int = 224,
splitter: Splitter = NativeSplitter(),
batch_size: int = 512,
backbone: BackboneType = None,
precompute_embs: bool = True,
force_recompute: bool = False,
concept_subset: list | None = None,
label_descriptions: dict | None = None,
workers: int = 0,
**kwargs,
):
dataset = CUBDataset(
root=root,
image_size=image_size,
concept_subset=concept_subset,
label_descriptions=label_descriptions,
)
super().__init__(
dataset=dataset,
batch_size=batch_size,
backbone=backbone,
precompute_embs=precompute_embs,
force_recompute=force_recompute,
workers=workers,
splitter=splitter,
)