Source code for torch_concepts.nn.modules.low.inference.intervention

"""
Inference and intervention modules for concept-based models.

This module provides intervention strategies that modify concept values during
inference, enabling causal reasoning and what-if analysis in concept-based models.
"""
import math
import contextlib
from abc import abstractmethod
from typing import List, Sequence, Union, Optional, Dict
import torch
import torch.nn as nn

from ...mid.models.cpd import ParametricCPD
from ..base.inference import BaseIntervention

# ---------------- core helpers ----------------

def _get_submodule(model: nn.Module, dotted: str) -> nn.Module:
    cur = model
    for name in dotted.split("."):
        cur = cur.get_submodule(name)
    return cur

def _set_submodule(model: nn.Module, dotted: str, new: nn.Module) -> None:
    parts = dotted.split(".")
    # validate
    if len(parts) == 0 or (len(parts) == 1 and parts[0] == ""):
        raise ValueError("Dotted path must not be empty")

    parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model
    name = parts[-1]

    # If parent supports indexed assignment (e.g., nn.Sequential) and the name is an index, set by index
    if name.isdigit() and hasattr(parent, "__setitem__"):
        idx = int(name)
        parent[idx] = new
        return

    # Otherwise set as attribute on parent.
    # If the new module is already a ParametricCPD, keep it. If not and we're attaching
    # it as a plain attribute on a Module, wrap it into a ParametricCPD so semantics are preserved.
    if isinstance(new, ParametricCPD):
        setattr(parent, name, new)
    else:
        setattr(parent, name, ParametricCPD(concepts=dotted, parametrization=new))

def _as_list(x, n: int):
    # broadcast a singleton to length n; if already a list/tuple, validate length
    if isinstance(x, (list, tuple)):
        if len(x) != n:
            raise ValueError(f"Expected list of length {n}, got {len(x)}")
        return list(x)
    return [x for _ in range(n)]

# ---------------- strategy ----------------

