Data Base Classes

This module provides base classes for data handling in concept-based models.

Summary

Dataset Base Classes

ConceptDataset

Base class for concept-annotated datasets.

DataModule Base Classes

ConceptDataModule

PyTorch Lightning DataModule for concept-based datasets.

Scaler Base Classes

Scaler

Abstract base class for data scaling transformations.

Splitter Base Classes

Splitter

Abstract base class for dataset splitting strategies.

Class Documentation

Dataset Classes

class ConceptDataset(input_data: ndarray | DataFrame | Tensor, concepts: ndarray | DataFrame | Tensor, annotations: Annotations | None = None, graph: DataFrame | None = None, concept_names_subset: List[str] | None = None, precision: int | str = 32, name: str | None = None)[source]

Bases: Dataset

Base class for concept-annotated datasets.

This class extends PyTorch’s Dataset to support concept annotations, concept graphs, and various metadata. It provides a unified interface for working with datasets that have both input features and concept labels.

name

Name of the dataset.

Type:

str

precision

Numerical precision for tensors (16, 32, or 64).

Type:

int or str

input_data

Input features/images.

Type:

Tensor

concepts

Concept annotations.

Type:

Tensor

annotations

Detailed concept annotations with metadata.

Type:

Annotations

Parameters:
  • input_data – Input features as numpy array, pandas DataFrame, or Tensor.

  • concepts – Concept annotations as numpy array, pandas DataFrame, or Tensor.

  • annotations – Optional Annotations object with concept metadata.

  • graph – Optional concept graph as pandas DataFrame or tensor.

  • concept_names_subset – Optional list to select subset of concepts.

  • precision – Numerical precision (16, 32, or 64, default: 32).

  • name – Optional dataset name.

  • exogenous – Optional exogenous variables (not yet implemented).

Raises:
  • ValueError – If concepts is None or annotations don’t include axis 1.

  • NotImplementedError – If continuous concepts or exogenous variables are used.

Example

>>> X = torch.randn(100, 28, 28)  # 100 images
>>> C = torch.randint(0, 2, (100, 5))  # 5 binary concepts
>>> annotations = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'c3', 'c4', 'c5'])})
>>> dataset = ConceptDataset(X, C, annotations=annotations)
>>> len(dataset)
100
property n_samples: int

Number of samples in the dataset.

Returns:

Number of samples.

Return type:

int

property n_features: tuple

Shape of features in dataset’s input (excluding number of samples).

Returns:

Shape of input features.

Return type:

tuple

property n_concepts: int

Number of concepts in the dataset.

Returns:

Number of concepts, or 0 if no concepts.

Return type:

int

property concept_names: List[str]

List of concept names in the dataset.

Returns:

Names of all concepts.

Return type:

List[str]

property annotations: Annotations | None

Annotations for the concepts in the dataset.

property shape: tuple

Shape of the input tensor.

property exogenous: Dict[str, Tensor]

Mapping of dataset’s exogenous variables.

property n_exogenous: int

Number of exogenous variables in the dataset.

property graph: ConceptGraph | None

Adjacency matrix of the causal graph between concepts.

property has_exogenous: bool

Whether the dataset has exogenous information.

property has_concepts: bool

Whether the dataset has concept annotations.

property root_dir: str
abstract property raw_filenames: List[str]

The list of raw filenames in the self.root_dir folder that must be present in order to skip download(). Should be implemented by subclasses.

abstract property processed_filenames: List[str]

The list of processed filenames in the self.root_dir folder that must be present in order to skip build(). Should be implemented by subclasses.

property raw_paths: List[str]

The absolute paths of the raw files that must be present in order to skip downloading.

property processed_paths: List[str]

The absolute paths of the processed files that must be present in order to skip building.

maybe_download()[source]
maybe_build()[source]
download() None[source]

Downloads dataset’s files to the self.root_dir folder.

build() None[source]

Eventually build the dataset from raw data to self.root_dir folder.

load_raw(*args, **kwargs)[source]

Loads raw dataset without any data preprocessing.

load(*args, **kwargs)[source]

Loads raw dataset and preprocess data. Default to load_raw.

maybe_reduce_annotations(annotations: Annotations, concept_names_subset: List[str] | None = None)[source]

Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use.

If None, will use all concepts.

