torch_concepts.data.base.dataset.ConceptDataset

class ConceptDataset(input_data: ndarray | DataFrame | Tensor, concepts: ndarray | DataFrame | Tensor, annotations: Annotations | None = None, graph: DataFrame | None = None, concept_names_subset: List[str] | None = None, precision: int | str = 32, name: str | None = None)[source]

Base class for concept-annotated datasets.

This class extends PyTorch’s Dataset to support concept annotations, concept graphs, and various metadata. It provides a unified interface for working with datasets that have both input features and concept labels.

name

Name of the dataset.

Type:

str

precision

Numerical precision for tensors (16, 32, or 64).

Type:

int or str

input_data

Input features/images.

Type:

Tensor

concepts

Concept annotations.

Type:

Tensor

annotations

Detailed concept annotations with metadata.

Type:

Annotations

Parameters:
  • input_data – Input features as numpy array, pandas DataFrame, or Tensor.

  • concepts – Concept annotations as numpy array, pandas DataFrame, or Tensor.

  • annotations – Optional Annotations object with concept metadata.

  • graph – Optional concept graph as pandas DataFrame or tensor.

  • concept_names_subset – Optional list to select subset of concepts.

  • precision – Numerical precision (16, 32, or 64, default: 32).

  • name – Optional dataset name.

  • exogenous – Optional exogenous variables (not yet implemented).

Raises:
  • ValueError – If concepts is None or annotations don’t include axis 1.

  • NotImplementedError – If continuous concepts or exogenous variables are used.

Example

>>> X = torch.randn(100, 28, 28)  # 100 images
>>> C = torch.randint(0, 2, (100, 5))  # 5 binary concepts
>>> annotations = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'c3', 'c4', 'c5'])})
>>> dataset = ConceptDataset(X, C, annotations=annotations)
>>> len(dataset)
100
__init__(input_data: ndarray | DataFrame | Tensor, concepts: ndarray | DataFrame | Tensor, annotations: Annotations | None = None, graph: DataFrame | None = None, concept_names_subset: List[str] | None = None, precision: int | str = 32, name: str | None = None)[source]

Methods

__init__(input_data, concepts[, ...])

add_exogenous(name, value[, convert_precision])

add_scaler(key, scaler)

Add a scaler for preprocessing a specific tensor.

build()

Eventually build the dataset from raw data to self.root_dir folder.

download()

Downloads dataset's files to the self.root_dir folder.

load(*args, **kwargs)

Loads raw dataset and preprocess data.

load_raw(*args, **kwargs)

Loads raw dataset without any data preprocessing.

maybe_build()

maybe_download()

maybe_reduce_annotations(annotations[, ...])

Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use. If None, will use all concepts.

remove_exogenous(name)

set_concepts(concepts)

Set concept annotations for the dataset.

set_graph(graph)

Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.

Attributes

annotations

Annotations for the concepts in the dataset.

concept_names

List of concept names in the dataset.

exogenous

Mapping of dataset's exogenous variables.

graph

Adjacency matrix of the causal graph between concepts.

has_concepts

Whether the dataset has concept annotations.

has_exogenous

Whether the dataset has exogenous information.

n_concepts

Number of concepts in the dataset.

n_exogenous

Number of exogenous variables in the dataset.

n_features

Shape of features in dataset's input (excluding number of samples).

n_samples

Number of samples in the dataset.

processed_filenames

The list of processed filenames in the self.root_dir folder that must be present in order to skip build().

processed_paths

The absolute paths of the processed files that must be present in order to skip building.

raw_filenames

The list of raw filenames in the self.root_dir folder that must be present in order to skip download().

raw_paths

The absolute paths of the raw files that must be present in order to skip downloading.

root_dir

shape

Shape of the input tensor.