Source code for torch_concepts.nn.modules.high.models.graph_cbm
"""Graph Concept Bottleneck Model — a CBM over an explicit DAG.
A worked example that the :class:`~torch_concepts.nn.modules.high.base.homogen.HomogenGraphModel`
assembler is genuinely extendible: a plain linear concept bottleneck defined over
an arbitrary DAG (concepts may have concept parents) is obtained by supplying
*only* the encoder and predictor layers — no embeddings, no custom assembly.
"""
from typing import Optional, Union
from .....annotations import Annotations
from .....concept_graph import ConceptGraph
from ...low.encoders.linear import LinearEmbeddingToConcept
from ...low.predictors.linear import LinearConceptToConcept
from ...mid.inference.base import BaseInference
from ...mid.inference.torch.deterministic import DeterministicInference
from ..base.homogen import HomogenGraphModel
[docs]
class GraphConceptBottleneckModel(HomogenGraphModel):
"""Linear concept bottleneck over a DAG: root concepts encoded from the latent,
internal concepts predicted from their parent concepts.
Parameters
----------
input_size : int
Dimensionality of input features (after the backbone, if any).
annotations : Annotations
Concept annotations (labels, cardinalities, types).
graph : ConceptGraph
Directed acyclic graph over the concepts (node names must match labels).
inference, inference_kwargs, train_inference, train_inference_kwargs
Inference engine configuration (see :class:`ConceptBottleneckModel`).
lightning : bool, default False
If True, adds Lightning training capabilities.
**kwargs
Forwarded to :class:`BaseModel` (e.g. ``backbone``, ``latent_size``).
"""
supported_concept_types = frozenset({"binary", "categorical"})
param_for_discrete_var = "logits"
source_embeddings = False
internal_embeddings = False
[docs]
def __init__(
self,
input_size: int,
annotations: Annotations,
graph: ConceptGraph,
inference: Optional[BaseInference] = DeterministicInference,
inference_kwargs: Optional[dict] = None,
train_inference: Optional[BaseInference] = None,
train_inference_kwargs: Optional[dict] = None,
lightning: bool = False,
**kwargs,
):
super().__init__(
input_size=input_size,
annotations=annotations,
graph=graph,
lightning=lightning,
**kwargs,
)
self.pgm = self._build_individual_model()
# once self.pgm is built, we can set up the inference engines (train and eval)
self.setup_inference(
inference,
inference_kwargs,
train_inference,
train_inference_kwargs,
)
# ------------------------------------------------------------------
# Layer hooks (the only model-specific pieces)
# ------------------------------------------------------------------
def build_encoder(self, in_embeddings, out_concepts):
return LinearEmbeddingToConcept(
in_embeddings=in_embeddings,
out_concepts=out_concepts
)
def build_predictor(self, in_concepts: Annotations, in_embeddings, out_concepts):
return LinearConceptToConcept(
in_concepts=int(sum(in_concepts.cardinalities)),
out_concepts=out_concepts
)