[docs] class RewiringIntervention(BaseIntervention): """ Base class for rewiring-based interventions. Rewiring interventions replace predicted concept values with target values based on a binary mask, implementing do-calculus operations. Args: model: The concept-based model to intervene on. Example: >>> import torch >>> from torch_concepts.nn import RewiringIntervention >>> >>> # Subclass to create custom intervention >>> class MyIntervention(RewiringIntervention): ... def _make_target(self, y, *args, **kwargs): ... return torch.ones_like(y) >>> """
[docs] def __init__(self, model: nn.Module, *args, **kwargs): super().__init__(model)
@abstractmethod def _make_target(self, y: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Create target tensor for intervention. Args: y: Predicted concept values. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: torch.Tensor: Target values for intervention. """ raise NotImplementedError
[docs] def query(self, original_module: nn.Module, mask: torch.Tensor, *args, **kwargs) -> nn.Module: """ Create an intervention wrapper module. Args: original_module: The original module to wrap. mask: Binary mask (1=keep prediction, 0=replace with target). *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: nn.Module: Wrapped module with intervention applied. """ parent = self class _Rewire(nn.Module): def __init__(self, orig: nn.Module, mask_: torch.Tensor): super().__init__() self.orig = orig self.register_buffer("mask", mask_.clone()) def forward(self, **kwargs) -> torch.Tensor: y = self.orig(**kwargs) # [B, F] assert y.dim() == 2, (f"RewiringIntervention expects 2-D tensors [Batch, N_concepts]. " f"Got shape: {y.shape}") t = parent._make_target(y) # [B, F] m = self.mask.to(dtype=y.dtype) return y * m + t * (1.0 - m) return _Rewire(original_module, mask)
# -------------------- Concrete strategies --------------------
[docs] class GroundTruthIntervention(RewiringIntervention): """ Intervention that replaces predicted concepts with ground truth values. Implements do(C=c_true) operations by mixing predicted and ground truth concept values based on a binary mask. Args: model: The concept-based model to intervene on. ground_truth: Ground truth concept values of shape (batch_size, n_concepts). Example: >>> import torch >>> from torch_concepts.nn import GroundTruthIntervention >>> >>> # Create a dummy model >>> model = torch.nn.Linear(10, 5) >>> >>> # Ground truth values >>> c_true = torch.tensor([[1.0, 0.0, 1.0, 0.0, 1.0], ... [0.0, 1.0, 0.0, 1.0, 0.0]]) >>> >>> # Create intervention >>> intervention = GroundTruthIntervention(model, c_true) >>> >>> # Apply intervention (typically done via context manager) >>> # See intervention() context manager for complete usage """
[docs] def __init__(self, model: nn.Module, ground_truth: torch.Tensor): super().__init__(model) self.register_buffer("ground_truth", ground_truth)
def _make_target(self, y: torch.Tensor, *args, **kwargs) -> torch.Tensor: return self.ground_truth.to(dtype=y.dtype, device=y.device)
[docs] class DoIntervention(RewiringIntervention): """ Intervention that sets concepts to constant values (do-calculus). Implements do(C=constant) operations, supporting scalar, per-concept, or per-sample constant values with automatic broadcasting. Args: model: The concept-based model to intervene on. constants: Constant values (scalar, [F], [1,F], or [B,F]). Example: >>> import torch >>> from torch_concepts.nn import DoIntervention >>> >>> # Create a dummy model >>> model = torch.nn.Linear(10, 3) >>> >>> # Set all concepts to 1.0 >>> intervention_scalar = DoIntervention(model, 1.0) >>> >>> # Set each concept to different values >>> intervention_vec = DoIntervention( ... model, ... torch.tensor([0.5, 1.0, 0.0]) ... ) >>> >>> # Set per-sample values >>> intervention_batch = DoIntervention( ... model, ... torch.tensor([[0.0, 1.0, 0.5], ... [1.0, 0.0, 0.5]]) ... ) >>> >>> # Use via context manager - see intervention() """
[docs] def __init__(self, model: nn.Module, constants: torch.Tensor | float): super().__init__(model) const = constants if torch.is_tensor(constants) else torch.tensor(constants) self.register_buffer("constants", const)
# unified signature matching base def _make_target(self, y: torch.Tensor, *args, **kwargs) -> torch.Tensor: B, F = y.shape v = self.constants if v.dim() == 0: # scalar v = v.view(1, 1).expand(B, F) elif v.dim() == 1: # [F] assert v.numel() == F, f"constants [F] must have F={F}, got {v.numel()}" v = v.unsqueeze(0).expand(B, F) elif v.dim() == 2: b, f = v.shape assert f == F, f"constants second dim must be F={F}, got {f}" if b == 1: v = v.expand(B, F) # [1, F] -> [B, F] else: assert b == B, f"constants first dim must be B={B} or 1, got {b}" else: raise ValueError("constants must be scalar, [F], [1, F], or [B, F]") return v.to(dtype=y.dtype, device=y.device)
[docs] class DistributionIntervention(RewiringIntervention): """ Intervention that samples concept values from distributions. Implements do(C~D) operations where concepts are sampled from specified probability distributions, enabling distributional interventions. Args: model: The concept-based model to intervene on. dist: A torch.distributions.Distribution or list of per-concept distributions. Example: >>> import torch >>> from torch_concepts.nn import DistributionIntervention >>> from torch.distributions import Bernoulli, Normal >>> >>> # Create a dummy model >>> model = torch.nn.Linear(10, 3) >>> >>> # Single distribution for all concepts >>> intervention_single = DistributionIntervention( ... model, ... Bernoulli(torch.tensor(0.7)) ... ) >>> >>> # Per-concept distributions >>> intervention_multi = DistributionIntervention( ... model, ... [Bernoulli(torch.tensor(0.3)), ... Normal(torch.tensor(0.0), torch.tensor(1.0)), ... Bernoulli(torch.tensor(0.8))] ... ) >>> >>> # Use via context manager - see intervention() """
[docs] def __init__(self, model: nn.Module, dist): super().__init__(model) self.dist = dist
# unified signature matching base def _make_target(self, y: torch.Tensor, *args, **kwargs) -> torch.Tensor: B, F = y.shape device, dtype = y.device, y.dtype def _sample(d, shape): # Try rsample first (for reparameterization), fall back to sample if not supported if hasattr(d, "rsample"): try: return d.rsample(shape) except NotImplementedError: pass return d.sample(shape) if hasattr(self.dist, "sample"): # one distribution for all features t = _sample(self.dist, (B, F)) else: # per-feature list/tuple dists = list(self.dist) assert len(dists) == F, f"Need {F} per-feature distributions, got {len(dists)}" cols = [_sample(d, (B,)) for d in dists] # each [B] t = torch.stack(cols, dim=1) # [B, F] return t.to(device=device, dtype=dtype)
# ---------------- wrapper ---------------- class _InterventionWrapper(nn.Module): def __init__( self, original: nn.Module, policy: nn.Module, strategy: RewiringIntervention, quantile: float, subset: Optional[List[int]] = None, eps: float = 1e-12, ): super().__init__() self.original = original self.policy = policy self.strategy = strategy self.quantile = float(quantile) self.subset = subset self.eps = eps if hasattr(original, "parametrization"): if hasattr(original.parametrization, "forward_to_check"): self.forward_to_check = original.parametrization.forward_to_check elif hasattr(original.parametrization, "forward"): self.forward_to_check = original.parametrization.forward else: self.forward_to_check = original.forward def _build_mask(self, policy_endogenous: torch.Tensor) -> torch.Tensor: B, F = policy_endogenous.shape device = policy_endogenous.device dtype = policy_endogenous.dtype sel_idx = torch.tensor(self.subset, device=device, dtype=torch.long) if self.subset is not None else torch.arange(F, device=device, dtype=torch.long) if len(sel_idx) == 0: return torch.ones_like(policy_endogenous) K = sel_idx.numel() sel = policy_endogenous.index_select(dim=1, index=sel_idx) # [B, K] if K == 1: # Edge case: single selected column. # q < 1 => keep; q == 1 => replace. keep_col = torch.ones((B, 1), device=device, dtype=dtype) if self.quantile < 1.0 \ else torch.zeros((B, 1), device=device, dtype=dtype) mask = torch.ones((B, F), device=device, dtype=dtype) mask.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), keep_col) # STE proxy (optional; keeps gradients flowing on the selected col) row_max = sel.max(dim=1, keepdim=True).values + self.eps soft_sel = torch.log1p(sel) / torch.log1p(row_max) # [B,1] soft_proxy = torch.ones_like(policy_endogenous) soft_proxy.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), soft_sel) mask = (mask - soft_proxy).detach() + soft_proxy return mask # K > 1: standard per-row quantile via kthvalue k = int(max(1, min(K, 1 + math.floor(self.quantile * (K - 1))))) thr, _ = torch.kthvalue(sel, k, dim=1, keepdim=True) # [B,1] # Use strict '>' so ties at the threshold are replaced (robust near edges) sel_mask_hard = (sel > (thr - 0.0)).to(dtype) # [B,K] mask = torch.ones((B, F), device=device, dtype=dtype) mask.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), sel_mask_hard) # STE proxy (unchanged) row_max = sel.max(dim=1, keepdim=True).values + 1e-12 soft_sel = torch.log1p(sel) / torch.log1p(row_max) soft_proxy = torch.ones_like(policy_endogenous) soft_proxy.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), soft_sel) mask = (mask - soft_proxy).detach() + soft_proxy return mask def forward(self, **kwargs) -> torch.Tensor: y = self.original(**kwargs) endogenous = self.policy(y) # [B,F], 0 = most uncertain, +inf = most certain mask = self._build_mask(endogenous) # 1 keep, 0 replace # 3) proxy that returns the cached y instead of recomputing class _CachedOutput(nn.Module): def __init__(self, y_cached: torch.Tensor): super().__init__() self.y_cached = y_cached # keep graph-connected tensor; do NOT detach def forward(self, **kwargs) -> torch.Tensor: return self.y_cached cached = _CachedOutput(y) # 4) use existing strategy API; no changes to GroundTruthIntervention replacer = self.strategy.query(cached, mask) return replacer(**kwargs) # ---------------- global policy wrapper ---------------- class _GlobalPolicyState: """ Shared state for coordinating global policy across multiple wrappers. This state object is shared among all wrappers when global_policy=True. It collects policy endogenous from all layers, computes a global mask once, then distributes slices to each wrapper. This implementation works with sequential, threaded, and CUDA stream execution. """ def __init__(self, n_wrappers: int, quantile: float, eps: float = 1e-12): self.n_wrappers = n_wrappers self.quantile = float(quantile) self.eps = eps # Store endogenous and outputs indexed by wrapper_id self.endogenous_cache: Dict[int, torch.Tensor] = {} self.outputs_cache: Dict[int, torch.Tensor] = {} self.global_mask: Optional[torch.Tensor] = None self.batch_size: Optional[int] = None def reset(self): """Reset state for a new forward pass.""" self.endogenous_cache.clear() self.outputs_cache.clear() self.global_mask = None self.batch_size = None def register(self, wrapper_id: int, endogenous: torch.Tensor, output: torch.Tensor): """Register endogenous and output from a wrapper.""" # Detect new batch by checking batch size change if self.batch_size is not None and endogenous.shape[0] != self.batch_size: self.reset() self.batch_size = endogenous.shape[0] self.endogenous_cache[wrapper_id] = endogenous self.outputs_cache[wrapper_id] = output def is_ready(self) -> bool: """Check if all wrappers have registered their endogenous.""" return len(self.endogenous_cache) == self.n_wrappers def compute_global_mask(self): """Compute the global mask once all endogenous are collected.""" if self.global_mask is not None: return # Already computed if not self.is_ready(): raise RuntimeError( f"Cannot compute global mask: only {len(self.endogenous_cache)}/{self.n_wrappers} wrappers registered" ) # Concatenate all endogenous in wrapper_id order all_endogenous = torch.cat([self.endogenous_cache[i] for i in range(self.n_wrappers)], dim=1) B, F_total = all_endogenous.shape device = all_endogenous.device dtype = all_endogenous.dtype if F_total == 0: self.global_mask = torch.ones((B, 0), device=device, dtype=dtype) return # quantile determines the fraction of concepts to intervene on # quantile=0 -> intervene on 0% (mask=1 for all, keep all) # quantile=1 -> intervene on 100% (mask=0 for all, replace all) num_to_intervene = int(max(0, min(F_total, math.ceil(self.quantile * F_total)))) if num_to_intervene == 0: # Don't intervene on any concepts - keep all predictions # mask=1 means keep, so all ones self.global_mask = torch.ones((B, F_total), device=device, dtype=dtype) return if num_to_intervene == F_total: # Intervene on all concepts - replace all predictions # mask=0 means intervene, so all zeros self.global_mask = torch.zeros((B, F_total), device=device, dtype=dtype) return # Find the threshold: intervene on the top num_to_intervene concepts by policy endogenous # kthvalue(k) returns the k-th smallest value, so for top-k we use (F_total - num_to_intervene + 1) k = F_total - num_to_intervene + 1 thr, _ = torch.kthvalue(all_endogenous, k, dim=1, keepdim=True) # [B,1] # mask=1 means keep (don't intervene), mask=0 means replace (do intervene) # Intervene on concepts with endogenous >= threshold (top-k by policy score) # So those get mask=0, others get mask=1 mask_hard = (all_endogenous < thr).to(dtype) # [B, F_total] - 1 where we keep, 0 where we intervene # STE proxy row_max = all_endogenous.max(dim=1, keepdim=True).values + self.eps soft_proxy = torch.log1p(all_endogenous) / torch.log1p(row_max) self.global_mask = (mask_hard - soft_proxy).detach() + soft_proxy def get_mask_slice(self, wrapper_id: int) -> torch.Tensor: """Get the mask slice for a specific wrapper.""" if self.global_mask is None: raise RuntimeError("Global mask not computed yet") # Calculate start/end index for this wrapper based on output shapes start_idx = sum(self.outputs_cache[i].shape[1] for i in range(wrapper_id)) end_idx = start_idx + self.outputs_cache[wrapper_id].shape[1] return self.global_mask[:, start_idx:end_idx] class _GlobalPolicyInterventionWrapper(nn.Module): """ Intervention wrapper that uses a shared global state for coordinated masking. This wrapper defers intervention application until all wrappers in the level have computed their policy endogenous. During forward pass, it only collects endogenous and returns the original output. The actual intervention is applied via apply_intervention() after all wrappers are ready. """ def __init__( self, original: nn.Module, policy: nn.Module, strategy: RewiringIntervention, wrapper_id: int, shared_state: '_GlobalPolicyState', ): super().__init__() self.original = original self.policy = policy self.strategy = strategy self.wrapper_id = wrapper_id self.shared_state = shared_state if hasattr(original, "parametrization"): if hasattr(original.parametrization, "forward_to_check"): self.forward_to_check = original.parametrization.forward_to_check elif hasattr(original.parametrization, "forward"): self.forward_to_check = original.parametrization.forward else: self.forward_to_check = original.forward def forward(self, **kwargs) -> torch.Tensor: """ Forward pass that collects policy endogenous but does NOT apply intervention. Returns the original output. Intervention is applied later via apply_intervention(). """ # Get output from original module y = self.original(**kwargs) # Compute policy endogenous endogenous = self.policy(y) # [B, F_i] # Register with shared state self.shared_state.register(self.wrapper_id, endogenous, y) # Always return original output - intervention applied later return y def apply_intervention(self, y: torch.Tensor) -> torch.Tensor: """ Apply intervention to the output after all wrappers are ready. This should be called after all wrappers in the level have completed forward(). Args: y: The original output from forward() Returns: Intervened output """ if not self.shared_state.is_ready(): raise RuntimeError( f"Cannot apply intervention: only {len(self.shared_state.endogenous_cache)}/{self.shared_state.n_wrappers} wrappers registered" ) # Compute global mask if not already computed if self.shared_state.global_mask is None: self.shared_state.compute_global_mask() # Get mask slice for this wrapper mask = self.shared_state.get_mask_slice(self.wrapper_id) # Create cached output wrapper class _CachedOutput(nn.Module): def __init__(self, y_cached: torch.Tensor): super().__init__() self.y_cached = y_cached def forward(self, **kwargs) -> torch.Tensor: return self.y_cached cached = _CachedOutput(y) replacer = self.strategy.query(cached, mask) result = replacer() return result # ---------------- context manager (now multi-layer) ----------------
[docs] @contextlib.contextmanager def intervention( *, policies: Union[nn.Module, Sequence[nn.Module]], strategies: Union[RewiringIntervention, Sequence[RewiringIntervention]], target_concepts: Union[str, int, Sequence[Union[str, int]]], quantiles: Optional[Union[float, Sequence[float]]] = 1., model: nn.Module = None, global_policy: bool = False, ): """ Context manager for applying interventions to concept-based models. Enables interventions on concept modules by temporarily replacing model components with intervention wrappers. Supports single or multiple layers. Args: policies: Policy module(s) that determine which concepts to intervene on. strategies: Intervention strategy/strategies (e.g., DoIntervention). target_concepts: Concept names/paths or indices to intervene on. quantiles: Quantile thresholds for selective intervention (default: 1.0). model: Optional model reference (default: strategies[0].model). global_policy: If True, multiple policies are coordinated globally to create a unified mask across all layers. If False (default), each policy operates independently on its layer. Only applies when target_concepts are strings and multiple policies are provided. Yields: The intervention wrapper (if target_concepts are indices) or None. Example: >>> import torch >>> from torch_concepts.nn import ( ... DoIntervention, intervention, RandomPolicy ... ) >>> from torch_concepts import Variable >>> >>> # Create a simple model >>> class SimplePGM(torch.nn.Module): ... def __init__(self, in_features, out_features): ... super().__init__() ... self.encoder = torch.nn.Linear(in_features, 3) ... self.predictor = torch.nn.Linear(3, out_features) ... def forward(self, x): ... c = torch.sigmoid(self.encoder(x)) ... y = self.predictor(c) ... return y >>> >>> model = SimplePGM(10, 3) >>> >>> # Create intervention strategy (set concepts to 1) >>> strategy = DoIntervention(model, torch.FloatTensor([1.0, 0.0, 1.0])) >>> >>> # Create policy (random selection) >>> policy = RandomPolicy(out_features=3) >>> >>> # Apply intervention on specific concept indices >>> x = torch.randn(4, 10) >>> with intervention( ... policies=policy, ... strategies=strategy, ... target_concepts=[0, 2], # Intervene on concepts 0 and 2 ... quantiles=0.8 ... ) as wrapper: ... # Inside context, interventions are active ... output = wrapper(x=x) >>> >>> print(f"Output shape: {output.shape}") Output shape: torch.Size([4, 3]) >>> >>> # Example with global_policy=True for coordinated multi-layer intervention >>> # (requires multiple layers and policies) """ # Normalise on_layers to list and compute N if isinstance(target_concepts, str): target_concepts = [target_concepts] N = len(target_concepts) # Choose the reference model if isinstance(strategies, Sequence): ref_model = strategies[0].model else: ref_model = strategies.model originals: List[tuple[str, nn.Module]] = [] try: if isinstance(target_concepts[0], int): # in this case we expect a single module to replace assert not isinstance(policies, Sequence), "When target_concepts are indices, only a single policy is supported" assert not isinstance(strategies, Sequence), "When target_concepts are indices, only a single strategy is supported" assert not isinstance(quantiles, Sequence), "When target_concepts are indices, only a single quantile is supported" assert not global_policy, "global_policy not supported for index-based interventions" wrap = _InterventionWrapper( original=strategies.model, policy=policies, strategy=strategies, quantile=quantiles, subset=target_concepts # type: ignore ) yield wrap else: # Broadcast/validate others policies = _as_list(policies, N) strategies = _as_list(strategies, N) quantiles = _as_list(quantiles, N) if global_policy: # Global policy mode: coordinate all policies to create unified global mask # Validate: all quantiles must be the same for global policy if not all(q == quantiles[0] for q in quantiles): raise ValueError( "When global_policy=True, all quantiles must be the same. " f"Got: {quantiles}" ) global_quantile = quantiles[0] # Create shared state for coordination shared_state = _GlobalPolicyState(n_wrappers=N, quantile=global_quantile) # Create global wrappers for each layer for wrapper_id, (path, pol, strat) in enumerate(zip(target_concepts, policies, strategies)): orig = _get_submodule(ref_model, path) originals.append((path, orig)) wrapper = _GlobalPolicyInterventionWrapper( original=orig, policy=pol, strategy=strat, wrapper_id=wrapper_id, shared_state=shared_state, ) _set_submodule(ref_model, path, wrapper) # Don't yield anything - wrappers coordinate automatically during forward pass yield else: # Independent mode (default/backward compatible): each policy creates its own mask for path, pol, strat, q in zip(target_concepts, policies, strategies, quantiles): orig = _get_submodule(ref_model, path) originals.append((path, orig)) wrap = _InterventionWrapper( original=orig, policy=pol, strategy=strat, quantile=q, ) _set_submodule(ref_model, path, wrap) yield finally: # restore originals for path, orig in originals: _set_submodule(ref_model, path, orig)