Source code for torch_concepts.data.datasets.bnlearn

import os
import gzip
import shutil
import pandas as pd
import torch
import logging
from typing import List, Optional
import bnlearn as bn
from pgmpy.sampling import BayesianModelSampling

from ...annotations import Annotations, AxisAnnotation

logger = logging.getLogger(__name__)

from ..base import ConceptDataset
from ..preprocessing.autoencoder import extract_embs_from_autoencoder
from ..io import download_url

BUILTIN_DAGS = ['asia', 'alarm', 'andes', 'sachs', 'water']

[docs] class BnLearnDataset(ConceptDataset): """Dataset class for the Asia dataset from bnlearn. This dataset represents a small expert system that models the relationship between traveling to Asia, smoking habits, and various lung diseases. """
[docs] def __init__( self, name: str, # name of the bnlearn DAG root: str = None, # root directory to store/load the dataset seed: int = 42, # seed for data generation n_gen: int = 10000, concept_subset: Optional[list] = None, # subset of concept labels label_descriptions: Optional[dict] = None, autoencoder_kwargs: Optional[dict] = None, # kwargs of the autoencoder used to extract latent representations ): self.name = name self.seed = seed # If root is not provided, create a local folder automatically if root is None: root = os.path.join(os.getcwd(), 'data', self.name) self.root = root self.n_gen = n_gen self.autoencoder_kwargs = autoencoder_kwargs self.label_descriptions = label_descriptions # embeddings is a torch tensor # concepts is a pandas dataframe # annotations is an object Annotations # graph is the adjacency matrix as a pandas dataframe embeddings, concepts, annotations, graph = self.load() # Initialize parent class super().__init__( input_data=embeddings, concepts=concepts, annotations=annotations, graph=graph, concept_names_subset=concept_subset, # subset of concept names )
@property def raw_filenames(self) -> List[str]: """List of raw filenames that need to be present in the raw directory for the dataset to be considered present.""" if self.name in BUILTIN_DAGS: return [] # nothing to download, these are built-in in bnlearn else: return [f"{self.name}.bif"] @property def processed_filenames(self) -> List[str]: """List of processed filenames that will be created during build step.""" return [ f"embs_N_{self.n_gen}_seed_{self.seed}.pt", f"concepts_N_{self.n_gen}_seed_{self.seed}.h5", "annotations.pt", "graph.h5" ]
[docs] def download(self): if self.name in BUILTIN_DAGS: pass else: url = f'https://www.bnlearn.com/bnrepository/{self.name}/{self.name}.bif.gz' gz_path = download_url(url, self.root_dir) bif_path = self.raw_paths[0] # Decompress .gz file with gzip.open(gz_path, 'rb') as f_in: with open(bif_path, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) # Remove the .gz file after extraction os.unlink(gz_path)
[docs] def build(self): self.maybe_download() if self.name in BUILTIN_DAGS: self.bn_model_dict = bn.import_DAG(self.name) else: self.bn_model_dict = bn.import_DAG(self.raw_paths[0]) self.bn_model = self.bn_model_dict["model"] # generate data inference = BayesianModelSampling(self.bn_model) df = inference.forward_sample(size=self.n_gen, seed=self.seed) # extract embeddings from latent autoencoder state concepts = df.copy() embeddings = extract_embs_from_autoencoder( df, self.autoencoder_kwargs if self.autoencoder_kwargs is not None else {} ) # get concept annotations concept_names = list(self.bn_model.nodes()) # get concept metadata, store as many objects as you need. # at least store the variable 'type'! ('discrete' or 'continuous') concept_metadata = { node: {'type': 'discrete'} for node in concept_names } cardinalities = [int(self.bn_model.get_cardinality()[node]) for node in concept_names] # categorical concepts with card=2 will be treated as Bernoulli (card=1) cardinalities = [1 if card == 2 else card for card in cardinalities] annotations = Annotations({ # 0: batch axis, do not need to annotate # 1: concepts axis, always annotate 1: AxisAnnotation(labels=concept_names, cardinalities=cardinalities, metadata=concept_metadata)}) # get the graph for the endogenous concepts graph = self.bn_model_dict['adjmat'] graph = graph.astype(int) # ---- save all ---- # save embeddings logger.info(f"Saving dataset to {self.root_dir}") torch.save(embeddings, self.processed_paths[0]) # save concepts concepts.to_hdf(self.processed_paths[1], key="concepts", mode="w") # save concept annotations torch.save(annotations, self.processed_paths[2]) # save graph graph.to_hdf(self.processed_paths[3], key="graph", mode="w")
[docs] def load_raw(self): self.maybe_build() logger.info(f"Loading dataset from {self.root_dir}") embeddings = torch.load(self.processed_paths[0], weights_only=False) concepts = pd.read_hdf(self.processed_paths[1], "concepts") annotations = torch.load(self.processed_paths[2], weights_only=False) graph = pd.read_hdf(self.processed_paths[3], "graph") return embeddings, concepts, annotations, graph
[docs] def load(self): embeddings, concepts, annotations, graph = self.load_raw() return embeddings, concepts, annotations, graph