Data Base Classes¶
This module provides base classes for data handling in concept-based models.
Summary¶
Dataset Base Classes
Base class for concept-annotated datasets. |
DataModule Base Classes
PyTorch Lightning DataModule for concept-based datasets. |
Scaler Base Classes
Abstract base class for data scaling transformations. |
Splitter Base Classes
Abstract base class for dataset splitting strategies. |
Class Documentation¶
Dataset Classes¶
- class ConceptDataset(input_data: ndarray | DataFrame | Tensor, concepts: ndarray | DataFrame | Tensor, annotations: Annotations | None = None, graph: DataFrame | None = None, concept_names_subset: List[str] | None = None, precision: int | str = 32, name: str | None = None)[source]¶
Bases:
DatasetBase class for concept-annotated datasets.
This class extends PyTorch’s Dataset to support concept annotations, concept graphs, and various metadata. It provides a unified interface for working with datasets that have both input features and concept labels.
- input_data¶
Input features/images.
- Type:
Tensor
- concepts¶
Concept annotations.
- Type:
Tensor
- annotations¶
Detailed concept annotations with metadata.
- Type:
- Parameters:
input_data – Input features as numpy array, pandas DataFrame, or Tensor.
concepts – Concept annotations as numpy array, pandas DataFrame, or Tensor.
annotations – Optional Annotations object with concept metadata.
graph – Optional concept graph as pandas DataFrame or tensor.
concept_names_subset – Optional list to select subset of concepts.
precision – Numerical precision (16, 32, or 64, default: 32).
name – Optional dataset name.
exogenous – Optional exogenous variables (not yet implemented).
- Raises:
ValueError – If concepts is None or annotations don’t include axis 1.
NotImplementedError – If continuous concepts or exogenous variables are used.
Example
>>> X = torch.randn(100, 28, 28) # 100 images >>> C = torch.randint(0, 2, (100, 5)) # 5 binary concepts >>> annotations = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'c3', 'c4', 'c5'])}) >>> dataset = ConceptDataset(X, C, annotations=annotations) >>> len(dataset) 100
- property n_features: tuple¶
Shape of features in dataset’s input (excluding number of samples).
- Returns:
Shape of input features.
- Return type:
- property n_concepts: int¶
Number of concepts in the dataset.
- Returns:
Number of concepts, or 0 if no concepts.
- Return type:
- property concept_names: List[str]¶
List of concept names in the dataset.
- Returns:
Names of all concepts.
- Return type:
List[str]
- property annotations: Annotations | None¶
Annotations for the concepts in the dataset.
- abstract property raw_filenames: List[str]¶
The list of raw filenames in the
self.root_dirfolder that must be present in order to skip download(). Should be implemented by subclasses.
- abstract property processed_filenames: List[str]¶
The list of processed filenames in the
self.root_dirfolder that must be present in order to skip build(). Should be implemented by subclasses.
- property raw_paths: List[str]¶
The absolute paths of the raw files that must be present in order to skip downloading.
- property processed_paths: List[str]¶
The absolute paths of the processed files that must be present in order to skip building.
- maybe_reduce_annotations(annotations: Annotations, concept_names_subset: List[str] | None = None)[source]¶
Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use.
If
None, will use all concepts.
- set_graph(graph: DataFrame)[source]¶
Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
- Parameters:
graph – A pandas DataFrame representing the adjacency matrix of the causal graph. Rows and columns should be named after the variables in the dataset.
- set_concepts(concepts: ndarray | DataFrame | Tensor)[source]¶
Set concept annotations for the dataset.
- Parameters:
concepts – Tensor of shape (n_samples, n_concepts) containing concept values
concept_names – List of strings naming each concept. If None, will use numbered concepts like “concept_0”, “concept_1”, etc.
DataModule Classes¶
- class ConceptDataModule(dataset: ConceptDataset, val_size: float = 0.1, test_size: float = 0.2, batch_size: int = 64, backbone: str | Callable[[Tensor], Tensor] | None = None, precompute_embs: bool = False, force_recompute: bool = False, scalers: Mapping | None = None, splitter: object | None = None, workers: int = 0, pin_memory: bool = False)[source]¶
Bases:
LightningDataModulePyTorch Lightning DataModule for concept-based datasets.
Handles the complete data pipeline for concept-based learning:
Data splitting: Train/validation/test splits using configurable splitters
Embedding precomputation: Optional backbone feature extraction with caching
Data scaling: Optional normalization through configurable scalers
DataLoader creation: Efficient data loading with proper configurations
The datamodule automatically caches computed embeddings to disk, allowing fast reloading on subsequent runs without recomputation.
- Parameters:
dataset (ConceptDataset) – Complete dataset to be split and processed.
val_size (float, optional) – Validation set fraction (0.0 to 1.0). Default is 0.1.
test_size (float, optional) – Test set fraction (0.0 to 1.0). Default is 0.2.
batch_size (int, optional) – Mini-batch size for DataLoaders. Default is 64.
backbone (str or None, optional) –
Feature extraction model name. Can be:
HuggingFace model: ‘facebook/dinov2-base’, ‘google/vit-base-patch16-224’
torchvision model: ‘resnet18’, ‘resnet50’, ‘vgg16’, ‘efficientnet_b0’
If provided with
precompute_embs=True, embeddings are computed and cached to disk. Default is None.precompute_embs (bool, optional) – If True and backbone is provided, precompute and cache backbone embeddings before training. Embeddings are saved to
{dataset.root_dir}/{backbone_filename}.pt. Default is False.force_recompute (bool, optional) – If True, recompute embeddings even if cached file exists. Useful when the dataset or backbone changes. Default is False.
scalers (Mapping or None, optional) – Dictionary of custom scalers for data normalization. Keys should match target keys in the batch (e.g., ‘input’, ‘concepts’). If None, no scaling is applied. Default is None.
splitter (object or None, optional) – Custom splitter for train/val/test splits. Must implement a
split(dataset)method that setstrain_idxs,val_idxs, andtest_idxsattributes. If None, uses RandomSplitter with the specifiedval_sizeandtest_size. Default is None.workers (int, optional) – Number of subprocesses for data loading. 0 means data will be loaded in the main process. Default is 0.
pin_memory (bool, optional) – If True, the data loader will copy Tensors into pinned memory before returning them. Useful for GPU training. Default is False.
- dataset¶
The underlying concept dataset.
- Type:
- trainset¶
Training subset after setup().
- Type:
Subset or None
- valset¶
Validation subset after setup().
- Type:
Subset or None
- testset¶
Test subset after setup().
- Type:
Subset or None
- backbone¶
The backbone wrapper for feature extraction.
- Type:
Backbone or None
Examples
Basic usage with random splitting:
>>> from torch_concepts.data.datasets import ToyDataset >>> dataset = ToyDataset(dataset='xor', n_gen=1000) >>> dm = ConceptDataModule( ... dataset=dataset, ... val_size=0.1, ... test_size=0.2, ... batch_size=32 ... ) >>> dm.setup('fit') >>> print(f"Train: {dm.train_len}, Val: {dm.val_len}, Test: {dm.test_len}") Train: 700, Val: 100, Test: 200
Using backbone for embedding precomputation:
>>> dm = ConceptDataModule( ... dataset=image_dataset, ... backbone='resnet50', ... precompute_embs=True, ... batch_size=64, ... workers=4 ... ) >>> dm.setup('fit') # Computes and caches embeddings >>> # On subsequent runs, embeddings are loaded from cache
Using HuggingFace backbone:
>>> dm = ConceptDataModule( ... dataset=image_dataset, ... backbone='facebook/dinov2-base', ... precompute_embs=True ... )
See also
BackboneFeature extraction wrapper class.
ConceptDatasetBase dataset class for concept data.
RandomSplitterDefault splitter for train/val/test splits.
NativeSplitterSplitter using dataset’s native splits.
- property trainset¶
The training subset.
- Returns:
Training data subset, or None if not yet set up.
- Return type:
Subset or None
- property valset¶
The validation subset.
- Returns:
Validation data subset, or None if not yet set up.
- Return type:
Subset or None
- property testset¶
The test subset.
- Returns:
Test data subset, or None if not yet set up.
- Return type:
Subset or None
- property train_len¶
Number of samples in the training set.
- Returns:
Training set length, or None if not set up.
- Return type:
int or None
- property val_len¶
Number of samples in the validation set.
- Returns:
Validation set length, or None if not set up.
- Return type:
int or None
- property test_len¶
Number of samples in the test set.
- Returns:
Test set length, or None if not set up.
- Return type:
int or None
- property n_samples: int¶
Total number of samples in the dataset.
- Returns:
Total number of samples.
- Return type:
- property backbone: Backbone | None¶
The backbone model wrapper for feature extraction.
- Returns:
The backbone wrapper, or None if not configured.
- Return type:
Backbone or None
- setup(stage: Literal['fit', 'validate', 'test', 'predict'] | None = None, backbone_device: str | None = None, verbose: bool | None = True) None[source]¶
Prepare the data for training, validation, or testing.
This method is called by PyTorch Lightning with ‘fit’, ‘validate’, ‘test’, or ‘predict’ stages. It handles:
Backbone embedding precomputation (if configured)
Data splitting using the configured splitter
- Parameters:
stage ({'fit', 'validate', 'test', 'predict'}, optional) – The stage for which data is being prepared. If None, prepares data for all stages. Default is None.
backbone_device (str, optional) – Device for backbone computation (‘cpu’, ‘cuda’, etc.). If None, auto-detects available hardware. Default is None.
verbose (bool, optional) – If True, print detailed logging information during setup. Default is True.
Notes
Embedding Caching Behavior:
When
precompute_embs=True:If cached embeddings exist at
{dataset.root_dir}/{backbone.filename}, they are loaded automaticallyIf not, embeddings are computed using the backbone and saved to cache
Set
force_recompute=Trueto always recompute
When
precompute_embs=False:Uses original
input_datawithout backbone preprocessingBackbone is ignored even if provided
Examples
>>> dm = ConceptDataModule(dataset, backbone='resnet50', precompute_embs=True) >>> dm.setup('fit') # Computes/loads embeddings and creates splits >>> dm.setup('test', backbone_device='cuda:1') # Use specific GPU
- get_dataloader(split: Literal['train', 'val', 'test'] | None = None, shuffle: bool = False, batch_size: int | None = None) DataLoader | None[source]¶
Get the DataLoader for a specific split.
- Parameters:
split ({'train', 'val', 'test'}, optional) – Which split to create a DataLoader for. If None, returns a DataLoader for the entire dataset. Default is None.
shuffle (bool, optional) – Whether to shuffle the data. Typically True only for training. Default is False.
batch_size (int, optional) – Mini-batch size. If None, uses
self.batch_size. Default is None.
- Returns:
DataLoader for the requested split, or None if the split is not available (e.g., empty split).
- Return type:
DataLoader or None
- Raises:
ValueError – If split is not one of ‘train’, ‘val’, ‘test’, or None.
Notes
For training DataLoaders,
drop_last=Trueis set to ensure consistent batch sizes across iterations.
- train_dataloader(shuffle: bool = True, batch_size: int | None = None) DataLoader | None[source]¶
Get the training DataLoader.
- val_dataloader(shuffle: bool = False, batch_size: int | None = None) DataLoader | None[source]¶
Get the validation DataLoader.
Scaler Classes¶
- class Scaler(bias=0.0, scale=1.0)[source]¶
Bases:
ABCAbstract base class for data scaling transformations.
Provides a consistent interface for fitting scalers to data and applying forward/inverse transformations. All concrete scaler implementations should inherit from this class and implement fit(), transform(), and inverse_transform() methods.
- Parameters:
Example
>>> class MinMaxScaler(Scaler): ... def fit(self, x, dim=0): ... self.min = x.min(dim=dim, keepdim=True)[0] ... self.max = x.max(dim=dim, keepdim=True)[0] ... return self ... ... def transform(self, x): ... return (x - self.min) / (self.max - self.min) ... ... def inverse_transform(self, x): ... return x * (self.max - self.min) + self.min
- abstract fit(x: Tensor, dim: int = 0) Scaler[source]¶
Fit the scaler to the input data. :param x: Input tensor to fit the scaler to. :param dim: Dimension along which to compute statistics (default: 0).
- Returns:
The fitted scaler instance for method chaining.
- Return type:
self
- abstract transform(x: Tensor) Tensor[source]¶
Apply the fitted transformation to the input tensor. :param x: Input tensor to transform.
- Returns:
Transformed tensor with same shape as input.
Splitter Classes¶
- class Splitter[source]¶
Bases:
ABCAbstract base class for dataset splitting strategies.
Splitters divide a ConceptDataset into train, validation, and test splits. They store indices for each split and provide properties to access split sizes and indices. All concrete splitter implementations should inherit from this class and implement the fit() method.
Example
>>> class CustomSplitter(Splitter): ... def fit(self, dataset): ... n = len(dataset) ... self.set_indices( ... train=list(range(int(0.7*n))), ... val=list(range(int(0.7*n), int(0.9*n))), ... test=list(range(int(0.9*n), n)) ... ) ... self._fitted = True >>> >>> splitter = CustomSplitter() >>> splitter.fit(my_dataset) >>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}")
- property indices¶
- property fitted¶
- property train_idxs¶
- property val_idxs¶
- property test_idxs¶
- property train_len¶
- property val_len¶
- property test_len¶
- abstract fit(dataset: ConceptDataset)[source]¶
Split the dataset into train/val/test sets.
This method should set the following attributes: - self.train_idxs: List of training indices - self.val_idxs: List of validation indices - self.test_idxs: List of test indices
- Parameters:
dataset – The dataset to split.
- split(dataset: ConceptDataset) None[source]¶