torch_concepts.data.base.dataset.ConceptDataset¶
- class ConceptDataset(input_data: ndarray | DataFrame | Tensor, concepts: ndarray | DataFrame | Tensor, annotations: Annotations | None = None, graph: DataFrame | None = None, concept_names_subset: List[str] | None = None, precision: int | str = 32, name: str | None = None)[source]¶
Base class for concept-annotated datasets.
This class extends PyTorch’s Dataset to support concept annotations, concept graphs, and various metadata. It provides a unified interface for working with datasets that have both input features and concept labels.
- input_data¶
Input features/images.
- Type:
Tensor
- concepts¶
Concept annotations.
- Type:
Tensor
- annotations¶
Detailed concept annotations with metadata.
- Type:
- Parameters:
input_data – Input features as numpy array, pandas DataFrame, or Tensor.
concepts – Concept annotations as numpy array, pandas DataFrame, or Tensor.
annotations – Optional Annotations object with concept metadata.
graph – Optional concept graph as pandas DataFrame or tensor.
concept_names_subset – Optional list to select subset of concepts.
precision – Numerical precision (16, 32, or 64, default: 32).
name – Optional dataset name.
exogenous – Optional exogenous variables (not yet implemented).
- Raises:
ValueError – If concepts is None or annotations don’t include axis 1.
NotImplementedError – If continuous concepts or exogenous variables are used.
Example
>>> X = torch.randn(100, 28, 28) # 100 images >>> C = torch.randint(0, 2, (100, 5)) # 5 binary concepts >>> annotations = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'c3', 'c4', 'c5'])}) >>> dataset = ConceptDataset(X, C, annotations=annotations) >>> len(dataset) 100
- __init__(input_data: ndarray | DataFrame | Tensor, concepts: ndarray | DataFrame | Tensor, annotations: Annotations | None = None, graph: DataFrame | None = None, concept_names_subset: List[str] | None = None, precision: int | str = 32, name: str | None = None)[source]¶
Methods
__init__(input_data, concepts[, ...])add_exogenous(name, value[, convert_precision])add_scaler(key, scaler)Add a scaler for preprocessing a specific tensor.
build()Eventually build the dataset from raw data to
self.root_dirfolder.download()Downloads dataset's files to the
self.root_dirfolder.load(*args, **kwargs)Loads raw dataset and preprocess data.
load_raw(*args, **kwargs)Loads raw dataset without any data preprocessing.
maybe_reduce_annotations(annotations[, ...])Set concept and labels for the dataset. :param annotations: Annotations object for all concepts. :param concept_names_subset: List of strings naming the subset of concepts to use. If
None, will use all concepts.remove_exogenous(name)set_concepts(concepts)Set concept annotations for the dataset.
set_graph(graph)Set the adjacency matrix of the causal graph between concepts as a pandas DataFrame.
Attributes
Annotations for the concepts in the dataset.
List of concept names in the dataset.
Mapping of dataset's exogenous variables.
Adjacency matrix of the causal graph between concepts.
Whether the dataset has concept annotations.
Whether the dataset has exogenous information.
Number of concepts in the dataset.
Number of exogenous variables in the dataset.
Shape of features in dataset's input (excluding number of samples).
Number of samples in the dataset.
The list of processed filenames in the
self.root_dirfolder that must be present in order to skip build().The absolute paths of the processed files that must be present in order to skip building.
The list of raw filenames in the
self.root_dirfolder that must be present in order to skip download().The absolute paths of the raw files that must be present in order to skip downloading.
Shape of the input tensor.