Source code for torch_concepts.data.datasets.toy

import numpy as np
import torch
import pandas as pd
import os
import logging
from numpy.random import multivariate_normal, uniform
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_spd_matrix, make_low_rank_matrix
from typing import List, Optional, Union

from ..base.dataset import ConceptDataset
from ...annotations import Annotations, AxisAnnotation

logger = logging.getLogger(__name__)


def _xor(size, random_state=42):
    # sample from uniform distribution
    np.random.seed(random_state)
    x = np.random.uniform(0, 1, (size, 2))
    c = np.stack([
        x[:, 0] > 0.5,
        x[:, 1] > 0.5,
    ]).T
    y = np.logical_xor(c[:, 0], c[:, 1])

    x = torch.FloatTensor(x)
    c = torch.FloatTensor(c)
    y = torch.FloatTensor(y)

    cy = torch.cat([c, y.unsqueeze(-1)], dim=-1)
    cy_names = ['C1', 'C2', 'xor']
    graph_c_to_y = pd.DataFrame(
        [[0, 0, 1],
         [0, 0, 1],
         [0, 0, 0]],
        index=cy_names,
        columns=cy_names,
    )
    return x, cy, cy_names, graph_c_to_y


def _trigonometry(size, random_state=42):
    np.random.seed(random_state)
    h = np.random.normal(0, 2, (size, 3))
    x, y, z = h[:, 0], h[:, 1], h[:, 2]

    # raw features
    input_features = np.stack([
        np.sin(x) + x,
        np.cos(x) + x,
        np.sin(y) + y,
        np.cos(y) + y,
        np.sin(z) + z,
        np.cos(z) + z,
        x ** 2 + y ** 2 + z ** 2,
    ]).T

    # concepts
    concepts = np.stack([
        x > 0,
        y > 0,
        z > 0,
    ]).T

    # task
    downstream_task = (x + y + z) > 1

    input_features = torch.FloatTensor(input_features)
    concepts = torch.FloatTensor(concepts)
    downstream_task = torch.FloatTensor(downstream_task)

    cy = torch.cat([concepts, downstream_task.unsqueeze(-1)], dim=-1)
    cy_names = ['C1', 'C2', 'C3', 'sumGreaterThan1']
    graph_c_to_y = pd.DataFrame(
        [[0, 0, 0, 1],
         [0, 0, 0, 1],
         [0, 0, 0, 1],
         [0, 0, 0, 0]],
        index=cy_names,
        columns=cy_names,
    )

    return input_features, cy, cy_names, graph_c_to_y


def _dot(size, random_state=42):
    # sample from normal distribution
    emb_size = 2
    np.random.seed(random_state)
    v1 = np.random.randn(size, emb_size) * 2
    v2 = np.ones(emb_size)
    np.random.seed(random_state)
    v3 = np.random.randn(size, emb_size) * 2
    v4 = -np.ones(emb_size)
    x = np.hstack([v1+v3, v1-v3])
    c = np.stack([
        np.dot(v1, v2).ravel() > 0,
        np.dot(v3, v4).ravel() > 0,
    ]).T
    y = ((v1*v3).sum(axis=-1) > 0).astype(np.int64)

    x = torch.FloatTensor(x)
    c = torch.FloatTensor(c)
    y = torch.Tensor(y)

    cy = torch.cat([c, y.unsqueeze(-1)], dim=-1)
    cy_names = ['dotV1V2GreaterThan0', 'dotV3V4GreaterThan0', 'dotV1V3GreaterThan0']
    graph_c_to_y = pd.DataFrame(
        [[0, 0, 1],
         [0, 0, 1],
         [0, 0, 0]],
        index=cy_names,
        columns=cy_names,
    )

    return x, cy, cy_names, graph_c_to_y


def _toy_problem(n_samples: int = 10, seed: int = 42) -> torch.Tensor:
    torch.manual_seed(seed)
    A = torch.randint(0, 2, (n_samples,), dtype=torch.bool)
    torch.manual_seed(seed + 1)
    B = torch.randint(0, 2, (n_samples,), dtype=torch.bool)

    # Column C is true if B is true, randomly true/false if B is false
    C = ~B

    # Column D is true if A or C is true, randomly true/false if both are false
    D = A & C

    # Combine all columns into a matrix
    return torch.stack((A, B, C, D), dim=1).float()


