torch_concepts.data.base.splitter.Splitter

class Splitter[source]

Abstract base class for dataset splitting strategies.

Splitters divide a ConceptDataset into train, validation, and test splits. They store indices for each split and provide properties to access split sizes and indices. All concrete splitter implementations should inherit from this class and implement the fit() method.

train_idxs

Training set indices.

Type:

list

val_idxs

Validation set indices.

Type:

list

test_idxs

Test set indices.

Type:

list

Example

>>> class CustomSplitter(Splitter):
...     def fit(self, dataset):
...         n = len(dataset)
...         self.set_indices(
...             train=list(range(int(0.7*n))),
...             val=list(range(int(0.7*n), int(0.9*n))),
...             test=list(range(int(0.9*n), n))
...         )
...         self._fitted = True
>>>
>>> splitter = CustomSplitter()
>>> splitter.fit(my_dataset)
>>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}")
__init__()[source]

Methods

__init__()

fit(dataset)

Split the dataset into train/val/test sets.

reset()

set_indices([train, val, test])

split(dataset)

Attributes