Intervention Strategies and Context Manager

This module provides inference mechanisms for intervening on concept-based models.

Summary

Intervention Strategies

RewiringIntervention

Base class for rewiring-based interventions.

GroundTruthIntervention

Intervention that replaces predicted concepts with ground truth values.

DoIntervention

Intervention that sets concepts to constant values (do-calculus).

DistributionIntervention

Intervention that samples concept values from distributions.

Intervention Context Manager

intervention

Context manager for applying interventions to concept-based models.

Class Documentation

class RewiringIntervention(model: Module, *args, **kwargs)[source]

Bases: BaseIntervention

Base class for rewiring-based interventions.

Rewiring interventions replace predicted concept values with target values based on a binary mask, implementing do-calculus operations.

Parameters:

model – The concept-based model to intervene on.

Example

>>> import torch
>>> from torch_concepts.nn import RewiringIntervention
>>>
>>> # Subclass to create custom intervention
>>> class MyIntervention(RewiringIntervention):
...     def _make_target(self, y, *args, **kwargs):
...         return torch.ones_like(y)
>>>
query(original_module: Module, mask: Tensor, *args, **kwargs) Module[source]

Create an intervention wrapper module.

Parameters:
  • original_module – The original module to wrap.

  • mask – Binary mask (1=keep prediction, 0=replace with target).

  • *args – Additional arguments.

  • **kwargs – Additional keyword arguments.

Returns:

Wrapped module with intervention applied.

Return type:

nn.Module

training: bool
class GroundTruthIntervention(model: Module, ground_truth: Tensor)[source]

Bases: RewiringIntervention

Intervention that replaces predicted concepts with ground truth values.

Implements do(C=c_true) operations by mixing predicted and ground truth concept values based on a binary mask.

Parameters:
  • model – The concept-based model to intervene on.

  • ground_truth – Ground truth concept values of shape (batch_size, n_concepts).

Example

>>> import torch
>>> from torch_concepts.nn import GroundTruthIntervention
>>>
>>> # Create a dummy model
>>> model = torch.nn.Linear(10, 5)
>>>
>>> # Ground truth values
>>> c_true = torch.tensor([[1.0, 0.0, 1.0, 0.0, 1.0],
...                        [0.0, 1.0, 0.0, 1.0, 0.0]])
>>>
>>> # Create intervention
>>> intervention = GroundTruthIntervention(model, c_true)
>>>
>>> # Apply intervention (typically done via context manager)
>>> # See intervention() context manager for complete usage
training: bool
class DoIntervention(model: Module, constants: Tensor | float)[source]

Bases: RewiringIntervention

Intervention that sets concepts to constant values (do-calculus).

Implements do(C=constant) operations, supporting scalar, per-concept, or per-sample constant values with automatic broadcasting.

Parameters:
  • model – The concept-based model to intervene on.

  • constants – Constant values (scalar, [F], [1,F], or [B,F]).

Example

>>> import torch
>>> from torch_concepts.nn import DoIntervention
>>>
>>> # Create a dummy model
>>> model = torch.nn.Linear(10, 3)
>>>
>>> # Set all concepts to 1.0
>>> intervention_scalar = DoIntervention(model, 1.0)
>>>
>>> # Set each concept to different values
>>> intervention_vec = DoIntervention(
...     model,
...     torch.tensor([0.5, 1.0, 0.0])
... )
>>>
>>> # Set per-sample values
>>> intervention_batch = DoIntervention(
...     model,
...     torch.tensor([[0.0, 1.0, 0.5],
...                   [1.0, 0.0, 0.5]])
... )
>>>
>>> # Use via context manager - see intervention()
training: bool
class DistributionIntervention(model: Module, dist)[source]

Bases: RewiringIntervention

Intervention that samples concept values from distributions.

Implements do(C~D) operations where concepts are sampled from specified probability distributions, enabling distributional interventions.

Parameters:
  • model – The concept-based model to intervene on.

  • dist – A torch.distributions.Distribution or list of per-concept distributions.

Example

>>> import torch
>>> from torch_concepts.nn import DistributionIntervention
>>> from torch.distributions import Bernoulli, Normal
>>>
>>> # Create a dummy model
>>> model = torch.nn.Linear(10, 3)
>>>
>>> # Single distribution for all concepts
>>> intervention_single = DistributionIntervention(
...     model,
...     Bernoulli(torch.tensor(0.7))
... )
>>>
>>> # Per-concept distributions
>>> intervention_multi = DistributionIntervention(
...     model,
...     [Bernoulli(torch.tensor(0.3)),
...      Normal(torch.tensor(0.0), torch.tensor(1.0)),
...      Bernoulli(torch.tensor(0.8))]
... )
>>>
>>> # Use via context manager - see intervention()
training: bool

Function Documentation

intervention(*, policies: Module | Sequence[Module], strategies: RewiringIntervention | Sequence[RewiringIntervention], target_concepts: str | int | Sequence[str | int], quantiles: float | Sequence[float] | None = 1.0, model: Module = None, global_policy: bool = False)[source]

Context manager for applying interventions to concept-based models.

Enables interventions on concept modules by temporarily replacing model components with intervention wrappers. Supports single or multiple layers.

Parameters:
  • policies – Policy module(s) that determine which concepts to intervene on.

  • strategies – Intervention strategy/strategies (e.g., DoIntervention).

  • target_concepts – Concept names/paths or indices to intervene on.

  • quantiles – Quantile thresholds for selective intervention (default: 1.0).

  • model – Optional model reference (default: strategies[0].model).

  • global_policy – If True, multiple policies are coordinated globally to create a unified mask across all layers. If False (default), each policy operates independently on its layer. Only applies when target_concepts are strings and multiple policies are provided.

Yields:

The intervention wrapper (if target_concepts are indices) or None.

Example

>>> import torch
>>> from torch_concepts.nn import (
...     DoIntervention, intervention, RandomPolicy
... )
>>> from torch_concepts import Variable
>>>
>>> # Create a simple model
>>> class SimplePGM(torch.nn.Module):
...     def __init__(self, in_features, out_features):
...         super().__init__()
...         self.encoder = torch.nn.Linear(in_features, 3)
...         self.predictor = torch.nn.Linear(3, out_features)
...     def forward(self, x):
...         c = torch.sigmoid(self.encoder(x))
...         y = self.predictor(c)
...         return y
>>>
>>> model = SimplePGM(10, 3)
>>>
>>> # Create intervention strategy (set concepts to 1)
>>> strategy = DoIntervention(model, torch.FloatTensor([1.0, 0.0, 1.0]))
>>>
>>> # Create policy (random selection)
>>> policy = RandomPolicy(out_features=3)
>>>
>>> # Apply intervention on specific concept indices
>>> x = torch.randn(4, 10)
>>> with intervention(
...     policies=policy,
...     strategies=strategy,
...     target_concepts=[0, 2],  # Intervene on concepts 0 and 2
...     quantiles=0.8
... ) as wrapper:
...     # Inside context, interventions are active
...     output = wrapper(x=x)
>>>
>>> print(f"Output shape: {output.shape}")
Output shape: torch.Size([4, 3])
>>>
>>> # Example with global_policy=True for coordinated multi-layer intervention
>>> # (requires multiple layers and policies)