torch_concepts.nn.PositiveWeightsIntervention

class PositiveWeightsIntervention[source]

Intervention that replaces predicted concepts with ground truth values.

Implements do(C=c_true) operations by mixing predicted and ground truth concept values based on a binary mask.

Parameters:

ground_truth – Ground truth concept values of shape (batch_size, n_concepts).

__init__()[source]

Initialize the intervention module.

Methods

__init__()

Initialize the intervention module.

transform(module, *args, **kwargs)

Forward method to be implemented by subclasses.