torch_concepts.nn.intervention

intervention(*, policies: Module | Sequence[Module], strategies: RewiringIntervention | Sequence[RewiringIntervention], target_concepts: str | int | Sequence[str | int], quantiles: float | Sequence[float] | None = 1.0, model: Module = None, global_policy: bool = False)[source]

Context manager for applying interventions to concept-based models.

Enables interventions on concept modules by temporarily replacing model components with intervention wrappers. Supports single or multiple layers.

Parameters:
  • policies – Policy module(s) that determine which concepts to intervene on.

  • strategies – Intervention strategy/strategies (e.g., DoIntervention).

  • target_concepts – Concept names/paths or indices to intervene on.

  • quantiles – Quantile thresholds for selective intervention (default: 1.0).

  • model – Optional model reference (default: strategies[0].model).

  • global_policy – If True, multiple policies are coordinated globally to create a unified mask across all layers. If False (default), each policy operates independently on its layer. Only applies when target_concepts are strings and multiple policies are provided.

Yields:

The intervention wrapper (if target_concepts are indices) or None.

Example

>>> import torch
>>> from torch_concepts.nn import (
...     DoIntervention, intervention, RandomPolicy
... )
>>> from torch_concepts import Variable
>>>
>>> # Create a simple model
>>> class SimplePGM(torch.nn.Module):
...     def __init__(self, in_features, out_features):
...         super().__init__()
...         self.encoder = torch.nn.Linear(in_features, 3)
...         self.predictor = torch.nn.Linear(3, out_features)
...     def forward(self, x):
...         c = torch.sigmoid(self.encoder(x))
...         y = self.predictor(c)
...         return y
>>>
>>> model = SimplePGM(10, 3)
>>>
>>> # Create intervention strategy (set concepts to 1)
>>> strategy = DoIntervention(model, torch.FloatTensor([1.0, 0.0, 1.0]))
>>>
>>> # Create policy (random selection)
>>> policy = RandomPolicy(out_features=3)
>>>
>>> # Apply intervention on specific concept indices
>>> x = torch.randn(4, 10)
>>> with intervention(
...     policies=policy,
...     strategies=strategy,
...     target_concepts=[0, 2],  # Intervene on concepts 0 and 2
...     quantiles=0.8
... ) as wrapper:
...     # Inside context, interventions are active
...     output = wrapper(x=x)
>>>
>>> print(f"Output shape: {output.shape}")
Output shape: torch.Size([4, 3])
>>>
>>> # Example with global_policy=True for coordinated multi-layer intervention
>>> # (requires multiple layers and policies)