torch_concepts.data.base.splitter.Splitter¶
- class Splitter[source]¶
Abstract base class for dataset splitting strategies.
Splitters divide a ConceptDataset into train, validation, and test splits. They store indices for each split and provide properties to access split sizes and indices. All concrete splitter implementations should inherit from this class and implement the fit() method.
Example
>>> class CustomSplitter(Splitter): ... def fit(self, dataset): ... n = len(dataset) ... self.set_indices( ... train=list(range(int(0.7*n))), ... val=list(range(int(0.7*n), int(0.9*n))), ... test=list(range(int(0.9*n), n)) ... ) ... self._fitted = True >>> >>> splitter = CustomSplitter() >>> splitter.fit(my_dataset) >>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}")
Methods
__init__()fit(dataset)Split the dataset into train/val/test sets.
reset()set_indices([train, val, test])split(dataset)Attributes