torch_concepts.nn.compute_cace¶
- compute_cace(model, dataloader, source_concept: str, target_concept: str, prob_high: float | Tensor = 1.0, prob_low: float | Tensor = 0.0) Tensor[source]¶
Compute the Causal Concept Effect of source_concept on target_concept.
Runs
do(source = prob_high)vsdo(source = prob_low)over the dataloader and returns the mean difference on the target.Values are in probability space (0–1 for binary concepts). They are converted to logits internally via
torch.logit.Binary (default):
prob_high=1, prob_low=0.Categorical: pass probability vectors, e.g.
prob_high=tensor([0, 0, 1]),prob_low=tensor([1, 0, 0]).
- Parameters:
model – A high-level concept model (e.g.
ConceptBottleneckModel).dataloader – Iterable yielding batch dicts with
{'inputs': {'x': Tensor}}.source_concept – Concept to intervene on.
target_concept – Concept whose prediction is measured.
prob_high – Probability for the high regime (default 1.0).
prob_low – Probability for the low regime (default 0.0).
- Returns:
Scalar tensor with the CaCE score.
Example:
>>> cace = compute_cace( ... model=cbm, ... dataloader=test_loader, ... source_concept="c1", ... target_concept="task", ... )