Source code for torch_concepts.data.splitters.random
"""Random data splitting for train/validation/test splits.
This module provides RandomSplitter for randomly dividing datasets into
standard train/val/test splits.
"""
from typing import Union
import numpy as np
from ..utils import resolve_size
from ..base.dataset import ConceptDataset
from ..base.splitter import Splitter
[docs]
class RandomSplitter(Splitter):
"""Random splitting strategy for datasets.
Randomly divides a dataset into train, validation, and test splits.
Ensures reproducibility when numpy's random seed is set externally
before calling fit().
The splitting is done in the following order:
1. Test (if test_size > 0)
2. Validation (if val_size > 0)
3. Training (remaining samples)
Args:
val_size (Union[int, float], optional): Size of validation set.
If float, represents fraction of dataset. If int, represents
absolute number of samples. Defaults to 0.1.
test_size (Union[int, float], optional): Size of test set.
If float, represents fraction of dataset. If int, represents
absolute number of samples. Defaults to 0.2.
Example:
>>> # 70% train, 10% val, 20% test
>>> splitter = RandomSplitter(val_size=0.1, test_size=0.2)
>>> splitter.fit(dataset)
>>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}, Test: {splitter.test_len}")
Train: 700, Val: 100, Test: 200
"""
[docs]
def __init__(
self,
val_size: Union[int, float] = 0.1,
test_size: Union[int, float] = 0.2,
):
"""Initialize the RandomSplitter.
Args:
val_size: Size of validation set. If float, represents fraction
of dataset. If int, represents absolute number of samples.
Defaults to 0.1.
test_size: Size of test set. If float, represents fraction
of dataset. If int, represents absolute number of samples.
Defaults to 0.2.
"""
super().__init__()
self.val_size = val_size
self.test_size = test_size
[docs]
def fit(self, dataset: ConceptDataset) -> None:
"""Randomly split the dataset into train/val/test sets.
Creates a random permutation of dataset indices and divides them
according to specified split sizes. Sets the _fitted flag to True
upon completion.
Args:
dataset: The ConceptDataset to split.
Raises:
ValueError: If split sizes exceed dataset size.
"""
n_samples = len(dataset)
# Resolve all sizes to absolute numbers
n_val = resolve_size(self.val_size, n_samples)
n_test = resolve_size(self.test_size, n_samples)
# Validate that splits don't exceed dataset size
total_split = n_val + n_test
if total_split > n_samples:
raise ValueError(
f"Split sizes sum to {total_split} but dataset has only "
f"{n_samples} samples. "
f"(val={n_val}, test={n_test})"
)
n_train = n_samples - total_split
# Create random permutation of indices
indices = np.random.permutation(n_samples)
# Split indices in order: test, val, train
test_idxs = indices[:n_test]
val_idxs = indices[n_test:n_test + n_val]
train_idxs = indices[n_test + n_val:]
# Store indices
self.set_indices(
train=train_idxs.tolist(),
val=val_idxs.tolist(),
test=test_idxs.tolist()
)
self._fitted = True
# Sanity check
assert len(self.train_idxs) == n_train, \
f"Expected {n_train} training samples, got {len(self.train_idxs)}"
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"train_size={self.train_len}, "
f"val_size={self.val_len}, "
f"test_size={self.test_len})"
)