Intervention Strategies and Context Manager¶
This module provides inference mechanisms for intervening on concept-based models.
Summary¶
Intervention Strategies
Base class for rewiring-based interventions. |
|
Intervention that replaces predicted concepts with ground truth values. |
|
Intervention that sets concepts to constant values (do-calculus). |
|
Intervention that samples concept values from distributions. |
Intervention Context Manager
Context manager for applying interventions to concept-based models. |
Class Documentation¶
- class RewiringIntervention(model: Module, *args, **kwargs)[source]¶
Bases:
BaseInterventionBase class for rewiring-based interventions.
Rewiring interventions replace predicted concept values with target values based on a binary mask, implementing do-calculus operations.
- Parameters:
model – The concept-based model to intervene on.
Example
>>> import torch >>> from torch_concepts.nn import RewiringIntervention >>> >>> # Subclass to create custom intervention >>> class MyIntervention(RewiringIntervention): ... def _make_target(self, y, *args, **kwargs): ... return torch.ones_like(y) >>>
- query(original_module: Module, mask: Tensor, *args, **kwargs) Module[source]¶
Create an intervention wrapper module.
- Parameters:
original_module – The original module to wrap.
mask – Binary mask (1=keep prediction, 0=replace with target).
*args – Additional arguments.
**kwargs – Additional keyword arguments.
- Returns:
Wrapped module with intervention applied.
- Return type:
nn.Module
- class GroundTruthIntervention(model: Module, ground_truth: Tensor)[source]¶
Bases:
RewiringInterventionIntervention that replaces predicted concepts with ground truth values.
Implements do(C=c_true) operations by mixing predicted and ground truth concept values based on a binary mask.
- Parameters:
model – The concept-based model to intervene on.
ground_truth – Ground truth concept values of shape (batch_size, n_concepts).
Example
>>> import torch >>> from torch_concepts.nn import GroundTruthIntervention >>> >>> # Create a dummy model >>> model = torch.nn.Linear(10, 5) >>> >>> # Ground truth values >>> c_true = torch.tensor([[1.0, 0.0, 1.0, 0.0, 1.0], ... [0.0, 1.0, 0.0, 1.0, 0.0]]) >>> >>> # Create intervention >>> intervention = GroundTruthIntervention(model, c_true) >>> >>> # Apply intervention (typically done via context manager) >>> # See intervention() context manager for complete usage
- class DoIntervention(model: Module, constants: Tensor | float)[source]¶
Bases:
RewiringInterventionIntervention that sets concepts to constant values (do-calculus).
Implements do(C=constant) operations, supporting scalar, per-concept, or per-sample constant values with automatic broadcasting.
- Parameters:
model – The concept-based model to intervene on.
constants – Constant values (scalar, [F], [1,F], or [B,F]).
Example
>>> import torch >>> from torch_concepts.nn import DoIntervention >>> >>> # Create a dummy model >>> model = torch.nn.Linear(10, 3) >>> >>> # Set all concepts to 1.0 >>> intervention_scalar = DoIntervention(model, 1.0) >>> >>> # Set each concept to different values >>> intervention_vec = DoIntervention( ... model, ... torch.tensor([0.5, 1.0, 0.0]) ... ) >>> >>> # Set per-sample values >>> intervention_batch = DoIntervention( ... model, ... torch.tensor([[0.0, 1.0, 0.5], ... [1.0, 0.0, 0.5]]) ... ) >>> >>> # Use via context manager - see intervention()
- class DistributionIntervention(model: Module, dist)[source]¶
Bases:
RewiringInterventionIntervention that samples concept values from distributions.
Implements do(C~D) operations where concepts are sampled from specified probability distributions, enabling distributional interventions.
- Parameters:
model – The concept-based model to intervene on.
dist – A torch.distributions.Distribution or list of per-concept distributions.
Example
>>> import torch >>> from torch_concepts.nn import DistributionIntervention >>> from torch.distributions import Bernoulli, Normal >>> >>> # Create a dummy model >>> model = torch.nn.Linear(10, 3) >>> >>> # Single distribution for all concepts >>> intervention_single = DistributionIntervention( ... model, ... Bernoulli(torch.tensor(0.7)) ... ) >>> >>> # Per-concept distributions >>> intervention_multi = DistributionIntervention( ... model, ... [Bernoulli(torch.tensor(0.3)), ... Normal(torch.tensor(0.0), torch.tensor(1.0)), ... Bernoulli(torch.tensor(0.8))] ... ) >>> >>> # Use via context manager - see intervention()
Function Documentation¶
- intervention(*, policies: Module | Sequence[Module], strategies: RewiringIntervention | Sequence[RewiringIntervention], target_concepts: str | int | Sequence[str | int], quantiles: float | Sequence[float] | None = 1.0, model: Module = None, global_policy: bool = False)[source]¶
Context manager for applying interventions to concept-based models.
Enables interventions on concept modules by temporarily replacing model components with intervention wrappers. Supports single or multiple layers.
- Parameters:
policies – Policy module(s) that determine which concepts to intervene on.
strategies – Intervention strategy/strategies (e.g., DoIntervention).
target_concepts – Concept names/paths or indices to intervene on.
quantiles – Quantile thresholds for selective intervention (default: 1.0).
model – Optional model reference (default: strategies[0].model).
global_policy – If True, multiple policies are coordinated globally to create a unified mask across all layers. If False (default), each policy operates independently on its layer. Only applies when target_concepts are strings and multiple policies are provided.
- Yields:
The intervention wrapper (if target_concepts are indices) or None.
Example
>>> import torch >>> from torch_concepts.nn import ( ... DoIntervention, intervention, RandomPolicy ... ) >>> from torch_concepts import Variable >>> >>> # Create a simple model >>> class SimplePGM(torch.nn.Module): ... def __init__(self, in_features, out_features): ... super().__init__() ... self.encoder = torch.nn.Linear(in_features, 3) ... self.predictor = torch.nn.Linear(3, out_features) ... def forward(self, x): ... c = torch.sigmoid(self.encoder(x)) ... y = self.predictor(c) ... return y >>> >>> model = SimplePGM(10, 3) >>> >>> # Create intervention strategy (set concepts to 1) >>> strategy = DoIntervention(model, torch.FloatTensor([1.0, 0.0, 1.0])) >>> >>> # Create policy (random selection) >>> policy = RandomPolicy(out_features=3) >>> >>> # Apply intervention on specific concept indices >>> x = torch.randn(4, 10) >>> with intervention( ... policies=policy, ... strategies=strategy, ... target_concepts=[0, 2], # Intervene on concepts 0 and 2 ... quantiles=0.8 ... ) as wrapper: ... # Inside context, interventions are active ... output = wrapper(x=x) >>> >>> print(f"Output shape: {output.shape}") Output shape: torch.Size([4, 3]) >>> >>> # Example with global_policy=True for coordinated multi-layer intervention >>> # (requires multiple layers and policies)