torch_concepts.data.splitters.coloring.ColoringSplitter¶
- class ColoringSplitter(root: str, seed: int | None = None, val_size: int | float = 0.1, test_size: int | float = 0.2)[source]¶
Coloring-based splitting strategy for distribution shift experiments.
Divides a dataset into train/val/test splits based on a pre-computed coloring scheme stored in a JSON file. This ensures that training and validation sets contain samples with ‘training’ coloring, while test sets contain samples with ‘test’ coloring.
This is useful for: - Out-of-distribution (OOD) evaluation - Domain adaptation experiments - Controlled distribution shift scenarios
Note: Assumes the dataset is already shuffled and that a coloring file exists at {root}/coloring_mode_seed_{seed}.json
- Parameters:
root (str) – Root directory containing the coloring mode JSON file.
seed (int, optional) – Random seed used to identify the coloring file. Defaults to None.
val_size (Union[int, float], optional) – Validation set size (from ‘training’ colored samples). Defaults to 0.1.
test_size (Union[int, float], optional) – Test set size (from ‘test’ colored samples). Defaults to 0.2.
Example
>>> # Create a coloring file first: coloring_mode_seed_42.json >>> # Format: {"0": "training", "1": "training", "2": "test", ...} >>> >>> splitter = ColoringSplitter( ... root='data/my_dataset', ... seed=42, ... val_size=0.1, ... test_size=0.2 ... ) >>> splitter.fit(dataset) >>> # Train/val from 'training' samples, test from 'test' samples
- __init__(root: str, seed: int | None = None, val_size: int | float = 0.1, test_size: int | float = 0.2)[source]¶
Initialize the ColoringSplitter.
- Parameters:
root (str) – Root directory containing coloring mode JSON file.
seed (int, optional) – Random seed to identify coloring file. File expected at {root}/coloring_mode_seed_{seed}.json. Defaults to None.
val_size – Validation set size (from ‘training’ samples). If float, represents fraction. If int, absolute count. Defaults to 0.1.
test_size – Test set size (from ‘test’ samples). If float, represents fraction. If int, absolute count. Defaults to 0.2.
Methods
__init__(root[, seed, val_size, test_size])Initialize the ColoringSplitter.
fit(dataset)Split dataset based on coloring scheme from JSON file.
reset()set_indices([train, val, test])split(dataset)Attributes
fittedindicestest_idxstest_lentrain_idxstrain_lenval_idxsval_len