"""
Concept graph representation and utilities.
This module provides a memory-efficient implementation of concept graphs using
sparse tensor representations. It includes utilities for graph analysis, conversions,
and topological operations.
"""
import torch
import pandas as pd
from collections import deque
from typing import Dict, List, Tuple, Union, Optional, Set
from torch import Tensor
import networkx as nx
def _dense_to_sparse_pytorch(adj_matrix: Tensor) -> Tuple[Tensor, Tensor]:
"""
Convert dense adjacency matrix to sparse COO format using pure PyTorch.
This is a differentiable alternative to torch_geometric's dense_to_sparse.
Args:
adj_matrix: Dense adjacency matrix of shape (n_nodes, n_nodes)
Returns:
edge_index: Tensor of shape (2, num_edges) with [source, target] indices
edge_weight: Tensor of shape (num_edges,) with edge weights
"""
# Get non-zero indices using torch.nonzero (differentiable)
indices = torch.nonzero(adj_matrix, as_tuple=False)
if indices.numel() == 0:
# Empty graph - return empty tensors with proper shape
device = adj_matrix.device
dtype = adj_matrix.dtype
return (torch.empty((2, 0), dtype=torch.long, device=device),
torch.empty(0, dtype=dtype, device=device))
# Transpose to get shape (2, num_edges) for edge_index
edge_index = indices.t().contiguous()
# Extract edge weights at non-zero positions
edge_weight = adj_matrix[indices[:, 0], indices[:, 1]]
return edge_index, edge_weight
[docs]
class ConceptGraph:
"""
Memory-efficient concept graph representation using sparse COO format.
This class stores graphs in sparse format (edge list) internally, making it
efficient for large sparse graphs. It provides utilities for graph analysis
and conversions to dense/NetworkX/pandas formats.
The graph is stored as:
- edge_index: Tensor of shape (2, num_edges) with [source, target] indices
- edge_weight: Tensor of shape (num_edges,) with edge weights
- node_names: List of node names
Attributes:
edge_index (Tensor): Edge list of shape (2, num_edges)
edge_weight (Tensor): Edge weights of shape (num_edges,)
node_names (List[str]): Names of nodes in the graph
n_nodes (int): Number of nodes in the graph
Args:
data (Tensor): Dense adjacency matrix of shape (n_nodes, n_nodes)
node_names (List[str], optional): Node names. If None, generates default names.
Example:
>>> import torch
>>> from torch_concepts import ConceptGraph
>>>
>>> # Create a simple directed graph
>>> # A -> B -> C
>>> # A -> C
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> graph = ConceptGraph(adj, node_names=['A', 'B', 'C'])
>>>
>>> # Get root nodes (no incoming edges)
>>> print(graph.get_root_nodes())
['A']
>>>
>>> # Get leaf nodes (no outgoing edges)
>>> print(graph.get_leaf_nodes())
['C']
>>>
>>> # Check edge existence
>>> print(graph.has_edge('A', 'B'))
True
>>> print(graph.has_edge('B', 'A'))
False
>>>
>>> # Get edge weight
>>> print(graph.get_edge_weight('A', 'C'))
1.0
>>>
>>> # Get successors and predecessors
>>> print(graph.get_successors('A'))
['B', 'C']
>>> print(graph.get_predecessors('C'))
['A', 'B']
>>>
>>> # Check if DAG
>>> print(graph.is_dag())
True
>>>
>>> # Topological sort
>>> print(graph.topological_sort())
['A', 'B', 'C']
>>>
>>> # Convert to NetworkX for visualization
>>> nx_graph = graph.to_networkx()
>>>
>>> # Convert to pandas DataFrame
>>> df = graph.to_pandas()
>>>
>>> # Create from sparse format directly
>>> edge_index = torch.tensor([[0, 0, 1], [1, 2, 2]])
>>> edge_weight = torch.tensor([1.0, 1.0, 1.0])
>>> graph2 = ConceptGraph.from_sparse(
... edge_index, edge_weight, n_nodes=3,
... node_names=['X', 'Y', 'Z']
... )
"""
[docs]
def __init__(self, data: Tensor, node_names: Optional[List[str]] = None):
"""Create new ConceptGraph instance from dense adjacency matrix."""
# Validate shape
if data.dim() != 2:
raise ValueError(f"Adjacency matrix must be 2D, got {data.dim()}D")
if data.shape[0] != data.shape[1]:
raise ValueError(f"Adjacency matrix must be square, got shape {data.shape}")
self._n_nodes = data.shape[0]
self.node_names = node_names if node_names is not None else [f"node_{i}" for i in range(self._n_nodes)]
if len(self.node_names) != self._n_nodes:
raise ValueError(f"Number of node names ({len(self.node_names)}) must match matrix size ({self._n_nodes})")
# Pre-compute node name to index mapping for O(1) lookup
self._node_name_to_index = {name: idx for idx, name in enumerate(self.node_names)}
# Convert to sparse format and store
self._edge_index, self._edge_weight = _dense_to_sparse_pytorch(data)
# Cache networkx graph for faster repeated access
self._nx_graph_cache = None
@property
def edge_index(self) -> Tensor:
"""Edge list of shape (2, num_edges)."""
return self._edge_index
@edge_index.setter
def edge_index(self, value: Tensor):
self._edge_index = value
self._nx_graph_cache = None # invalidate cache
@property
def edge_weight(self) -> Tensor:
"""Edge weights of shape (num_edges,)."""
return self._edge_weight
@edge_weight.setter
def edge_weight(self, value: Tensor):
self._edge_weight = value
self._nx_graph_cache = None # invalidate cache
@classmethod
def from_sparse(cls, edge_index: Tensor, edge_weight: Tensor, n_nodes: int, node_names: Optional[List[str]] = None):
"""
Create ConceptGraph directly from sparse format (more efficient).
Args:
edge_index: Tensor of shape (2, num_edges) with [source, target] indices
edge_weight: Tensor of shape (num_edges,) with edge weights
n_nodes: Number of nodes in the graph
node_names: Optional node names
Returns:
ConceptGraph instance
Example:
>>> import torch
>>> from torch_concepts import ConceptGraph
>>> edge_index = torch.tensor([[0, 0, 1], [1, 2, 2]])
>>> edge_weight = torch.tensor([1.0, 1.0, 1.0])
>>> graph = ConceptGraph.from_sparse(edge_index, edge_weight, n_nodes=3)
"""
# Create instance without going through __init__
instance = cls.__new__(cls)
instance._n_nodes = n_nodes
instance.node_names = node_names if node_names is not None else [f"node_{i}" for i in range(n_nodes)]
if len(instance.node_names) != n_nodes:
raise ValueError(f"Number of node names ({len(instance.node_names)}) must match n_nodes ({n_nodes})")
# Pre-compute node name to index mapping for O(1) lookup
instance._node_name_to_index = {name: idx for idx, name in enumerate(instance.node_names)}
instance.edge_index = edge_index
instance.edge_weight = edge_weight
# Cache networkx graph for faster repeated access
instance._nx_graph_cache = None
return instance
@property
def n_nodes(self) -> int:
"""Get number of nodes in the graph."""
return self._n_nodes
@property
def data(self) -> Tensor:
"""
Get dense adjacency matrix representation.
Note: This reconstructs the dense matrix from sparse format.
For frequent dense access, consider caching the result.
Returns:
Dense adjacency matrix of shape (n_nodes, n_nodes)
"""
# Reconstruct dense matrix from sparse format
adj = torch.zeros(self._n_nodes, self._n_nodes, dtype=self.edge_weight.dtype, device=self.edge_weight.device)
adj[self.edge_index[0], self.edge_index[1]] = self.edge_weight
return adj
def _node_to_index(self, node: Union[str, int]) -> int:
"""Convert node name or index to index."""
if isinstance(node, int):
if node < 0 or node >= self.n_nodes:
raise IndexError(f"Node index {node} out of range [0, {self.n_nodes})")
return node
elif isinstance(node, str):
# Use pre-computed dictionary for O(1) lookup instead of O(n) list search
idx = self._node_name_to_index.get(node)
if idx is None:
raise ValueError(f"Node '{node}' not found in graph")
return idx
else:
raise TypeError(f"Node must be str or int, got {type(node)}")
def __getitem__(self, key):
"""
Allow indexing like graph[i, j] or graph['A', 'B'].
For single edge queries (tuple of 2), uses sparse lookup.
For slice/advanced indexing, falls back to dense representation.
"""
if isinstance(key, tuple) and len(key) == 2:
# Optimized path for single edge lookup
row = self._node_to_index(key[0])
col = self._node_to_index(key[1])
# Search in sparse edge list
mask = (self.edge_index[0] == row) & (self.edge_index[1] == col)
if mask.any():
return self.edge_weight[mask]
return torch.tensor(0.0, dtype=self.edge_weight.dtype, device=self.edge_weight.device)
# For advanced indexing, use dense representation
return self.data[key]
def get_edge_weight(self, source: Union[str, int], target: Union[str, int]) -> float:
"""
Get the weight of an edge.
Args:
source: Source node name or index
target: Target node name or index
Returns:
Edge weight value (0.0 if edge doesn't exist)
"""
source_idx = self._node_to_index(source)
target_idx = self._node_to_index(target)
# Search in sparse edge list
mask = (self.edge_index[0] == source_idx) & (self.edge_index[1] == target_idx)
if mask.any():
return self.edge_weight[mask].item()
return 0.0
def has_edge(self, source: Union[str, int], target: Union[str, int], threshold: float = 0.0) -> bool:
"""
Check if an edge exists between two nodes.
Args:
source: Source node name or index
target: Target node name or index
threshold: Minimum weight to consider as edge
Returns:
True if edge exists, False otherwise
"""
weight = self.get_edge_weight(source, target)
return abs(weight) > threshold
def to_pandas(self) -> pd.DataFrame:
"""
Convert adjacency matrix to pandas DataFrame.
Returns:
pd.DataFrame with node names as index and columns
"""
return pd.DataFrame(
self.data.cpu().numpy(),
index=self.node_names,
columns=self.node_names
)
@property
def _nx_graph(self) -> nx.DiGraph:
"""
Get cached NetworkX graph (lazy initialization).
This property caches the NetworkX graph for faster repeated access.
The cache is created on first access.
Returns:
nx.DiGraph: Cached NetworkX directed graph
"""
if self._nx_graph_cache is None:
self._nx_graph_cache = self.to_networkx()
return self._nx_graph_cache
def to_networkx(self, threshold: float = 0.0) -> nx.DiGraph:
"""
Convert to NetworkX directed graph.
Args:
threshold: Minimum absolute value to consider as an edge
Returns:
nx.DiGraph: NetworkX directed graph
"""
# If threshold is 0.0 and we have a cache, return it
if threshold == 0.0 and self._nx_graph_cache is not None:
return self._nx_graph_cache
# Create empty directed graph
G = nx.DiGraph()
# Add all nodes with their names
G.add_nodes_from(self.node_names)
# Add edges from sparse representation
edge_index_np = self.edge_index.cpu().numpy()
edge_weight_np = self.edge_weight.cpu().numpy()
for i in range(edge_index_np.shape[1]):
source_idx = edge_index_np[0, i]
target_idx = edge_index_np[1, i]
weight = edge_weight_np[i]
# Apply threshold
if abs(weight) > threshold:
source_name = self.node_names[source_idx]
target_name = self.node_names[target_idx]
G.add_edge(source_name, target_name, weight=weight)
# Cache if threshold is 0.0
if threshold == 0.0 and self._nx_graph_cache is None:
self._nx_graph_cache = G
return G
def dense_to_sparse(self, threshold: float = 0.0) -> Tuple[Tensor, Tensor]:
"""
Get sparse COO format (edge list) representation.
Args:
threshold: Minimum value to consider as an edge (default: 0.0)
Returns:
edge_index: Tensor of shape (2, num_edges) with source and target indices
edge_weight: Tensor of shape (num_edges,) with edge weights
"""
if threshold > 0.0:
# Filter edges by threshold
mask = torch.abs(self.edge_weight) > threshold
return self.edge_index[:, mask], self.edge_weight[mask]
return self.edge_index, self.edge_weight
def get_root_nodes(self) -> List[str]:
"""
Get nodes with no incoming edges (in-degree = 0).
Returns:
List of root node names
"""
G = self._nx_graph
return [node for node, degree in G.in_degree() if degree == 0]
def get_leaf_nodes(self) -> List[str]:
"""
Get nodes with no outgoing edges (out-degree = 0).
Returns:
List of leaf node names
"""
G = self._nx_graph
return [node for node, degree in G.out_degree() if degree == 0]
def topological_sort(self) -> List[str]:
"""
Compute topological ordering of nodes.
Only valid for directed acyclic graphs (DAGs).
Returns:
List of node names in topological order
Raises:
nx.NetworkXError: If graph contains cycles
"""
G = self._nx_graph
return list(nx.topological_sort(G))
def get_levels(self) -> List[List[str]]:
"""Group nodes by depth from the graph roots.
Only valid for directed acyclic graphs (DAGs).
Returns:
List of lists, where ``result[d]`` contains the node names at
depth *d*. The outer list is sorted by increasing depth.
Example:
>>> import torch
>>> from torch_concepts import ConceptGraph
>>> adj = torch.tensor([[0., 1., 0.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> g = ConceptGraph(adj, node_names=['A', 'B', 'C'])
>>> g.get_levels()
[['A'], ['B'], ['C']]
"""
roots = self.get_root_nodes()
depths: Dict[str, int] = {}
queue: deque = deque()
for root in roots:
depths[root] = 0
queue.append(root)
while queue:
node = queue.popleft()
for child in self.get_successors(node):
new_depth = depths[node] + 1
if child not in depths or new_depth < depths[child]:
depths[child] = new_depth
queue.append(child)
# Nodes with no edges at all (isolated) default to depth 0
for name in self.node_names:
if name not in depths:
depths[name] = 0
# Group by depth
groups: Dict[int, List[str]] = {}
for name, d in depths.items():
groups.setdefault(d, []).append(name)
max_depth = max(groups) if groups else -1
return [groups.get(d, []) for d in range(max_depth + 1)]
def get_predecessors(self, node: Union[str, int]) -> List[str]:
"""
Get immediate predecessors (parents) of a node.
Args:
node: Node name (str) or index (int)
Returns:
List of predecessor node names
"""
G = self._nx_graph
node_name = self.node_names[node] if isinstance(node, int) else node
return list(G.predecessors(node_name))
def get_successors(self, node: Union[str, int]) -> List[str]:
"""
Get immediate successors (children) of a node.
Args:
node: Node name (str) or index (int)
Returns:
List of successor node names
"""
G = self._nx_graph
node_name = self.node_names[node] if isinstance(node, int) else node
return list(G.successors(node_name))
def get_ancestors(self, node: Union[str, int]) -> Set[str]:
"""
Get all ancestors of a node (transitive predecessors).
Args:
node: Node name (str) or index (int)
Returns:
Set of ancestor node names
"""
G = self._nx_graph
node_name = self.node_names[node] if isinstance(node, int) else node
return nx.ancestors(G, node_name)
def get_descendants(self, node: Union[str, int]) -> Set[str]:
"""
Get all descendants of a node (transitive successors).
Args:
node: Node name (str) or index (int)
Returns:
Set of descendant node names
"""
G = self._nx_graph
node_name = self.node_names[node] if isinstance(node, int) else node
return nx.descendants(G, node_name)
def is_directed_acyclic(self) -> bool:
"""
Check if the graph is a directed acyclic graph (DAG).
Returns:
True if graph is a DAG, False otherwise
"""
G = self._nx_graph
return nx.is_directed_acyclic_graph(G)
def is_dag(self) -> bool:
"""
Check if the graph is a directed acyclic graph (DAG).
Alias for is_directed_acyclic() for convenience.
Returns:
True if graph is a DAG, False otherwise
"""
return self.is_directed_acyclic()
def dense_to_sparse(
adj_matrix: Union[ConceptGraph, Tensor],
threshold: float = 0.0
) -> Tuple[Tensor, Tensor]:
"""
Convert dense adjacency matrix to sparse COO format (edge list).
Uses PyTorch Geometric's native dense_to_sparse function.
Args:
adj_matrix: Dense adjacency matrix (ConceptGraph or Tensor) of shape (n_nodes, n_nodes)
threshold: Minimum absolute value to consider as an edge (only used in fallback)
Returns:
edge_index: Tensor of shape (2, num_edges) with [source_indices, target_indices]
edge_weight: Tensor of shape (num_edges,) with edge weights
Example:
>>> import torch
>>> from torch_concepts.concept_graph import dense_to_sparse
>>> adj = torch.tensor([[0., 1., 0.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> edge_index, edge_weight = dense_to_sparse(adj)
>>> print(edge_index)
tensor([[0, 1],
[1, 2]])
>>> print(edge_weight)
tensor([1., 1.])
"""
# Extract tensor data
if isinstance(adj_matrix, ConceptGraph):
adj_tensor = adj_matrix.data
else:
adj_tensor = adj_matrix
return _dense_to_sparse_pytorch(adj_tensor)
def to_networkx_graph(
adj_matrix: Union[ConceptGraph, Tensor],
node_names: Optional[List[str]] = None,
threshold: float = 0.0
) -> nx.DiGraph:
"""
Convert adjacency matrix to NetworkX directed graph.
Uses NetworkX's native from_numpy_array function for conversion.
Args:
adj_matrix: Adjacency matrix (ConceptGraph or Tensor)
node_names: Optional node names. If adj_matrix is ConceptGraph,
uses its node_names. Otherwise uses integer indices.
threshold: Minimum absolute value to consider as an edge
Returns:
nx.DiGraph: NetworkX directed graph
Example:
>>> import torch
>>> from torch_concepts.concept_graph import to_networkx_graph
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> G = to_networkx_graph(adj, node_names=['A', 'B', 'C'])
>>> print(list(G.nodes()))
['A', 'B', 'C']
>>> print(list(G.edges()))
[('A', 'B'), ('A', 'C'), ('B', 'C')]
"""
# Extract node names and tensor data
if isinstance(adj_matrix, ConceptGraph):
if node_names is None:
node_names = adj_matrix.node_names
adj_tensor = adj_matrix.data
else:
adj_tensor = adj_matrix
if node_names is None:
node_names = list(range(adj_tensor.shape[0]))
# Apply threshold if needed
if threshold > 0.0:
adj_tensor = adj_tensor.clone()
adj_tensor[torch.abs(adj_tensor) <= threshold] = 0.0
# Convert to numpy for NetworkX
adj_numpy = adj_tensor.detach().cpu().numpy()
# Use NetworkX's native conversion
G = nx.from_numpy_array(adj_numpy, create_using=nx.DiGraph)
# Relabel nodes with custom names if provided
if node_names != list(range(len(node_names))):
mapping = {i: name for i, name in enumerate(node_names)}
G = nx.relabel_nodes(G, mapping)
return G
def get_root_nodes(
adj_matrix: Union[ConceptGraph, Tensor, nx.DiGraph],
node_names: Optional[List[str]] = None
) -> List[str]:
"""
Get nodes with no incoming edges (in-degree = 0).
Args:
adj_matrix: Adjacency matrix (ConceptGraph, Tensor) or NetworkX graph
node_names: Optional node names (only needed if adj_matrix is Tensor)
Returns:
List of root node names
Example:
>>> import torch
>>> from torch_concepts.concept_graph import get_root_nodes
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> roots = get_root_nodes(adj, node_names=['A', 'B', 'C'])
>>> print(roots)
['A']
"""
if isinstance(adj_matrix, nx.DiGraph):
G = adj_matrix
else:
if isinstance(adj_matrix, ConceptGraph):
node_names = adj_matrix.node_names
G = to_networkx_graph(adj_matrix, node_names=node_names)
return [node for node, degree in G.in_degree() if degree == 0]
def get_leaf_nodes(
adj_matrix: Union[ConceptGraph, Tensor, nx.DiGraph],
node_names: Optional[List[str]] = None
) -> List[str]:
"""
Get nodes with no outgoing edges (out-degree = 0).
Args:
adj_matrix: Adjacency matrix (ConceptGraph, Tensor) or NetworkX graph
node_names: Optional node names (only needed if adj_matrix is Tensor)
Returns:
List of leaf node names
Example:
>>> import torch
>>> from torch_concepts.concept_graph import get_leaf_nodes
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> leaves = get_leaf_nodes(adj, node_names=['A', 'B', 'C'])
>>> print(leaves)
['C']
"""
if isinstance(adj_matrix, nx.DiGraph):
G = adj_matrix
else:
if isinstance(adj_matrix, ConceptGraph):
node_names = adj_matrix.node_names
G = to_networkx_graph(adj_matrix, node_names=node_names)
return [node for node, degree in G.out_degree() if degree == 0]
def topological_sort(
adj_matrix: Union[ConceptGraph, Tensor, nx.DiGraph],
node_names: Optional[List[str]] = None
) -> List[str]:
"""
Compute topological ordering of nodes (only for DAGs).
Uses NetworkX's native topological_sort function.
Args:
adj_matrix: Adjacency matrix (ConceptGraph, Tensor) or NetworkX graph
node_names: Optional node names (only needed if adj_matrix is Tensor)
Returns:
List of node names in topological order
Raises:
nx.NetworkXError: If graph contains cycles
Example:
>>> import torch
>>> from torch_concepts.concept_graph import topological_sort
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> ordered = topological_sort(adj, node_names=['A', 'B', 'C'])
>>> print(ordered)
['A', 'B', 'C']
"""
if isinstance(adj_matrix, nx.DiGraph):
G = adj_matrix
else:
if isinstance(adj_matrix, ConceptGraph):
node_names = adj_matrix.node_names
G = to_networkx_graph(adj_matrix, node_names=node_names)
return list(nx.topological_sort(G))
def get_predecessors(
adj_matrix: Union[ConceptGraph, Tensor, nx.DiGraph],
node: Union[str, int],
node_names: Optional[List[str]] = None
) -> List[str]:
"""
Get immediate predecessors (parents) of a node.
Uses NetworkX's native predecessors method.
Args:
adj_matrix: Adjacency matrix (ConceptGraph, Tensor) or NetworkX graph
node: Node name (str) or index (int)
node_names: Optional node names (only needed if adj_matrix is Tensor)
Returns:
List of predecessor node names
Example:
>>> import torch
>>> from torch_concepts.concept_graph import get_predecessors
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> preds = get_predecessors(adj, 'C', node_names=['A', 'B', 'C'])
>>> print(preds)
['A', 'B']
"""
if isinstance(adj_matrix, nx.DiGraph):
G = adj_matrix
if isinstance(node, int) and node_names:
node = node_names[node]
else:
if isinstance(adj_matrix, ConceptGraph):
node_names = adj_matrix.node_names
G = to_networkx_graph(adj_matrix, node_names=node_names)
if isinstance(node, int):
node = node_names[node]
return list(G.predecessors(node))
def get_successors(
adj_matrix: Union[ConceptGraph, Tensor, nx.DiGraph],
node: Union[str, int],
node_names: Optional[List[str]] = None
) -> List[str]:
"""
Get immediate successors (children) of a node.
Uses NetworkX's native successors method.
Args:
adj_matrix: Adjacency matrix (ConceptGraph, Tensor) or NetworkX graph
node: Node name (str) or index (int)
node_names: Optional node names (only needed if adj_matrix is Tensor)
Returns:
List of successor node names
Example:
>>> import torch
>>> from torch_concepts.concept_graph import get_successors
>>> adj = torch.tensor([[0., 1., 1.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> succs = get_successors(adj, 'A', node_names=['A', 'B', 'C'])
>>> print(succs)
['B', 'C']
"""
if isinstance(adj_matrix, nx.DiGraph):
G = adj_matrix
if isinstance(node, int) and node_names:
node = node_names[node]
else:
if isinstance(adj_matrix, ConceptGraph):
node_names = adj_matrix.node_names
G = to_networkx_graph(adj_matrix, node_names=node_names)
if isinstance(node, int):
node = node_names[node]
return list(G.successors(node))