torch_concepts.data.AWA2DataModule¶
- class AWA2DataModule(root: str | None = None, seed: int = 42, image_size: int = 224, val_size: float = 0.1, test_size: float = 0.2, splitter: Splitter = RandomSplitter(train_size=None, val_size=None, test_size=None), batch_size: int = 512, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = True, force_recompute: bool = False, concept_subset: list | None = None, label_descriptions: dict | None = None, workers: int = 0, **kwargs)[source]¶
DataModule for Animals with Attributes 2 (AwA2).
Handles data loading, splitting, and batching for the AwA2 dataset with support for concept-based learning. Since AwA2 has no official train/val/test split, splitting is performed by the datamodule using
RandomSplitterby default.- Parameters:
root (str, optional) – Root directory where the AwA2 data is stored. Default:
None(auto-creates./data/AWA2).seed (int, optional) – Random seed for train / val / test split. Default: 42.
image_size (int, optional) – Side length (px) to resize images to. Default: 224.
val_size (float, optional) – Fraction of samples for validation. Default: 0.1.
test_size (float, optional) – Fraction of samples for test. Default: 0.2.
splitter (Splitter, optional) – Splitting strategy. Default:
RandomSplitter()(no official split exists for AwA2, so the datamodule owns the split).batch_size (int, optional) – Number of samples per batch. Default: 512.
backbone (BackboneType, optional) – Backbone model for feature extraction (e.g.
'resnet50'). Default:None.precompute_embs (bool, optional) – Whether to precompute and cache backbone embeddings. Default:
True.force_recompute (bool, optional) – Recompute embeddings even if a cache exists. Default:
False.concept_subset (list of str, optional) – Subset of concept names to retain. Default:
None(all 86).label_descriptions (dict, optional) – Mapping from concept name to human-readable description.
workers (int, optional) – Number of data-loading worker processes. Default: 0.
Examples
>>> from torch_concepts.data import AWA2DataModule >>> >>> dm = AWA2DataModule( ... root="./data/AWA2", ... backbone="resnet50", ... precompute_embs=True, ... batch_size=64, ... ) >>> dm.setup() >>> train_loader = dm.train_dataloader()
See also
AWA2DatasetThe underlying dataset class.
ConceptDataModuleParent class with common datamodule functionality.
- __init__(root: str | None = None, seed: int = 42, image_size: int = 224, val_size: float = 0.1, test_size: float = 0.2, splitter: Splitter = RandomSplitter(train_size=None, val_size=None, test_size=None), batch_size: int = 512, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = True, force_recompute: bool = False, concept_subset: list | None = None, label_descriptions: dict | None = None, workers: int = 0, **kwargs)[source]¶
- prepare_data_per_node¶
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices¶
If True, dataloader with zero length within local rank is allowed. Default value is False.
Methods
__init__([root, seed, image_size, val_size, ...])from_datasets([train_dataset, val_dataset, ...])Create an instance from torch.utils.data.Dataset.
get_dataloader([split, shuffle, batch_size])Get the DataLoader for a specific split.
load_from_checkpoint(checkpoint_path[, ...])Primary way of loading a datamodule from a checkpoint.
load_state_dict(state_dict)Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
on_after_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_exception(exception)Called when the trainer execution is interrupted by an exception.
predict_dataloader()An iterable or collection of iterables specifying prediction samples.
prepare_data()Use this to download and prepare data.
remove_ignored_hparams(ignore_list)Remove ignored hyperparameters from the stored state.
save_hyperparameters(*args[, ignore, frame, ...])Save arguments to
hparamsattribute.setup([stage, backbone_device, verbose])Prepare the data for training, validation, or testing.
state_dict()Called when saving a checkpoint, implement to generate and save datamodule state.
teardown(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader([shuffle, batch_size])Get the test DataLoader.
train_dataloader([shuffle, batch_size])Get the training DataLoader.
transfer_batch_to_device(batch, device, ...)Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.val_dataloader([shuffle, batch_size])Get the validation DataLoader.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPEbackboneThe backbone model wrapper for feature extraction.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().n_samplesTotal number of samples in the dataset.
nametest_lenNumber of samples in the test set.
testsetThe test subset.
train_lenNumber of samples in the training set.
trainsetThe training subset.
val_lenNumber of samples in the validation set.
valsetThe validation subset.