torch_concepts.nn.intervention¶
- intervention(*, policies: Module | Sequence[Module], strategies: RewiringIntervention | Sequence[RewiringIntervention], target_concepts: str | int | Sequence[str | int], quantiles: float | Sequence[float] | None = 1.0, model: Module = None, global_policy: bool = False)[source]¶
Context manager for applying interventions to concept-based models.
Enables interventions on concept modules by temporarily replacing model components with intervention wrappers. Supports single or multiple layers.
- Parameters:
policies – Policy module(s) that determine which concepts to intervene on.
strategies – Intervention strategy/strategies (e.g., DoIntervention).
target_concepts – Concept names/paths or indices to intervene on.
quantiles – Quantile thresholds for selective intervention (default: 1.0).
model – Optional model reference (default: strategies[0].model).
global_policy – If True, multiple policies are coordinated globally to create a unified mask across all layers. If False (default), each policy operates independently on its layer. Only applies when target_concepts are strings and multiple policies are provided.
- Yields:
The intervention wrapper (if target_concepts are indices) or None.
Example
>>> import torch >>> from torch_concepts.nn import ( ... DoIntervention, intervention, RandomPolicy ... ) >>> from torch_concepts import Variable >>> >>> # Create a simple model >>> class SimplePGM(torch.nn.Module): ... def __init__(self, in_features, out_features): ... super().__init__() ... self.encoder = torch.nn.Linear(in_features, 3) ... self.predictor = torch.nn.Linear(3, out_features) ... def forward(self, x): ... c = torch.sigmoid(self.encoder(x)) ... y = self.predictor(c) ... return y >>> >>> model = SimplePGM(10, 3) >>> >>> # Create intervention strategy (set concepts to 1) >>> strategy = DoIntervention(model, torch.FloatTensor([1.0, 0.0, 1.0])) >>> >>> # Create policy (random selection) >>> policy = RandomPolicy(out_features=3) >>> >>> # Apply intervention on specific concept indices >>> x = torch.randn(4, 10) >>> with intervention( ... policies=policy, ... strategies=strategy, ... target_concepts=[0, 2], # Intervene on concepts 0 and 2 ... quantiles=0.8 ... ) as wrapper: ... # Inside context, interventions are active ... output = wrapper(x=x) >>> >>> print(f"Output shape: {output.shape}") Output shape: torch.Size([4, 3]) >>> >>> # Example with global_policy=True for coordinated multi-layer intervention >>> # (requires multiple layers and policies)