torch_concepts.data.splitters.coloring.ColoringSplitter

class ColoringSplitter(root: str, seed: int | None = None, val_size: int | float = 0.1, test_size: int | float = 0.2)[source]

Coloring-based splitting strategy for distribution shift experiments.

Divides a dataset into train/val/test splits based on a pre-computed coloring scheme stored in a JSON file. This ensures that training and validation sets contain samples with ‘training’ coloring, while test sets contain samples with ‘test’ coloring.

This is useful for: - Out-of-distribution (OOD) evaluation - Domain adaptation experiments - Controlled distribution shift scenarios

Note: Assumes the dataset is already shuffled and that a coloring file exists at {root}/coloring_mode_seed_{seed}.json

Parameters:
  • root (str) – Root directory containing the coloring mode JSON file.

  • seed (int, optional) – Random seed used to identify the coloring file. Defaults to None.

  • val_size (Union[int, float], optional) – Validation set size (from ‘training’ colored samples). Defaults to 0.1.

  • test_size (Union[int, float], optional) – Test set size (from ‘test’ colored samples). Defaults to 0.2.

Example

>>> # Create a coloring file first: coloring_mode_seed_42.json
>>> # Format: {"0": "training", "1": "training", "2": "test", ...}
>>>
>>> splitter = ColoringSplitter(
...     root='data/my_dataset',
...     seed=42,
...     val_size=0.1,
...     test_size=0.2
... )
>>> splitter.fit(dataset)
>>> # Train/val from 'training' samples, test from 'test' samples
__init__(root: str, seed: int | None = None, val_size: int | float = 0.1, test_size: int | float = 0.2)[source]

Initialize the ColoringSplitter.

Parameters:
  • root (str) – Root directory containing coloring mode JSON file.

  • seed (int, optional) – Random seed to identify coloring file. File expected at {root}/coloring_mode_seed_{seed}.json. Defaults to None.

  • val_size – Validation set size (from ‘training’ samples). If float, represents fraction. If int, absolute count. Defaults to 0.1.

  • test_size – Test set size (from ‘test’ samples). If float, represents fraction. If int, absolute count. Defaults to 0.2.

Methods

__init__(root[, seed, val_size, test_size])

Initialize the ColoringSplitter.

fit(dataset)

Split dataset based on coloring scheme from JSON file.

reset()

set_indices([train, val, test])

split(dataset)

Attributes

fitted

indices

test_idxs

test_len

train_idxs

train_len

val_idxs

val_len