def _checkmark(n_samples: int = 10, seed: int =42, perturb: float = 0.1):
    x = _toy_problem(n_samples, seed)
    c = x.clone()
    torch.manual_seed(seed)
    x = x * 2 - 1 + torch.randn_like(x) * perturb

    # Create DAG as pandas DataFrame with proper column/row names
    concept_names = ['A', 'B', 'C', 'D']
    dag_array = [[0, 0, 0, 1],  # A influences D
                 [0, 0, 1, 0],  # B influences C
                 [0, 0, 0, 1],  # C influences D
                 [0, 0, 0, 0]]  # D doesn't influence others
    dag = pd.DataFrame(dag_array, index=concept_names, columns=concept_names)

    return x, c, concept_names, dag


[docs] class ToyDataset(ConceptDataset): """ Synthetic datasets for concept-based learning experiments. This class provides several toy datasets with known ground-truth concept relationships and causal structures. Each dataset includes input features, binary concepts, tasks, and a directed acyclic graph (DAG) representing concept-to-task relationships. Available Datasets ------------------ - **xor**: Simple XOR dataset with 2 input features, 2 concepts (C1, C2), and 1 task (xor). The task is the XOR of the two concepts. - **trigonometry**: Dataset with 7 trigonometric input features derived from 3 hidden variables, 3 concepts (C1, C2, C3) representing the signs of the hidden variables, and 1 task (sumGreaterThan1). - **dot**: Dataset with 4 input features, 2 concepts based on dot products (dotV1V2GreaterThan0, dotV3V4GreaterThan0), and 1 task (dotV1V3GreaterThan0). - **checkmark**: Dataset with 4 input features and 4 concepts (A, B, C, D), where C = NOT B and D = A AND C, demonstrating causal relationships. Parameters ---------- dataset : str Name of the toy dataset to load. Must be one of: 'xor', 'trigonometry', 'dot', or 'checkmark'. root : str, optional Root directory to store/load the dataset files. If None, defaults to './data/toy_datasets/{dataset_name}'. Default: None seed : int, optional Random seed for reproducible data generation. Default: 42 n_gen : int, optional Number of samples to generate. Default: 10000 concept_subset : list of str, optional Subset of concept names to use. If provided, only the specified concepts will be included in the dataset. Default: None (use all concepts) Attributes ---------- input_data : torch.Tensor Input features tensor of shape (n_samples, n_features). concepts : torch.Tensor Concepts and tasks tensor of shape (n_samples, n_concepts + n_tasks). Note: This includes both concepts and tasks concatenated. annotations : Annotations Metadata about concept names, cardinalities, and types. graph : pandas.DataFrame Directed acyclic graph representing concept-to-task relationships. Stored as an adjacency matrix with concept/task names as indices. concept_names : list of str Names of all concepts and tasks in the dataset. n_concepts : int Total number of concepts and tasks (includes both). n_features : tuple or int Dimensionality of input features. Examples -------- Basic usage with XOR dataset: >>> from torch_concepts.data.datasets import ToyDataset >>> >>> # Create XOR dataset with 1000 samples >>> dataset = ToyDataset(dataset='xor', seed=42, n_gen=1000) >>> print(f"Dataset size: {len(dataset)}") >>> print(f"Input features: {dataset.n_features}") >>> print(f"Concepts: {dataset.concept_names}") >>> >>> # Access a single sample >>> sample = dataset[0] >>> x = sample['inputs']['x'] # input features >>> c = sample['concepts']['c'] # concepts and task >>> >>> # Get concept graph >>> print(dataset.graph) References ---------- .. [1] Espinosa Zarlenga, M., et al. "Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off", NeurIPS 2022. https://arxiv.org/abs/2209.09056 .. [2] Dominici, G., et al. (2025). "Causal Concept Graph Models: Beyond Causal Opacity in Deep Learning." ICLR 2025. https://arxiv.org/abs/2405.16507 See Also -------- CompletenessDataset : Synthetic dataset for concept completeness experiments """
[docs] def __init__( self, dataset: str, # name of the toy dataset ('xor', 'trigonometry', 'dot', 'checkmark') root: str = None, # root directory to store/load the dataset seed: int = 42, # seed for data generation n_gen: int = 10000, # number of samples to generate concept_subset: Optional[list] = None, # subset of concept labels ): if dataset.lower() not in TOYDATASETS: raise ValueError(f"Dataset {dataset} not found. Available datasets: {TOYDATASETS}") self.dataset_name = dataset.lower() self.name = dataset.lower() self.seed = seed # If root is not provided, create a local folder automatically if root is None: root = os.path.join(os.getcwd(), 'data', 'toy_datasets', self.dataset_name) self.root = root self.n_gen = n_gen # Load data (will generate if not exists) input_data, concepts, annotations, graph = self.load() # Initialize parent class super().__init__( input_data=input_data, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_subset, name=f"ToyDataset_{dataset}" )
@property def raw_filenames(self) -> List[str]: """No raw files needed - data is generated.""" return [] @property def processed_filenames(self) -> List[str]: """List of processed filenames that will be created during build step.""" files = [ f"{self.dataset_name}_input_N_{self.n_gen}_seed_{self.seed}.pt", f"{self.dataset_name}_concepts_N_{self.n_gen}_seed_{self.seed}.pt", f"{self.dataset_name}_annotations.pt", f"{self.dataset_name}_graph.h5", ] return files
[docs] def download(self): """No download needed for toy datasets.""" pass
[docs] def build(self): """Generate synthetic data and save to disk.""" logger.info(f"Generating {self.dataset_name} dataset with n_gen={self.n_gen}, seed={self.seed}") # Select the appropriate data generation function if self.dataset_name == 'xor': input_data, concepts, concept_names, graph = _xor(self.n_gen, self.seed) elif self.dataset_name == 'trigonometry': input_data, concepts, concept_names, graph = _trigonometry(self.n_gen, self.seed) elif self.dataset_name == 'dot': input_data, concepts, concept_names, graph = _dot(self.n_gen, self.seed) elif self.dataset_name == 'checkmark': input_data, concepts, concept_names, graph = _checkmark(self.n_gen, self.seed) else: raise ValueError(f"Unknown dataset: {self.dataset_name}") # Create annotations concept_metadata = { name: {'type': 'discrete'} for name in concept_names } cardinalities = tuple([1] * len(concept_names)) # All binary concepts annotations = Annotations({ 1: AxisAnnotation( labels=concept_names, cardinalities=cardinalities, metadata=concept_metadata ) }) # Save all data logger.info(f"Saving dataset to {self.root_dir}") os.makedirs(self.root_dir, exist_ok=True) torch.save(input_data, self.processed_paths[0]) torch.save(concepts, self.processed_paths[1]) torch.save(annotations, self.processed_paths[2]) graph.to_hdf(self.processed_paths[3], key="graph", mode="w")
[docs] def load_raw(self): """Load the generated dataset from disk.""" self.maybe_build() logger.info(f"Loading dataset from {self.root_dir}") input_data = torch.load(self.processed_paths[0], weights_only=False) concepts = torch.load(self.processed_paths[1], weights_only=False) annotations = torch.load(self.processed_paths[2], weights_only=False) graph = pd.read_hdf(self.processed_paths[3], "graph") return input_data, concepts, annotations, graph
[docs] def load(self): """Load the dataset (wraps load_raw).""" return self.load_raw()
def _relu(x): return x * (x > 0) def _random_nonlin_map(n_in, n_out, n_hidden, rank=1000): W_0 = make_low_rank_matrix(n_in, n_hidden, effective_rank=rank) W_1 = make_low_rank_matrix(n_hidden, n_hidden, effective_rank=rank) W_2 = make_low_rank_matrix(n_hidden, n_out, effective_rank=rank) # No biases b_0 = np.random.uniform(0, 0, (1, n_hidden)) b_1 = np.random.uniform(0, 0, (1, n_hidden)) b_2 = np.random.uniform(0, 0, (1, n_out)) nlin_map = lambda x: np.matmul( _relu( np.matmul( _relu(np.matmul(x, W_0) + np.tile(b_0, (x.shape[0], 1))), W_1, ) + np.tile(b_1, (x.shape[0], 1)) ), W_2, ) + np.tile(b_2, (x.shape[0], 1)) return nlin_map def _complete( n_samples: int = 10, p: int = 2, n_views: int = 10, n_concepts: int = 2, n_hidden_concepts: int = 0, n_tasks: int = 1, seed: int = 42, ): total_concepts = n_concepts + n_hidden_concepts # Replicability np.random.seed(seed) # Generate covariates mu = uniform(-5, 5, p * n_views) sigma = make_spd_matrix(p * n_views, random_state=seed) X = multivariate_normal(mean=mu, cov=sigma, size=n_samples) ss = StandardScaler() X = ss.fit_transform(X) # Produce different views X_views = np.zeros((n_samples, n_views, p)) for v in range(n_views): X_views[:, v] = X[:, (v * p):(v * p + p)] # Nonlinear maps g = _random_nonlin_map( n_in=p * n_views, n_out=total_concepts, n_hidden=int((p * n_views + total_concepts) / 2), ) f = _random_nonlin_map( n_in=total_concepts, n_out=n_tasks, n_hidden=int(total_concepts / 2), ) # Generate concepts c = g(X) c = torch.sigmoid(torch.FloatTensor(c)) c = (c >= 0.5) * 1.0 # Generate labels y = f(c.detach().numpy()) if n_tasks == 1: y = torch.sigmoid(torch.FloatTensor(y)).squeeze(-1) y = (y >= 0.5) * 1.0 else: y = torch.argmax(torch.FloatTensor(y), dim=-1) u = c[:, :n_concepts] X = torch.FloatTensor(X) u = torch.FloatTensor(u) y = torch.FloatTensor(y) uy = torch.cat([u, y.unsqueeze(-1)], dim=-1) uy_names = [f'C{i}' for i in range(n_concepts)] + ['y'] graph_c_to_y = pd.DataFrame( np.zeros((n_concepts + 1, n_concepts + 1)), index=uy_names, columns=uy_names, ) for i in range(n_concepts): graph_c_to_y.iloc[i, n_concepts] = 1 # concepts influence tasks return X, uy, uy_names, graph_c_to_y
[docs] class CompletenessDataset(ConceptDataset): """ Synthetic dataset for concept bottleneck completeness experiments. This dataset generates synthetic data to study complete vs. incomplete concept bottlenecks. Data is generated using randomly initialized multi-layer perceptrons with ReLU activations. Input features are sampled from a multivariate normal distribution, and concepts are derived through nonlinear transformations. Hidden concepts can be included to simulate incomplete bottlenecks. The dataset uses a two-stage generation process: 1. Map inputs X to concepts C (both observed and hidden) via nonlinear function g 2. Map concepts C to tasks Y via nonlinear function f Parameters ---------- name : str Name identifier for the dataset (used for file storage). root : str, optional Root directory to store/load the dataset files. If None, defaults to './data/completeness_datasets/{name}'. Default: None seed : int, optional Random seed for reproducible data generation. Default: 42 n_gen : int, optional Number of samples to generate. Default: 10000 p : int, optional Dimensionality of each view (feature group). Default: 2 n_views : int, optional Number of views/feature groups. Total input features = p * n_views. Default: 10 n_concepts : int, optional Number of observable concepts (not including hidden concepts). Default: 2 n_hidden_concepts : int, optional Number of hidden concepts not observable in the bottleneck. Use this to simulate incomplete concept bottlenecks. Default: 0 n_tasks : int, optional Number of downstream tasks to predict. Default: 1 concept_subset : list of str, optional Subset of concept names to use. If provided, only the specified concepts will be included. Concept names follow format 'C0', 'C1', etc. Default: None Attributes ---------- input_data : torch.Tensor Input features tensor of shape (n_samples, p * n_views). concepts : torch.Tensor Concepts and tasks tensor of shape (n_samples, n_concepts + n_tasks). Note: Hidden concepts are NOT included in this tensor. annotations : Annotations Metadata about concept names, cardinalities, and types. graph : pandas.DataFrame Directed acyclic graph representing concept-to-task relationships. All concepts influence all tasks in this dataset. concept_names : list of str Names of all concepts and tasks. Format: ['C0', 'C1', ..., 'y'] n_concepts : int Total number of observable concepts and tasks (includes both, excludes hidden). n_features : tuple or int Dimensionality of input features (p * n_views). Examples -------- Basic usage with complete bottleneck: >>> from torch_concepts.data.datasets import CompletenessDataset >>> >>> # Create dataset with complete bottleneck (no hidden concepts) >>> dataset = CompletenessDataset( ... name='complete_exp', ... n_gen=5000, ... n_concepts=5, ... n_hidden_concepts=0, ... seed=42 ... ) >>> print(f"Dataset size: {len(dataset)}") >>> print(f"Input features: {dataset.n_features}") >>> print(f"Concepts: {dataset.concept_names}") Creating incomplete bottleneck with hidden concepts: >>> from torch_concepts.data.datasets import CompletenessDataset >>> >>> # Create dataset with incomplete bottleneck >>> dataset = CompletenessDataset( ... name='incomplete_exp', ... n_gen=5000, ... n_concepts=3, # 3 observable concepts ... n_hidden_concepts=2, # 2 hidden concepts (not in bottleneck) ... seed=42 ... ) >>> # The hidden concepts affect tasks but are not observable >>> print(f"Observable concepts: {dataset.n_concepts}") References ---------- .. [1] Laguna, S., et al. "Beyond Concept Bottleneck Models: How to Make Black Boxes Intervenable?", NeurIPS 2024. https://arxiv.org/abs/2401.13544 """
[docs] def __init__( self, name: str, # name of the dataset root: str = None, # root directory to store/load the dataset seed: int = 42, # seed for data generation n_gen: int = 10000, # number of samples to generate p: int = 2, # dimensionality of each view n_views: int = 10, # number of views n_concepts: int = 2, # number of concepts n_hidden_concepts: int = 0, # number of hidden concepts n_tasks: int = 1, # number of tasks concept_subset: Optional[list] = None, # subset of concept labels ): self.name = name self.seed = seed # If root is not provided, create a local folder automatically if root is None: root = os.path.join(os.getcwd(), 'data', 'completeness_datasets', name) self.root = root self.n_gen = n_gen self.p = p self.n_views = n_views self._n_concepts = n_concepts # Use internal variable to avoid property conflict self._n_hidden_concepts = n_hidden_concepts self._n_tasks = n_tasks # Load data (will generate if not exists) input_data, concepts, annotations, graph = self.load() # Initialize parent class super().__init__( input_data=input_data, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_subset, name=name )
@property def raw_filenames(self) -> List[str]: """No raw files needed - data is generated.""" return [] @property def processed_filenames(self) -> List[str]: """List of processed filenames that will be created during build step.""" return [ f"input_N_{self.n_gen}_p_{self.p}_views_{self.n_views}_concepts_{self._n_concepts}_hidden_{self._n_hidden_concepts}_seed_{self.seed}.pt", f"concepts_N_{self.n_gen}_p_{self.p}_views_{self.n_views}_concepts_{self._n_concepts}_hidden_{self._n_hidden_concepts}_seed_{self.seed}.pt", f"annotations_concepts_{self._n_concepts}.pt", "graph.h5", ]
[docs] def download(self): """No download needed for synthetic datasets.""" pass
[docs] def build(self): """Generate synthetic completeness data and save to disk.""" logger.info(f"Generating completeness dataset with n_gen={self.n_gen}, seed={self.seed}") # Generate data using _complete function input_data, concepts, concept_names, graph = _complete( n_samples=self.n_gen, p=self.p, n_views=self.n_views, n_concepts=self._n_concepts, n_hidden_concepts=self._n_hidden_concepts, n_tasks=self._n_tasks, seed=self.seed, ) # Create annotations concept_metadata = { name: {'type': 'discrete'} for name in concept_names } cardinalities = tuple([1] * self._n_concepts) + tuple([self._n_tasks]) annotations = Annotations({ 1: AxisAnnotation( labels=concept_names, cardinalities=cardinalities, metadata=concept_metadata ) }) # Save all data logger.info(f"Saving dataset to {self.root_dir}") os.makedirs(self.root_dir, exist_ok=True) torch.save(input_data, self.processed_paths[0]) torch.save(concepts, self.processed_paths[1]) torch.save(annotations, self.processed_paths[2]) graph.to_hdf(os.path.join(self.root_dir, "graph.h5"), key="graph", mode="w")
[docs] def load_raw(self): """Load the generated dataset from disk.""" self.maybe_build() logger.info(f"Loading dataset from {self.root_dir}") input_data = torch.load(self.processed_paths[0], weights_only=False) concepts = torch.load(self.processed_paths[1], weights_only=False) annotations = torch.load(self.processed_paths[2], weights_only=False) graph = pd.read_hdf(os.path.join(self.root_dir, "graph.h5"), "graph") return input_data, concepts, annotations, graph
[docs] def load(self): """Load the dataset (wraps load_raw).""" return self.load_raw()
TOYDATASETS = ['xor', 'trigonometry', 'dot', 'checkmark']