Source code for torch_concepts.data.base.splitter

"""Abstract base class for dataset splitting strategies.

This module defines the Splitter interface for dividing datasets into
train/val/test splits. Splitters manage indices and ensure reproducible
splits through random seeds.
"""

from abc import ABC, abstractmethod

from .dataset import ConceptDataset

[docs] class Splitter(ABC): """Abstract base class for dataset splitting strategies. Splitters divide a ConceptDataset into train, validation, and test splits. They store indices for each split and provide properties to access split sizes and indices. All concrete splitter implementations should inherit from this class and implement the fit() method. Attributes: train_idxs (list): Training set indices. val_idxs (list): Validation set indices. test_idxs (list): Test set indices. Example: >>> class CustomSplitter(Splitter): ... def fit(self, dataset): ... n = len(dataset) ... self.set_indices( ... train=list(range(int(0.7*n))), ... val=list(range(int(0.7*n), int(0.9*n))), ... test=list(range(int(0.9*n), n)) ... ) ... self._fitted = True >>> >>> splitter = CustomSplitter() >>> splitter.fit(my_dataset) >>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}") """
[docs] def __init__(self): self.__indices = dict() self._fitted = False self.reset()
@property def indices(self): return self.__indices @property def fitted(self): return self._fitted @property def train_idxs(self): return self.__indices.get('train') @property def val_idxs(self): return self.__indices.get('val') @property def test_idxs(self): return self.__indices.get('test') @property def train_len(self): return len(self.train_idxs) if self.train_idxs is not None else None @property def val_len(self): return len(self.val_idxs) if self.val_idxs is not None else None @property def test_len(self): return len(self.test_idxs) if self.test_idxs is not None else None
[docs] def set_indices(self, train=None, val=None, test=None): if train is not None: self.__indices['train'] = train if val is not None: self.__indices['val'] = val if test is not None: self.__indices['test'] = test
[docs] def reset(self): self.__indices = dict(train=None, val=None, test=None) self._fitted = False
[docs] @abstractmethod def fit(self, dataset: ConceptDataset): """Split the dataset into train/val/test sets. This method should set the following attributes: - self.train_idxs: List of training indices - self.val_idxs: List of validation indices - self.test_idxs: List of test indices Args: dataset: The dataset to split. """ raise NotImplementedError
[docs] def split(self, dataset: ConceptDataset) -> None: if self.fitted: return self.indices else: return self.fit(dataset)