Model Constructors¶
This module provides constructors for building concept-based models from specifications.
Summary¶
Constructor Classes
Bipartite concept graph model with concepts and tasks in separate layers. |
|
Concept-based model with explicit graph structure between concepts and tasks. |
Class Documentation¶
- class BipartiteModel(task_names: List[str] | str, input_size: int, annotations: Annotations, encoder: LazyConstructor | Module, predictor: LazyConstructor | Module, use_source_exogenous: bool | None = None, source_exogenous: LazyConstructor | Module | None = None, internal_exogenous: LazyConstructor | Module | None = None)[source]¶
Bases:
GraphModelBipartite concept graph model with concepts and tasks in separate layers.
This model implements a bipartite graph structure where concepts only connect to tasks (not to each other), creating a clean separation between concept and task layers. This is useful for multi-task learning with shared concepts.
- Parameters:
task_names – List of task names (must be in annotations labels).
input_size – Size of input features.
annotations – Annotations object with concept and task metadata.
encoder – LazyConstructor for encoding concepts from inputs.
predictor – LazyConstructor for predicting tasks from concepts.
use_source_exogenous – Whether to use exogenous features for source nodes.
source_exogenous – Optional propagator for source exogenous features.
internal_exogenous – Optional propagator for internal exogenous features.
Example
>>> import torch >>> from torch_concepts import Annotations, AxisAnnotation >>> from torch_concepts.nn import BipartiteModel, LazyConstructor, LinearCC >>> from torch.distributions import Bernoulli >>> >>> # Define concepts and tasks >>> all_labels = ('color', 'shape', 'size', 'task1', 'task2') >>> metadata = {'color': {'distribution': Bernoulli}, ... 'shape': {'distribution': Bernoulli}, ... 'size': {'distribution': Bernoulli}, ... 'task1': {'distribution': Bernoulli}, ... 'task2': {'distribution': Bernoulli}} >>> annotations = Annotations({ ... 1: AxisAnnotation(labels=all_labels, metadata=metadata) ... }) >>> >>> # Create bipartite model with tasks >>> task_names = ['task1', 'task2'] >>> >>> model = BipartiteModel( ... task_names=task_names, ... input_size=784, ... annotations=annotations, ... encoder=LazyConstructor(torch.nn.Linear), ... predictor=LazyConstructor(LinearCC) ... ) >>> >>> # Generate random input >>> x = torch.randn(8, 784) # batch_size=8 >>> >>> # Forward pass (implementation depends on GraphModel) >>> # Concepts are encoded, then tasks predicted from concepts >>> print(model.concept_names) # ['color', 'shape', 'size'] >>> print(model.task_names) # ['task1', 'task2'] >>> print(model.probabilistic_model) >>> >>> # The bipartite structure ensures: >>> # - Concepts don't predict other concepts >>> # - Only concepts -> tasks edges exist
- class GraphModel(model_graph: ConceptGraph, input_size: int, annotations: Annotations, encoder: LazyConstructor | Module, predictor: LazyConstructor | Module, use_source_exogenous: bool | None = None, source_exogenous: LazyConstructor | Module | None = None, internal_exogenous: LazyConstructor | Module | None = None)[source]¶
Bases:
BaseConstructorConcept-based model with explicit graph structure between concepts and tasks.
This model builds a probabilistic model based on a provided concept graph structure. It automatically constructs the necessary variables and CPDs following the graph’s topological order, supporting both root concepts (encoded from inputs) and internal concepts (predicted from parents).
The graph structure defines dependencies between concepts, enabling: - Hierarchical concept learning - Causal reasoning with interventions - Structured prediction with concept dependencies
- model_graph¶
Directed acyclic graph defining concept relationships.
- Type:
ConceptGraph
- probabilistic_model¶
Underlying PGM with variables and CPDs.
- Type:
- Parameters:
model_graph – ConceptGraph defining the structure (must be a DAG).
input_size – Size of input features.
annotations – Annotations object with concept metadata and distributions.
encoder – LazyConstructor for encoding root concepts from inputs.
predictor – LazyConstructor for predicting internal concepts from parents.
use_source_exogenous – Whether to use source exogenous features for predictions.
source_exogenous – Optional propagator for source exogenous features.
internal_exogenous – Optional propagator for internal exogenous features.
- Raises:
AssertionError – If model_graph is not a DAG.
AssertionError – If node names don’t match annotations labels.
Example
>>> import torch >>> import pandas as pd >>> from torch_concepts import Annotations, AxisAnnotation, ConceptGraph >>> from torch_concepts.nn import GraphModel, LazyConstructor, LinearCC >>> from torch.distributions import Bernoulli >>> >>> # Define concepts and their structure >>> # Structure: input -> [A, B] -> C -> D >>> # A and B are root nodes (no parents) >>> # C depends on A and B >>> # D depends on C >>> concept_names = ['A', 'B', 'C', 'D'] >>> >>> # Create graph structure as adjacency matrix >>> graph_df = pd.DataFrame(0, index=concept_names, columns=concept_names) >>> graph_df.loc['A', 'C'] = 1 # A -> C >>> graph_df.loc['B', 'C'] = 1 # B -> C >>> graph_df.loc['C', 'D'] = 1 # C -> D >>> >>> graph = ConceptGraph( ... torch.FloatTensor(graph_df.values), ... node_names=concept_names ... ) >>> >>> # Create annotations with distributions >>> annotations = Annotations({ ... 1: AxisAnnotation( ... labels=tuple(concept_names), ... metadata={ ... 'A': {'distribution': Bernoulli}, ... 'B': {'distribution': Bernoulli}, ... 'C': {'distribution': Bernoulli}, ... 'D': {'distribution': Bernoulli} ... } ... ) ... }) >>> >>> # Create GraphModel >>> model = GraphModel( ... model_graph=graph, ... input_size=784, ... annotations=annotations, ... encoder=LazyConstructor(torch.nn.Linear), ... predictor=LazyConstructor(LinearCC), ... ) >>> >>> # Inspect the graph structure >>> print(model.root_nodes) # ['A', 'B'] - no parents >>> print(model.internal_nodes) # ['C', 'D'] - have parents >>> print(model.graph_order) # ['A', 'B', 'C', 'D'] - topological order >>> >>> # Check graph properties >>> print(model.model_graph.is_dag()) # True >>> print(model.model_graph.get_predecessors('C')) # ['A', 'B'] >>> print(model.model_graph.get_successors('C')) # ['D']
- References
Dominici, et al. “Causal concept graph models: Beyond causal opacity in deep learning”, ICLR 2025. https://arxiv.org/abs/2405.16507. De Felice, et al. “Causally reliable concept bottleneck models”, NeurIPS https://arxiv.org/abs/2503.04363v1.