Source code for torch_concepts.nn.modules.low.intervention.policy.gradient
import torch
from typing import Optional
from ...base.intervention import BaseInterventionPolicy
[docs]
class GradientPolicy(BaseInterventionPolicy):
"""
Gradient-based intervention policy.
Scores concepts by the magnitude of the gradient of a downstream output
with respect to each concept. Concepts with larger gradient magnitude have
higher influence on the downstream task and are prioritised for intervention.
Requires ``concept_grads`` to be provided via a ``build_context`` callable
on the :class:`InterventionModule`. Falls back to uniform (zero) scores if
no gradients are available.
Example::
def build_context(module, predictions, *args, **kwargs):
with torch.enable_grad():
pred = predictions.detach().requires_grad_(True)
task_out = module.task_head(pred)
grads = torch.autograd.grad(task_out.sum(), pred)[0]
return {"concept_grads": grads.detach()}
intervention_module = InterventionModule(
concept_encoder, strategy, GradientPolicy(),
build_context=build_context,
extra_modules={"task_head": task_head},
quantile=0.5,
)
"""
[docs]
def __init__(self):
super().__init__()
def forward(
self,
concepts: torch.Tensor,
*args,
concept_grads: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute intervention scores based on gradient magnitude.
Args:
concepts: Input concepts of shape ``(batch_size, n_concepts)``.
concept_grads: Gradient of a downstream output w.r.t. each concept,
same shape as ``concepts``. Supplied automatically when a
``build_context`` function is attached to the
:class:`InterventionModule`.
Returns:
torch.Tensor: Gradient magnitude scores (``|concept_grads|``), or
zeros of the same shape if ``concept_grads`` is not available.
"""
if concept_grads is not None:
return concept_grads.abs()
return torch.zeros_like(concepts)