Source code for torch_concepts.data.datamodules.pendulum

from ..datasets.pendulum import PendulumDataset

from ..base.datamodule import ConceptDataModule
from ...typing import BackboneType
from ..base.splitter import Splitter
from ..splitters import RandomSplitter


[docs] class PendulumDataModule(ConceptDataModule): """DataModule for Pendulum dataset with concept-based learning support. Handles data loading, splitting, and batching for the procedurally generated pendulum scene dataset. Supports precomputing backbone embeddings and flexible train/val/test splitting strategies. Parameters ---------- root : str, optional Root directory where the dataset is stored or will be generated. Default: None (auto-creates ``./data/pendulum``). n_theta : int, optional Number of theta angle steps for generation. Default: 100 n_phi : int, optional Number of phi angle steps for generation. Default: 1000 seed : int, optional Random seed for the train/val/test split. Default: 42 generation_seed : int, optional Random seed for data generation. Default: 42 splitter : Splitter, optional Splitting strategy for train/val/test partitioning. Default: RandomSplitter() val_size : int or float, optional Validation set size. Default: 0.1 test_size : int or float, optional Test set size. Default: 0.2 batch_size : int, optional Number of samples per batch. Default: 512 backbone : BackboneType, optional Backbone model for feature extraction. Default: None precompute_embs : bool, optional Whether to precompute and cache backbone embeddings. Default: True force_recompute : bool, optional If True, recompute embeddings even if cached. Default: False concept_subset : list of str, optional Subset of concept names to use. Default: None workers : int, optional Number of data loading workers. Default: 0 Examples -------- >>> from torch_concepts.data import PendulumDataModule >>> >>> dm = PendulumDataModule( ... root='./data/pendulum', ... n_theta=10, n_phi=10, ... batch_size=32, seed=42, ... ) >>> dm.setup() >>> train_loader = dm.train_dataloader() See Also -------- PendulumDataset : The underlying dataset class ConceptDataModule : Parent class with common datamodule functionality """
[docs] def __init__( self, root: str = None, n_theta: int = 100, n_phi: int = 1000, seed: int = 42, generation_seed: int = 42, splitter: Splitter = RandomSplitter(), val_size: int | float = 0.1, test_size: int | float = 0.2, batch_size: int = 512, backbone: BackboneType = None, precompute_embs: bool = True, force_recompute: bool = False, concept_subset: list | None = None, label_descriptions: dict | None = None, workers: int = 0, **kwargs ): dataset = PendulumDataset( root=root, n_theta=n_theta, n_phi=n_phi, seed=generation_seed, concept_subset=concept_subset, label_descriptions=label_descriptions, ) super().__init__( dataset=dataset, val_size=val_size, test_size=test_size, batch_size=batch_size, backbone=backbone, precompute_embs=precompute_embs, force_recompute=force_recompute, workers=workers, splitter=splitter, seed=seed, )