Source code for torch_concepts.data.datasets.bnlearn

import os
import gzip
import shutil
import pandas as pd
import torch
import logging
from typing import List, Optional

from ...annotations import Annotations

logger = logging.getLogger(__name__)

from ..base import ConceptDataset
from ..preprocessing.autoencoder import extract_embs_from_autoencoder
from ..io import download_url

BUILTIN_DAGS = ['asia', 'alarm', 'andes', 'sachs', 'water']


def _import_bnlearn():
    """Lazily import bnlearn (and pgmpy's sampler), raising a clear error if not installed."""
    try:
        import bnlearn as bn
        from pgmpy.sampling import BayesianModelSampling
        return bn, BayesianModelSampling
    except ImportError as exc:
        raise ImportError(
            "BnLearnDataset requires the `bnlearn` package (which pulls in pgmpy). "
            "Install it with: pip install bnlearn"
        ) from exc

[docs] class BnLearnDataset(ConceptDataset): """Dataset class for the Asia dataset from bnlearn. This dataset represents a small expert system that models the relationship between traveling to Asia, smoking habits, and various lung diseases. """
[docs] def __init__( self, name: str, # name of the bnlearn DAG root: str = None, # root directory to store/load the dataset seed: int = 42, # seed for data generation n_gen: int = 10000, concept_subset: Optional[list] = None, # subset of concept labels label_descriptions: Optional[dict] = None, autoencoder_kwargs: Optional[dict] = None, # kwargs of the autoencoder used to extract latent representations ): self.name = name self.seed = seed # If root is not provided, create a local folder automatically if root is None: root = os.path.join(os.getcwd(), 'data', self.name) self.root = root self.n_gen = n_gen self.autoencoder_kwargs = autoencoder_kwargs self.label_descriptions = label_descriptions # embeddings is a torch tensor # concepts is a pandas dataframe # annotations is an object Annotations # graph is the adjacency matrix as a pandas dataframe embeddings, concepts, annotations, graph = self.load() # Initialize parent class super().__init__( input_data=embeddings, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_subset, # subset of concept names )
@property def raw_filenames(self) -> List[str]: """List of raw filenames that need to be present in the raw directory for the dataset to be considered present.""" if self.name in BUILTIN_DAGS: return [] # nothing to download, these are built-in in bnlearn else: return [f"{self.name}.bif"] @property def processed_filenames(self) -> List[str]: """List of processed filenames that will be created during build step.""" return [ f"embs_N_{self.n_gen}_seed_{self.seed}.pt", f"concepts_N_{self.n_gen}_seed_{self.seed}.h5", "annotations.pt", "graph.h5" ] def download(self): if self.name in BUILTIN_DAGS: pass else: url = f'https://www.bnlearn.com/bnrepository/{self.name}/{self.name}.bif.gz' gz_path = download_url(url, self.root_dir) bif_path = self.raw_paths[0] # Decompress .gz file with gzip.open(gz_path, 'rb') as f_in: with open(bif_path, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) # Remove the .gz file after extraction os.unlink(gz_path) def build(self): bn, BayesianModelSampling = _import_bnlearn() self.maybe_download() if self.name in BUILTIN_DAGS: self.bn_model_dict = bn.import_DAG(self.name) else: self.bn_model_dict = bn.import_DAG(self.raw_paths[0]) self.bn_model = self.bn_model_dict["model"] # generate data inference = BayesianModelSampling(self.bn_model) df = inference.forward_sample(size=self.n_gen, seed=self.seed) # extract embeddings from latent autoencoder state concepts = df.copy() embeddings = extract_embs_from_autoencoder( df, self.autoencoder_kwargs if self.autoencoder_kwargs is not None else {} ) # get concept annotations concept_names = list(self.bn_model.nodes()) cardinalities = [int(self.bn_model.get_cardinality()[node]) for node in concept_names] # categorical concepts with card=2 will be treated as Bernoulli (card=1) cardinalities = [1 if card == 2 else card for card in cardinalities] # all bnlearn nodes are discrete: card==1 -> binary, card>1 -> categorical types = ['binary' if card == 1 else 'categorical' for card in cardinalities] annotations = Annotations(labels=concept_names, cardinalities=cardinalities, types=types) # get the graph for the endogenous concepts graph = self.bn_model_dict['adjmat'] graph = graph.astype(int) # ---- save all ---- # save embeddings logger.info(f"Saving dataset to {self.root_dir}") torch.save(embeddings, self.processed_paths[0]) # save concepts concepts.to_hdf(self.processed_paths[1], key="concepts", mode="w") # save concept annotations torch.save(annotations, self.processed_paths[2]) # save graph graph.to_hdf(self.processed_paths[3], key="graph", mode="w") def load_raw(self): self.maybe_build() logger.info(f"Loading dataset from {self.root_dir}") embeddings = torch.load(self.processed_paths[0], weights_only=False) concepts = pd.read_hdf(self.processed_paths[1], "concepts") annotations = torch.load(self.processed_paths[2], weights_only=False) graph = pd.read_hdf(self.processed_paths[3], "graph") return embeddings, concepts, annotations, graph def load(self): embeddings, concepts, annotations, graph = self.load_raw() return embeddings, concepts, annotations, graph