set_graph(graph: DataFrame)[source]

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Parameters:

graph – A pandas DataFrame representing the adjacency matrix of the causal graph. Rows and columns should be named after the variables in the dataset.

set_concepts(concepts: ndarray | DataFrame | Tensor)[source]

Set concept annotations for the dataset.

Parameters:
  • concepts – Tensor of shape (n_samples, n_concepts) containing concept values

  • concept_names – List of strings naming each concept. If None, will use numbered concepts like “concept_0”, “concept_1”, etc.

add_exogenous(name: str, value: ndarray | DataFrame | Tensor, convert_precision: bool = True)[source]
remove_exogenous(name: str)[source]
add_scaler(key: str, scaler)[source]

Add a scaler for preprocessing a specific tensor.

Parameters:
  • key (str) – The name of the tensor to scale (‘input’, ‘concepts’).

  • scaler (Scaler) – The fitted scaler to use.

DataModule Classes

class ConceptDataModule(dataset: ConceptDataset, val_size: float = 0.1, test_size: float = 0.2, batch_size: int = 64, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, scalers: Mapping | None = None, splitter: object | None = None, workers: int = 0, pin_memory: bool = False)[source]

Bases: LightningDataModule

PyTorch Lightning DataModule for concept-based datasets.

Handles the complete data pipeline for concept-based learning:

  1. Data splitting: Train/validation/test splits using configurable splitters

  2. Embedding precomputation: Optional backbone feature extraction with caching

  3. Data scaling: Optional normalization through configurable scalers

  4. DataLoader creation: Efficient data loading with proper configurations

The datamodule automatically caches computed embeddings to disk, allowing fast reloading on subsequent runs without recomputation.

Parameters:
  • dataset (ConceptDataset) – Complete dataset to be split and processed.

  • val_size (float, optional) – Validation set fraction (0.0 to 1.0). Default is 0.1.

  • test_size (float, optional) – Test set fraction (0.0 to 1.0). Default is 0.2.

  • batch_size (int, optional) – Mini-batch size for DataLoaders. Default is 64.

  • backbone (str or None, optional) –

    Feature extraction model name. Can be:

    • HuggingFace model: ‘facebook/dinov2-base’, ‘google/vit-base-patch16-224’

    • torchvision model: ‘resnet18’, ‘resnet50’, ‘vgg16’, ‘efficientnet_b0’

    If provided with precompute_embs=True, embeddings are computed and cached to disk. Default is None.

  • precompute_embs (bool, optional) – If True and backbone is provided, precompute and cache backbone embeddings before training. Embeddings are saved to {dataset.root_dir}/{backbone_filename}.pt. Default is False.

  • force_recompute (bool, optional) – If True, recompute embeddings even if cached file exists. Useful when the dataset or backbone changes. Default is False.

  • scalers (Mapping or None, optional) – Dictionary of custom scalers for data normalization. Keys should match target keys in the batch (e.g., ‘input’, ‘concepts’). If None, no scaling is applied. Default is None.

  • splitter (object or None, optional) – Custom splitter for train/val/test splits. Must implement a split(dataset) method that sets train_idxs, val_idxs, and test_idxs attributes. If None, uses RandomSplitter with the specified val_size and test_size. Default is None.

  • workers (int, optional) – Number of subprocesses for data loading. 0 means data will be loaded in the main process. Default is 0.

  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into pinned memory before returning them. Useful for GPU training. Default is False.

dataset

The underlying concept dataset.

Type:

ConceptDataset

trainset

Training subset after setup().

Type:

Subset or None

valset

Validation subset after setup().

Type:

Subset or None

testset

Test subset after setup().

Type:

Subset or None

backbone

The backbone wrapper for feature extraction.

Type:

Backbone or None

scalers

Dictionary of scalers for data normalization.

Type:

dict

splitter

The splitter used for data splitting.

Type:

object

Examples

Basic usage with random splitting:

>>> from torch_concepts.data.datasets import ToyDataset
>>> dataset = ToyDataset(dataset='xor', n_gen=1000)
>>> dm = ConceptDataModule(
...     dataset=dataset,
...     val_size=0.1,
...     test_size=0.2,
...     batch_size=32
... )
>>> dm.setup('fit')
>>> print(f"Train: {dm.train_len}, Val: {dm.val_len}, Test: {dm.test_len}")
Train: 700, Val: 100, Test: 200

