torch_concepts.nn.compute_cace

compute_cace(model, dataloader, source_concept: str, target_concept: str, prob_high: float | Tensor = 1.0, prob_low: float | Tensor = 0.0) Tensor[source]

Compute the Causal Concept Effect of source_concept on target_concept.

Runs do(source = prob_high) vs do(source = prob_low) over the dataloader and returns the mean difference on the target.

Values are in probability space (0–1 for binary concepts). They are converted to logits internally via torch.logit.

  • Binary (default): prob_high=1, prob_low=0.

  • Categorical: pass probability vectors, e.g. prob_high=tensor([0, 0, 1]), prob_low=tensor([1, 0, 0]).

Parameters:
  • model – A high-level concept model (e.g. ConceptBottleneckModel).

  • dataloader – Iterable yielding batch dicts with {'inputs': {'x': Tensor}}.

  • source_concept – Concept to intervene on.

  • target_concept – Concept whose prediction is measured.

  • prob_high – Probability for the high regime (default 1.0).

  • prob_low – Probability for the low regime (default 0.0).

Returns:

Scalar tensor with the CaCE score.

Example:

>>> cace = compute_cace(  
...     model=cbm,
...     dataloader=test_loader,
...     source_concept="c1",
...     target_concept="task",
... )