Graph Learners

This module provides graph learning algorithms for discovering concept relationships from data.

Summary

Graph Learning Classes

WANDAGraphLearner

WANDA Graph Learner for concept structure discovery.

Class Documentation

class WANDAGraphLearner(row_labels: List[str], col_labels: List[str], priority_var: float = 1.0, hard_threshold: bool = True, threshold_init: float = 0.0, eps: float = 1e-12)[source]

Bases: BaseGraphLearner

WANDA Graph Learner for concept structure discovery. Adapted from COSMO.

WANDA learns a directed acyclic graph (DAG) structure by assigning priority values to concepts and creating edges based on priority differences. This approach ensures acyclicity by construction.

np_params

Learnable priority values for each concept.

Type:

nn.Parameter

priority_var

Variance for priority initialization.

Type:

float

threshold

Fixed threshold for edge creation (not learnable).

Type:

torch.Tensor

hard_threshold

Whether to use hard or soft thresholding.

Type:

bool

Parameters:
  • row_labels – List of concept names for graph rows.

  • col_labels – List of concept names for graph columns.

  • priority_var – Variance for priority initialization (default: 1.0).

  • hard_threshold – Use hard thresholding for edges (default: True).

  • threshold_init – Initial value for threshold (default: 0.0).

Example

>>> import torch
>>> from torch_concepts.nn import WANDAGraphLearner
>>>
>>> # Create WANDA learner for 5 concepts
>>> concepts = ['c1', 'c2', 'c3', 'c4', 'c5']
>>> wanda = WANDAGraphLearner(
...     row_labels=concepts,
...     col_labels=concepts,
...     priority_var=1.0,
...     hard_threshold=True,
...     threshold_init=0.5
... )
>>>
>>> # Get current graph estimate
>>> adj_matrix = wanda.weighted_adj
>>> print(adj_matrix.shape)
torch.Size([5, 5])

References

Massidda et al. “Constraint-Free Structure Learning with Smooth Acyclic Orientations”. https://arxiv.org/abs/2309.08406

property weighted_adj: Tensor

Compute the weighted adjacency matrix from learned priorities.

Computes an orientation matrix based on priority differences. An edge from i to j exists when priority[j] > priority[i] + threshold[i]. The diagonal is always zero (no self-loops).

Returns:

Weighted adjacency matrix of shape (n_labels, n_labels).

Return type:

torch.Tensor

training: bool