Using backbone for embedding precomputation:

>>> dm = ConceptDataModule(
...     dataset=image_dataset,
...     backbone='resnet50',
...     precompute_embs=True,
...     batch_size=64,
...     workers=4
... )
>>> dm.setup('fit')  # Computes and caches embeddings
>>> # On subsequent runs, embeddings are loaded from cache

Using HuggingFace backbone:

>>> dm = ConceptDataModule(
...     dataset=image_dataset,
...     backbone='facebook/dinov2-base',
...     precompute_embs=True
... )

See also

Backbone

Feature extraction wrapper class.

ConceptDataset

Base dataset class for concept data.

RandomSplitter

Default splitter for train/val/test splits.

NativeSplitter

Splitter using dataset’s native splits.

property trainset

The training subset.

Returns:

Training data subset, or None if not yet set up.

Return type:

Subset or None

property valset

The validation subset.

Returns:

Validation data subset, or None if not yet set up.

Return type:

Subset or None

property testset

The test subset.

Returns:

Test data subset, or None if not yet set up.

Return type:

Subset or None

property train_len

Number of samples in the training set.

Returns:

Training set length, or None if not set up.

Return type:

int or None

property val_len

Number of samples in the validation set.

Returns:

Validation set length, or None if not set up.

Return type:

int or None

property test_len

Number of samples in the test set.

Returns:

Test set length, or None if not set up.

Return type:

int or None

property n_samples: int

Total number of samples in the dataset.

Returns:

Total number of samples.

Return type:

int

property backbone: Backbone | None

The backbone model wrapper for feature extraction.

Returns:

The backbone wrapper, or None if not configured.

Return type:

Backbone or None

setup(stage: Literal['fit', 'validate', 'test', 'predict'] | None = None, backbone_device: str | None = None, verbose: bool | None = True) None[source]

Prepare the data for training, validation, or testing.

This method is called by PyTorch Lightning with ‘fit’, ‘validate’, ‘test’, or ‘predict’ stages. It handles:

  1. Backbone embedding precomputation (if configured)

  2. Data splitting using the configured splitter

Parameters:
  • stage ({'fit', 'validate', 'test', 'predict'}, optional) – The stage for which data is being prepared. If None, prepares data for all stages. Default is None.

  • backbone_device (str, optional) – Device for backbone computation (‘cpu’, ‘cuda’, etc.). If None, auto-detects available hardware. Default is None.

  • verbose (bool, optional) – If True, print detailed logging information during setup. Default is True.

Notes

Embedding Caching Behavior:

When precompute_embs=True:

  • If cached embeddings exist at {dataset.root_dir}/{backbone.filename}, they are loaded automatically

  • If not, embeddings are computed using the backbone and saved to cache

  • Set force_recompute=True to always recompute

When precompute_embs=False:

  • Uses original input_data without backbone preprocessing

  • Backbone is ignored even if provided

Examples

>>> dm = ConceptDataModule(dataset, backbone='resnet50', precompute_embs=True)
>>> dm.setup('fit')  # Computes/loads embeddings and creates splits
>>> dm.setup('test', backbone_device='cuda:1')  # Use specific GPU
get_dataloader(split: Literal['train', 'val', 'test'] | None = None, shuffle: bool = False, batch_size: int | None = None) DataLoader | None[source]

Get the DataLoader for a specific split.

Parameters:
  • split ({'train', 'val', 'test'}, optional) – Which split to create a DataLoader for. If None, returns a DataLoader for the entire dataset. Default is None.

  • shuffle (bool, optional) – Whether to shuffle the data. Typically True only for training. Default is False.

  • batch_size (int, optional) – Mini-batch size. If None, uses self.batch_size. Default is None.

Returns:

DataLoader for the requested split, or None if the split is not available (e.g., empty split).

Return type:

DataLoader or None

Raises:

ValueError – If split is not one of ‘train’, ‘val’, ‘test’, or None.

Notes

For training DataLoaders, drop_last=True is set to ensure consistent batch sizes across iterations.

train_dataloader(shuffle: bool = True, batch_size: int | None = None) DataLoader | None[source]

Get the training DataLoader.

