torch_concepts.nn.functional.cace_score¶
- cace_score(y_pred_c0, y_pred_c1)[source]¶
Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.
The ACE/CaCE score measures the causal effect of a concept on the predictions of a model. It is computed as the absolute difference between the expected predictions when the concept is inactive (c0) and active (c1).
Main reference: “Explaining Classifiers with Causal Concept Effect (CaCE)”
- Parameters:
y_pred_c0 (torch.Tensor) – Predictions of the model when the concept is inactive. Shape: (batch_size, num_classes).
y_pred_c1 (torch.Tensor) – Predictions of the model when the concept is active. Shape: (batch_size, num_classes).
- Returns:
The ACE/CaCE score for each class. Shape: (num_classes,).
- Return type: