Graph Learners¶
This module provides graph learning algorithms for discovering concept relationships from data.
Summary¶
Graph Learning Classes
WANDA Graph Learner for concept structure discovery. |
Class Documentation¶
- class WANDAGraphLearner(row_labels: List[str], col_labels: List[str], priority_var: float = 1.0, hard_threshold: bool = True, threshold_init: float = 0.0, eps: float = 1e-12)[source]¶
Bases:
BaseGraphLearnerWANDA Graph Learner for concept structure discovery. Adapted from COSMO.
WANDA learns a directed acyclic graph (DAG) structure by assigning priority values to concepts and creating edges based on priority differences. This approach ensures acyclicity by construction.
- np_params¶
Learnable priority values for each concept.
- Type:
nn.Parameter
- threshold¶
Fixed threshold for edge creation (not learnable).
- Type:
- Parameters:
row_labels – List of concept names for graph rows.
col_labels – List of concept names for graph columns.
priority_var – Variance for priority initialization (default: 1.0).
hard_threshold – Use hard thresholding for edges (default: True).
threshold_init – Initial value for threshold (default: 0.0).
Example
>>> import torch >>> from torch_concepts.nn import WANDAGraphLearner >>> >>> # Create WANDA learner for 5 concepts >>> concepts = ['c1', 'c2', 'c3', 'c4', 'c5'] >>> wanda = WANDAGraphLearner( ... row_labels=concepts, ... col_labels=concepts, ... priority_var=1.0, ... hard_threshold=True, ... threshold_init=0.5 ... ) >>> >>> # Get current graph estimate >>> adj_matrix = wanda.weighted_adj >>> print(adj_matrix.shape) torch.Size([5, 5])
References
Massidda et al. “Constraint-Free Structure Learning with Smooth Acyclic Orientations”. https://arxiv.org/abs/2309.08406
- property weighted_adj: Tensor¶
Compute the weighted adjacency matrix from learned priorities.
Computes an orientation matrix based on priority differences. An edge from i to j exists when priority[j] > priority[i] + threshold[i]. The diagonal is always zero (no self-loops).
- Returns:
Weighted adjacency matrix of shape (n_labels, n_labels).
- Return type: