Source code for torch_concepts.nn.modules.low.base.intervention

"""
Base intervention classes for concept-based models.

This module provides abstract base classes for implementing intervention strategies in concept-based models.
"""
import math
from abc import ABC, abstractmethod
from typing import Optional

import torch
import torch.nn as nn


[docs] class BaseConceptInterventionStrategy(nn.Module, ABC): """ Abstract base class for intervention strategies. Intervention strategies define how to intervene on layers (either on the parametrization or on the output). """
[docs] def __init__(self, *args, **kwargs): """Initialize the intervention module.""" super(BaseConceptInterventionStrategy, self).__init__()
@abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: """Forward method to be implemented by subclasses.""" raise NotImplementedError
[docs] class BaseModuleInterventionStrategy(ABC): """ Abstract base class for intervention strategies. Intervention strategies define how to intervene on layers (either on the parametrization or on the output). """
[docs] def __init__(self, *args, **kwargs): """Initialize the intervention module.""" super(BaseModuleInterventionStrategy, self).__init__()
@abstractmethod def transform(self, module: nn.Module, *args, **kwargs) -> nn.Module: """Forward method to be implemented by subclasses.""" raise NotImplementedError
[docs] class BaseInterventionPolicy(nn.Module, ABC):
[docs] def __init__(self): super(BaseInterventionPolicy, self).__init__()
@abstractmethod def forward(self, x, *args, **kwargs) -> torch.Tensor: """Forward method to compute the intervention scores based on input x.""" raise NotImplementedError def build_mask( self, policy_scores: torch.tensor, sel_idx: Optional[torch.LongTensor] = None, quantile: float = 1.0, eps: float = 1e-12 ) -> torch.Tensor: B, F = policy_scores.shape device = policy_scores.device dtype = policy_scores.dtype if sel_idx is None: sel_idx = torch.arange(F, dtype=torch.long, device=device) else: sel_idx = sel_idx.to(device=device) if len(sel_idx) == 0: return torch.ones_like(policy_scores) K = sel_idx.numel() sel = policy_scores.index_select(dim=1, index=sel_idx) # [B, K] if K == 1: # Edge case: single selected column. # q < 1 => keep; q == 1 => replace. keep_col = torch.ones((B, 1), device=device, dtype=dtype) if quantile < 1.0 \ else torch.zeros((B, 1), device=device, dtype=dtype) mask = torch.ones((B, F), device=device, dtype=dtype) mask.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), keep_col) # STE proxy (optional; keeps gradients flowing on the selected col) row_max = sel.max(dim=1, keepdim=True).values + eps soft_sel = torch.log1p(sel) / torch.log1p(row_max) # [B,1] soft_proxy = torch.ones_like(policy_scores) soft_proxy.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), soft_sel) mask = (mask - soft_proxy).detach() + soft_proxy return mask # K > 1: standard per-row quantile via kthvalue k = int(max(1, min(K, 1 + math.floor(quantile * (K - 1))))) thr, _ = torch.kthvalue(sel, k, dim=1, keepdim=True) # [B,1] # Use strict '>' so ties at the threshold are replaced (robust near edges) sel_mask_hard = (sel > (thr - 0.0)).to(dtype) # [B,K] mask = torch.ones((B, F), device=device, dtype=dtype) mask.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), sel_mask_hard) # STE proxy (unchanged) row_max = sel.max(dim=1, keepdim=True).values + 1e-12 soft_sel = torch.log1p(sel) / torch.log1p(row_max) soft_proxy = torch.ones_like(policy_scores) soft_proxy.scatter_(1, sel_idx.unsqueeze(0).expand(B, -1), soft_sel) mask = (mask - soft_proxy).detach() + soft_proxy return mask