Source code for torch_concepts.nn.modules.low.intervention.strategy.positive_weights

import torch

from ...base.intervention import BaseModuleInterventionStrategy


[docs] class PositiveWeightsIntervention(BaseModuleInterventionStrategy): """ Intervention that replaces predicted concepts with ground truth values. Implements do(C=c_true) operations by mixing predicted and ground truth concept values based on a binary mask. Args: ground_truth: Ground truth concept values of shape (batch_size, n_concepts). """
[docs] def __init__(self): super().__init__()
def transform(self, module, *args, **kwargs): # find all parameters in the module and apply ReLU to them for name, param in module.named_parameters(): with torch.no_grad(): param.copy_(torch.relu(param)) return module