torch_concepts.nn.functional.cace_score

cace_score(y_pred_c0, y_pred_c1)[source]

Compute the Average Causal Effect (ACE) also known as the Causal Concept Effect (CaCE) score.

The ACE/CaCE score measures the causal effect of a concept on the predictions of a model. It is computed as the absolute difference between the expected predictions when the concept is inactive (c0) and active (c1).

Main reference: “Explaining Classifiers with Causal Concept Effect (CaCE)”

Parameters:
  • y_pred_c0 (torch.Tensor) – Predictions of the model when the concept is inactive. Shape: (batch_size, num_classes).

  • y_pred_c1 (torch.Tensor) – Predictions of the model when the concept is active. Shape: (batch_size, num_classes).

Returns:

The ACE/CaCE score for each class. Shape: (num_classes,).

Return type:

torch.Tensor