Source code for torch_concepts.nn.modules.low.intervention.intervention

import inspect
import functools
from itertools import chain
from abc import abstractmethod, ABC

import torch
import torch.nn as nn
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Union

from torch_concepts import Annotations
from ..base.intervention import (
    BaseConceptInterventionStrategy,
    BaseModuleInterventionStrategy,
    BaseInterventionPolicy
)


[docs] class BaseInterventionModule(nn.Module, ABC): """ Base class for intervention modules that wrap an original module with a specified intervention strategy and policy. This module applies the intervention strategy to the outputs of the original module according to the intervention policy, allowing for flexible interventions on concept encoders. Subclasses should implement the specific logic for applying the intervention strategy and policy in the forward method. """
[docs] def __init__( self, original_module: nn.Module, intervention_strategy: Union[BaseConceptInterventionStrategy, BaseModuleInterventionStrategy], intervention_policy: BaseInterventionPolicy, out_concepts_to_intervene_on: Union[List[str], List[int]] = None, quantile: float = 1.0, eps: float = 1e-12, build_context: Optional[Callable] = None, extra_modules: Optional[Dict[str, nn.Module]] = None, *args, **kwargs ): super().__init__() self.original_module = original_module self.intervention_strategy = intervention_strategy self.intervention_policy = intervention_policy self.out_concepts_to_intervene_on = out_concepts_to_intervene_on self.quantile = quantile self.eps = eps self._build_context_fn = build_context if extra_modules: for name, module in extra_modules.items(): self.add_module(name, module) self._patch_forward_signature()
def _patch_forward_signature(self): """ Patches ``self.forward`` at instance level to expose the same named arguments as ``original_module.forward``, with ``extra_tensors`` added as a keyword-only argument. This means IDEs and ``help()`` will show the correct call signature for this :class:`InterventionModule` instance, e.g.:: forward(embeddings: Tensor, *, extra_tensors: Optional[Dict[str, Tensor]] = None) The same named arguments are also visible inside ``build_context`` when calling ``module.original_module.forward``, because ``original_module`` is a concrete type with its own declared signature. """ try: orig_sig = inspect.signature(self.original_module.forward) params = [p for p in orig_sig.parameters.values() if p.name != 'self'] extra_param = inspect.Parameter( 'extra_tensors', kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=Optional[Dict[str, torch.Tensor]] ) # insert before **kwargs if present, otherwise append var_kw_idx = next( (i for i, p in enumerate(params) if p.kind == inspect.Parameter.VAR_KEYWORD), None ) if var_kw_idx is not None: params.insert(var_kw_idx, extra_param) else: params.append(extra_param) new_sig = orig_sig.replace(parameters=params) original_forward = InterventionModule.forward @functools.wraps(original_forward) def patched_forward(*args, **kwargs): return original_forward(self, *args, **kwargs) patched_forward.__signature__ = new_sig self.forward = patched_forward except (ValueError, TypeError): pass # silently skip if signature cannot be determined @property def sel_idx(self): if self.out_concepts_to_intervene_on is not None: if isinstance(self.out_concepts_to_intervene_on[0], int): return torch.tensor(self.out_concepts_to_intervene_on, dtype=torch.long) elif isinstance(self.out_concepts_to_intervene_on[0], str): original_annotations = getattr(self.original_module, "out_concepts", None) if original_annotations is None and not isinstance(original_annotations, Annotations): raise ValueError("To use string-based concept selection, the original module must have an " "'out_concepts' attribute of type Annotations.") indices = original_annotations.get_slice(self.out_concepts_to_intervene_on) if isinstance(indices, slice): indices = list(range(indices.start, indices.stop, indices.step or 1)) return torch.tensor(indices, dtype=torch.long) else: raise ValueError("out_concepts_to_intervene_on must be a list of integers (indices) or strings (names)") return None @abstractmethod def build_context( self, original_module_inputs: Dict[str, torch.Tensor], original_module: nn.Module, original_module_predictions: torch.Tensor, extra_tensors: Dict[str, torch.Tensor] = None, extra_modules: Dict[str, nn.Module] = None, ) -> dict: raise NotImplementedError("Subclasses must implement build_context method " "to provide extra context for policy and strategy.") def forward(self, *args, **kwargs) -> torch.Tensor: extra_tensors: Dict[str, torch.Tensor] = kwargs.pop('extra_tensors', None) or {} # bind positional and keyword args to parameter names of the wrapped module try: sig = inspect.signature(self.original_module.forward) bound = sig.bind(*args, **kwargs) bound.apply_defaults() original_module_inputs = dict(bound.arguments) except TypeError: original_module_inputs = {} original_module_predictions = self.original_module(*args, **kwargs) # [B, F] assert original_module_predictions.dim() == 2, ( f"BaseConceptInterventionStrategy expects 2-D tensors [Batch, N_concepts]. " f"Got shape: {original_module_predictions.shape}" ) extra_modules = { name: module for name, module in self._modules.items() if name not in ("original_module", "intervention_strategy", "intervention_policy") } context = self.build_context( original_module_inputs, self.original_module, original_module_predictions, extra_tensors, extra_modules, ) policy_scores = self.intervention_policy(original_module_predictions, *args, **kwargs, **context) intervention_mask = self.intervention_policy.build_mask( policy_scores, sel_idx=self.sel_idx, quantile=self.quantile, eps=self.eps ).to(dtype=original_module_predictions.dtype) if isinstance(self.intervention_strategy, BaseConceptInterventionStrategy): intervened_predictions = self.intervention_strategy(original_module_predictions, *args, **kwargs, **context) elif isinstance(self.intervention_strategy, BaseModuleInterventionStrategy): intervened_module = self.intervention_strategy.transform(self.original_module, *args, **kwargs) intervened_predictions = intervened_module(*args, **kwargs) else: raise ValueError("Intervention strategy must be an instance of " "BaseConceptInterventionStrategy or BaseModuleInterventionStrategy.") return (original_module_predictions * intervention_mask + intervened_predictions * (1.0 - intervention_mask))
class InterventionModule(BaseInterventionModule): def __init__( self, original_module: nn.Module, intervention_strategy: Union[BaseConceptInterventionStrategy, BaseModuleInterventionStrategy], intervention_policy: BaseInterventionPolicy, out_concepts_to_intervene_on: Union[List[str], List[int]] = None, quantile: float = 1.0, eps: float = 1e-12, build_context: Optional[Callable] = None, extra_modules: Optional[Dict[str, nn.Module]] = None, *args, **kwargs ): super().__init__( original_module, intervention_strategy, intervention_policy, out_concepts_to_intervene_on, quantile, eps, build_context=build_context, extra_modules=extra_modules, *args, **kwargs ) def build_context( self, original_module_inputs: Dict[str, torch.Tensor], original_module: nn.Module, original_module_predictions: torch.Tensor, extra_tensors: Dict[str, torch.Tensor] = None, extra_modules: Dict[str, nn.Module] = None, ) -> dict: """ Build extra context passed as kwargs to the policy and strategy. Override this method in a subclass, or supply a ``build_context`` callable at construction time. The callable receives:: build_context( original_module, original_module_predictions, original_module_inputs, extra_tensors, extra_modules, ) where: - ``original_module`` — the wrapped encoder module - ``original_module_predictions`` — encoder output ``[B, F]``, with grad_fn intact - ``original_module_inputs`` — encoder inputs bound by name via ``inspect.signature`` (e.g. ``{"embeddings": tensor}``) - ``extra_tensors`` — dict of tensors passed by the caller at call time (e.g. pre-computed ``y_pred``, ``c_pred``) - ``extra_modules`` — dict of registered extra modules passed at construction (e.g. ``{"task_head": task_head}``) Returns an empty dict by default (zero overhead). """ if self._build_context_fn is not None: return self._build_context_fn( original_module_predictions, self.original_module, original_module_inputs, extra_tensors, extra_modules, ) return {} @contextmanager def intervention( original_module: nn.Module, intervention_strategy: Union[BaseConceptInterventionStrategy, BaseModuleInterventionStrategy], intervention_policy: BaseInterventionPolicy, out_concepts_to_intervene_on: Union[List[str], List[int]] = None, quantile: float = 1.0, eps: float = 1e-12, build_context: Optional[Callable] = None, extra_modules: Optional[Dict[str, nn.Module]] = None, *args, **kwargs ): """ Context manager to automatically apply a policy and strategy to a concept encoder during execution. """ try: yield InterventionModule( original_module, intervention_strategy, intervention_policy, out_concepts_to_intervene_on, quantile, eps, build_context=build_context, extra_modules=extra_modules, *args, **kwargs ) finally: pass def intervene( original_module: nn.Module, intervention_strategy: Union[BaseConceptInterventionStrategy, BaseModuleInterventionStrategy], intervention_policy: BaseInterventionPolicy, out_concepts_to_intervene_on: Union[List[str], List[int]] = None, quantile: float = 1.0, eps: float = 1e-12, build_context: Optional[Callable] = None, extra_modules: Optional[Dict[str, nn.Module]] = None, *args, **kwargs ) -> nn.Module: """ Wrap a concept encoder module with an intervention strategy and policy. Args: original_module: The original module to wrap. intervention_strategy: The intervention strategy to apply. intervention_policy: The intervention policy to determine which concepts to intervene on. out_concepts_to_intervene_on: A list of concept names or indices to intervene on. If None, all concepts are considered for intervention. quantile: Fraction of selected concepts to intervene on (default 1.0 = all selected). eps: Small epsilon for numerical stability. build_context: Optional callable ``(module, predictions, extra_tensors, *args, **kwargs) -> dict`` that produces extra kwargs forwarded to the policy and strategy. Receives the :class:`InterventionModule` as first argument so it can access any registered ``extra_modules``, ``predictions`` as the un-intervened encoder output, and ``extra_tensors`` as a dict of tensors passed at call time. extra_modules: Optional dict of ``{name: nn.Module}`` to register inside the :class:`InterventionModule` (e.g. a task head needed by ``build_context``). Registered modules are fully tracked by PyTorch (parameters, device, state_dict). Returns: nn.Module: A new module that applies the specified intervention strategy and policy. """ return InterventionModule( original_module, intervention_strategy, intervention_policy, out_concepts_to_intervene_on, quantile, eps, build_context=build_context, extra_modules=extra_modules, *args, **kwargs )