Parameters:
  • shuffle (bool, optional) – Whether to shuffle the data. Default is True.

  • batch_size (int, optional) – Mini-batch size. If None, uses self.batch_size.

Returns:

Training DataLoader, or None if trainset is not available.

Return type:

DataLoader or None

val_dataloader(shuffle: bool = False, batch_size: int | None = None) DataLoader | None[source]

Get the validation DataLoader.

Parameters:
  • shuffle (bool, optional) – Whether to shuffle the data. Default is False.

  • batch_size (int, optional) – Mini-batch size. If None, uses self.batch_size.

Returns:

Validation DataLoader, or None if valset is not available.

Return type:

DataLoader or None

test_dataloader(shuffle: bool = False, batch_size: int | None = None) DataLoader | None[source]

Get the test DataLoader.

Parameters:
  • shuffle (bool, optional) – Whether to shuffle the data. Default is False.

  • batch_size (int, optional) – Mini-batch size. If None, uses self.batch_size.

Returns:

Test DataLoader, or None if testset is not available.

Return type:

DataLoader or None

Scaler Classes

class Scaler(bias=0.0, scale=1.0)[source]

Bases: ABC

Abstract base class for data scaling transformations.

Provides a consistent interface for fitting scalers to data and applying forward/inverse transformations. All concrete scaler implementations should inherit from this class and implement fit(), transform(), and inverse_transform() methods.

Parameters:
  • bias (float, optional) – Initial bias value. Defaults to 0.0.

  • scale (float, optional) – Initial scale value. Defaults to 1.0.

Example

>>> class MinMaxScaler(Scaler):
...     def fit(self, x, dim=0):
...         self.min = x.min(dim=dim, keepdim=True)[0]
...         self.max = x.max(dim=dim, keepdim=True)[0]
...         return self
...
...     def transform(self, x):
...         return (x - self.min) / (self.max - self.min)
...
...     def inverse_transform(self, x):
...         return x * (self.max - self.min) + self.min
abstract fit(x: Tensor, dim: int = 0) Scaler[source]

Fit the scaler to the input data. :param x: Input tensor to fit the scaler to. :param dim: Dimension along which to compute statistics (default: 0).

Returns:

The fitted scaler instance for method chaining.

Return type:

self

abstract transform(x: Tensor) Tensor[source]

Apply the fitted transformation to the input tensor. :param x: Input tensor to transform.

Returns:

Transformed tensor with same shape as input.

abstract inverse_transform(x: Tensor) Tensor[source]

Reverse the transformation to recover original data. :param x: Transformed tensor to inverse-transform.

Returns:

Tensor in original scale with same shape as input.

fit_transform(x: Tensor, dim: int = 0) Tensor[source]

Fit the scaler and transform the input data in one operation. :param x: Input tensor to fit and transform. :param dim: Dimension along which to compute statistics (default: 0).

Returns:

Transformed tensor with same shape as input.

Splitter Classes

class Splitter[source]

Bases: ABC

Abstract base class for dataset splitting strategies.

Splitters divide a ConceptDataset into train, validation, and test splits. They store indices for each split and provide properties to access split sizes and indices. All concrete splitter implementations should inherit from this class and implement the fit() method.

train_idxs

Training set indices.

Type:

list

val_idxs

Validation set indices.

Type:

list

test_idxs

Test set indices.

Type:

list

Example

>>> class CustomSplitter(Splitter):
...     def fit(self, dataset):
...         n = len(dataset)
...         self.set_indices(
...             train=list(range(int(0.7*n))),
...             val=list(range(int(0.7*n), int(0.9*n))),
...             test=list(range(int(0.9*n), n))
...         )
...         self._fitted = True
>>>
>>> splitter = CustomSplitter()
>>> splitter.fit(my_dataset)
>>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}")
property indices
property fitted
property train_idxs
property val_idxs
property test_idxs
property train_len
property val_len
property test_len
set_indices(train=None, val=None, test=None)[source]
reset()[source]
abstract fit(dataset: ConceptDataset)[source]

Split the dataset into train/val/test sets.

This method should set the following attributes: - self.train_idxs: List of training indices - self.val_idxs: List of validation indices - self.test_idxs: List of test indices

Parameters:

dataset – The dataset to split.

split(dataset: ConceptDataset) None[source]