"""
Base dataset class for concept-annotated datasets.
This module provides the ConceptDataset class, which serves as the foundation
for all concept-based datasets in the torch_concepts package.
"""
from abc import abstractmethod
import os
import numpy as np
import pandas as pd
from torch import Tensor
from torch.utils.data import Dataset
from copy import deepcopy
from typing import Dict, List, Optional, Union
import warnings
from ...nn.modules.mid.constructors.concept_graph import ConceptGraph
from ...annotations import Annotations, AxisAnnotation
from ..utils import files_exist, parse_tensor, convert_precision
# TODO: implement masks for missing values
# TODO: add exogenous
# TODO: range for continuous concepts
# TODO: add possibility to annotate multiple axis (e.g., for relational concepts)
[docs]
class ConceptDataset(Dataset):
"""
Base class for concept-annotated datasets.
This class extends PyTorch's Dataset to support concept annotations,
concept graphs, and various metadata. It provides a unified interface
for working with datasets that have both input features and concept labels.
Attributes:
name (str): Name of the dataset.
precision (int or str): Numerical precision for tensors (16, 32, or 64).
input_data (Tensor): Input features/images.
concepts (Tensor): Concept annotations.
annotations (Annotations): Detailed concept annotations with metadata.
Args:
input_data: Input features as numpy array, pandas DataFrame, or Tensor.
concepts: Concept annotations as numpy array, pandas DataFrame, or Tensor.
annotations: Optional Annotations object with concept metadata.
graph: Optional concept graph as pandas DataFrame or tensor.
concept_names_subset: Optional list to select subset of concepts.
precision: Numerical precision (16, 32, or 64, default: 32).
name: Optional dataset name.
exogenous: Optional exogenous variables (not yet implemented).
Raises:
ValueError: If concepts is None or annotations don't include axis 1.
NotImplementedError: If continuous concepts or exogenous variables are used.
Example:
>>> X = torch.randn(100, 28, 28) # 100 images
>>> C = torch.randint(0, 2, (100, 5)) # 5 binary concepts
>>> annotations = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'c3', 'c4', 'c5'])})
>>> dataset = ConceptDataset(X, C, annotations=annotations)
>>> len(dataset)
100
"""
[docs]
def __init__(
self,
input_data: Union[np.ndarray, pd.DataFrame, Tensor],
concepts: Union[np.ndarray, pd.DataFrame, Tensor],
annotations: Optional[Annotations] = None,
graph: Optional[pd.DataFrame] = None,
concept_names_subset: Optional[List[str]] = None,
precision: Union[int, str] = 32,
name: Optional[str] = None,
# TODO: implement handling of exogenous inputs
):
super(ConceptDataset, self).__init__()
# Set info
self.name = name if name is not None else self.__class__.__name__
self.precision = precision
self.embs_precomputed = False # whether input_data
# contains precomputed embeddings
if concepts is None:
raise ValueError("Concepts must be provided for ConceptDataset.")
# sanity check on concept annotations and metadata
if annotations is None and concepts is not None:
warnings.warn("No concept annotations provided. These will be set to default numbered "
"concepts 'concept_{i}'. All concepts will be treated as binary.")
annotations = Annotations({
1: AxisAnnotation(labels=[f"concept_{i}" for i in range(concepts.shape[1])],
cardinalities=None, # assume binary
metadata={f"concept_{i}": {'type': 'discrete', # assume discrete (bernoulli)
} for i in range(concepts.shape[1])})
})
# assert first axis is annotated axis for concepts
if 1 not in annotations.annotated_axes:
raise ValueError("Concept annotations must include axis 1 for concepts. " \
"Axis 0 is always assumed to be the batch dimension")
# sanity check
axis_annotation = annotations[1]
if axis_annotation.metadata is not None:
assert all('type' in v for v in axis_annotation.metadata.values()), \
"Concept metadata must contain 'type' for each concept."
assert all(v['type'] in ['discrete', 'continuous'] for v in axis_annotation.metadata.values()), \
"Concept metadata 'type' must be either 'discrete' or 'continuous'."
if axis_annotation.cardinalities is not None:
concept_names_with_cardinality = [name for name, card in zip(axis_annotation.labels, axis_annotation.cardinalities) if card is not None]
concept_names_without_cardinality = [name for name in axis_annotation.labels if name not in concept_names_with_cardinality]
if concept_names_without_cardinality:
raise ValueError(f"Cardinalities list provided but missing cardinality for concepts: {concept_names_without_cardinality}")
# sanity check on unsupported concept types
if axis_annotation.metadata is not None:
for name, meta in axis_annotation.metadata.items():
# raise error if type metadata contain 'continuous': this is not supported yet
# TODO: implement continuous concept types
if meta['type'] == 'continuous':
raise NotImplementedError("Continuous concept types are not supported yet.")
# set concept annotations
# this defines self.annotations property
self._annotations = annotations
# maybe reduce annotations based on subset of concept names
self.maybe_reduce_annotations(annotations,
concept_names_subset)
# Set dataset's input data X
# TODO: input is assumed to be a one of "np.ndarray, pd.DataFrame, Tensor" for now
# allow more complex data structures in the future with a custom parser
self.input_data: Tensor = parse_tensor(input_data, 'input', self.precision)
# Store concept data C
self.concepts = None
if concepts is not None:
self.set_concepts(concepts) # Annotat
# Store graph
self._graph = None
if graph is not None:
self.set_graph(graph) # graph among all concepts
def __repr__(self):
"""
Return string representation of the dataset.
Returns:
str: String showing dataset name and dimensions.
"""
return f"{self.name}(n_samples={self.n_samples}, n_features={self.n_features}, n_concepts={self.n_concepts})"
def __len__(self) -> int:
"""
Return number of samples in the dataset.
Returns:
int: Number of samples.
"""
return self.n_samples
def __getitem__(self, item):
"""
Get a single sample from the dataset.
Args:
item (int): Index of the sample to retrieve.
Returns:
dict: Dictionary containing 'inputs' and 'concepts' sub-dictionaries.
"""
# Get raw input data and concepts
x = self.input_data[item]
c = self.concepts[item]
# TODO: handle missing values with masks
# Create sample dictionary
sample = {
'inputs': {'x': x}, # input data: multiple inputs can be stored in a dict
'concepts': {'c': c}, # concepts: multiple concepts can be stored in a dict
# TODO: add scalers when these are set
# also check if batch transforms work correctly inside the model training loop
# 'transforms': {'x': self.scalers.get('input', None),
# 'c': self.scalers.get('concepts', None)}
}
return sample
# Dataset properties #####################################################
@property
def n_samples(self) -> int:
"""
Number of samples in the dataset.
Returns:
int: Number of samples.
"""
return self.input_data.size(0)
@property
def n_features(self) -> tuple:
"""
Shape of features in dataset's input (excluding number of samples).
Returns:
tuple: Shape of input features.
"""
return tuple(self.input_data.size()[1:])
@property
def n_concepts(self) -> int:
"""
Number of concepts in the dataset.
Returns:
int: Number of concepts, or 0 if no concepts.
"""
return len(self.concept_names) if self.has_concepts else 0
@property
def concept_names(self) -> List[str]:
"""
List of concept names in the dataset.
Returns:
List[str]: Names of all concepts.
"""
return self.annotations.get_axis_labels(1)
@property
def annotations(self) -> Optional[Annotations]:
"""Annotations for the concepts in the dataset."""
return self._annotations if hasattr(self, '_annotations') else None
@property
def shape(self) -> tuple:
"""Shape of the input tensor."""
return tuple(self.input_data.size())
@property
def exogenous(self) -> Dict[str, Tensor]:
"""Mapping of dataset's exogenous variables."""
# return {name: attr['value'] for name, attr in self._exogenous.items()}
raise NotImplementedError("Exogenous variables are not supported for now.")
@property
def n_exogenous(self) -> int:
"""Number of exogenous variables in the dataset."""
# return len(self._exogenous)
raise NotImplementedError("Exogenous variables are not supported for now.")
@property
def graph(self) -> Optional[ConceptGraph]:
"""Adjacency matrix of the causal graph between concepts."""
return self._graph
# Dataset flags #####################################################
@property
def has_exogenous(self) -> bool:
"""Whether the dataset has exogenous information."""
# return self.n_exogenous > 0
raise NotImplementedError("Exogenous variables are not supported for now.")
@property
def has_concepts(self) -> bool:
"""Whether the dataset has concept annotations."""
return self.concepts is not None
@property
def root_dir(self) -> str:
if isinstance(self.root, str):
root = os.path.expanduser(os.path.normpath(self.root))
else:
raise ValueError("Invalid root directory")
return root
@property
@abstractmethod
def raw_filenames(self) -> List[str]:
"""The list of raw filenames in the :obj:`self.root_dir` folder that must be
present in order to skip `download()`. Should be implemented by subclasses."""
pass
@property
@abstractmethod
def processed_filenames(self) -> List[str]:
"""The list of processed filenames in the :obj:`self.root_dir` folder that must be
present in order to skip `build()`. Should be implemented by subclasses."""
pass
@property
def raw_paths(self) -> List[str]:
"""The absolute paths of the raw files that must be present in order to skip downloading."""
return [os.path.join(self.root_dir, f) for f in self.raw_filenames]
@property
def processed_paths(self) -> List[str]:
"""The absolute paths of the processed files that must be present in order to skip building."""
return [os.path.join(self.root_dir, f) for f in self.processed_filenames]
# Directory utilities ###########################################################
# Loading pipeline: load() → load_raw() → build() → download()
[docs]
def maybe_download(self):
if not files_exist(self.raw_paths):
os.makedirs(self.root_dir, exist_ok=True)
self.download()
[docs]
def maybe_build(self):
if not files_exist(self.processed_paths):
os.makedirs(self.root_dir, exist_ok=True)
self.build()
[docs]
def download(self) -> None:
"""Downloads dataset's files to the :obj:`self.root_dir` folder."""
raise NotImplementedError
[docs]
def build(self) -> None:
"""Eventually build the dataset from raw data to :obj:`self.root_dir`
folder."""
pass
[docs]
def load_raw(self, *args, **kwargs):
"""Loads raw dataset without any data preprocessing."""
raise NotImplementedError
[docs]
def load(self, *args, **kwargs):
"""Loads raw dataset and preprocess data.
Default to :obj:`load_raw`."""
return self.load_raw(*args, **kwargs)
# Setters ##############################################################
[docs]
def maybe_reduce_annotations(self,
annotations: Annotations,
concept_names_subset: Optional[List[str]] = None):
"""Set concept and labels for the dataset.
Args:
annotations: Annotations object for all concepts.
concept_names_subset: List of strings naming the subset of concepts to use.
If :obj:`None`, will use all concepts.
"""
self.concept_names_all = annotations.get_axis_labels(1)
if concept_names_subset is not None:
# sanity check, all subset concepts must be in all concepts
missing_concepts = set(concept_names_subset) - set(self.concept_names_all)
assert not missing_concepts, f"Concepts not found in dataset: {missing_concepts}"
to_select = deepcopy(concept_names_subset)
# Get indices of selected concepts
indices = [self.concept_names_all.index(name) for name in to_select]
# Reduce annotations by extracting only the selected concepts
axis_annotation = annotations[1]
reduced_labels = tuple(axis_annotation.labels[i] for i in indices)
# Reduce cardinalities
reduced_cardinalities = tuple(axis_annotation.cardinalities[i] for i in indices)
# Reduce states
reduced_states = tuple(axis_annotation.states[i] for i in indices)
# Reduce metadata if present
if axis_annotation.metadata is not None:
reduced_metadata = {reduced_labels[i]: axis_annotation.metadata[axis_annotation.labels[indices[i]]]
for i in range(len(indices))}
else:
reduced_metadata = None
# Create reduced annotations
self._annotations = Annotations({
1: AxisAnnotation(
labels=reduced_labels,
cardinalities=reduced_cardinalities,
states=reduced_states,
metadata=reduced_metadata
)
})
[docs]
def set_graph(self, graph: pd.DataFrame):
"""Set the adjacency matrix of the causal graph between concepts
as a pandas DataFrame.
Args:
graph: A pandas DataFrame representing the adjacency matrix of the
causal graph. Rows and columns should be named after the
variables in the dataset.
"""
if not isinstance(graph, pd.DataFrame):
raise TypeError(f"Graph must be a pandas DataFrame, got {type(graph).__name__}.")
# eventually extract subset
graph = graph.loc[self.concept_names, self.concept_names]
self._graph = ConceptGraph(
data=parse_tensor(graph, 'graph', self.precision),
node_names=self.concept_names
)
[docs]
def set_concepts(self, concepts: Union[np.ndarray, pd.DataFrame, Tensor]):
"""Set concept annotations for the dataset.
Args:
concepts: Tensor of shape (n_samples, n_concepts) containing concept values
concept_names: List of strings naming each concept. If None, will use
numbered concepts like "concept_0", "concept_1", etc.
"""
# Validate shape
# concepts' length must match dataset's length
if concepts.shape[0] != self.n_samples:
raise RuntimeError(f"Concepts has {concepts.shape[0]} samples but "
f"input_data has {self.n_samples}.")
# eventually extract subset
if isinstance(concepts, pd.DataFrame):
concepts = concepts.loc[:, self.concept_names]
elif isinstance(concepts, np.ndarray) or isinstance(concepts, Tensor):
rows = [self.concept_names_all.index(name) for name in self.concept_names]
concepts = concepts[:, rows]
else:
raise TypeError(f"Concepts must be a np.ndarray, pd.DataFrame, "
f"or Tensor, got {type(concepts).__name__}.")
#########################################################################
###### modify this to change convention for how to store concepts ######
#########################################################################
# convert pd.Dataframe to tensor
concepts = parse_tensor(concepts, 'concepts', self.precision)
#########################################################################
self.concepts = concepts
[docs]
def add_exogenous(self,
name: str,
value: Union[np.ndarray, pd.DataFrame, Tensor],
convert_precision: bool = True):
raise NotImplementedError("Exogenous variables are not supported for now.")
[docs]
def remove_exogenous(self, name: str):
raise NotImplementedError("Exogenous variables are not supported for now.")
[docs]
def add_scaler(self, key: str, scaler):
"""Add a scaler for preprocessing a specific tensor.
Args:
key (str): The name of the tensor to scale ('input', 'concepts').
scaler (Scaler): The fitted scaler to use.
"""
if key not in ['input', 'concepts']:
raise KeyError(f"{key} not in dataset. Valid keys: 'input', 'concepts'")
self.scalers[key] = scaler
# Utilities ###########################################################