Loss Functions

Concept-aware loss functions with automatic routing and weighting.

Summary

High-Level Losses

ConceptLoss

Concept loss for concept-based models.

WeightedConceptLoss

Weighted concept loss for concept-based models.

Low-Level Losses

Class Documentation

class ConceptLoss(annotations: Annotations, fn_collection: GroupConfig)[source]

Bases: Module

Concept loss for concept-based models.

Automatically routes to appropriate loss functions based on concept types (binary, categorical, continuous) using annotation metadata.

Parameters:
  • annotations (Annotations) – Concept annotations with metadata including type information for each concept.

  • fn_collection (GroupConfig) – Loss function configuration per concept type. Keys should be ‘binary’, ‘categorical’, and/or ‘continuous’.

Example

>>> from torch_concepts.nn import ConceptLoss
>>> from torch_concepts import GroupConfig, Annotations, AxisAnnotation
>>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
>>> from torch.distributions import Bernoulli, Categorical
>>>
>>> # Define annotations
>>> ann = Annotations({1: AxisAnnotation(
...     labels=['is_round', 'color'],
...     cardinalities=[1, 3],
...     metadata={
...         'is_round': {'type': 'discrete', 'distribution': Bernoulli},
...         'color': {'type': 'discrete', 'distribution': Categorical}
...     }
... )})
>>>
>>> # Configure loss functions
>>> loss_config = GroupConfig(
...     binary=BCEWithLogitsLoss(),
...     categorical=CrossEntropyLoss()
... )
>>> loss_fn = ConceptLoss(ann[1], loss_config)
>>>
>>> # Compute loss
>>> predictions = torch.randn(2, 4)  # 1 binary + 3 categorical logits
>>> targets = torch.cat([
...     torch.randint(0, 2, (2, 1)),  # binary target
...     torch.randint(0, 3, (2, 1))   # categorical target
... ], dim=1)
>>> loss = loss_fn(predictions, targets)
forward(input: Tensor, target: Tensor) Tensor[source]

Compute total loss across all concept types.

Splits inputs and targets by concept type, computes individual losses, and sums them to get the total loss.

Parameters:
  • input (torch.Tensor) – Model predictions in endogenous space (logits).

  • target (torch.Tensor) – Ground truth labels/values.

Returns:

Total computed loss (scalar).

Return type:

torch.Tensor

training: bool
class WeightedConceptLoss(annotations: Annotations, fn_collection: GroupConfig, concept_weight: float, task_weight: float, task_names: List[str])[source]

Bases: Module

Weighted concept loss for concept-based models.

Computes a weighted combination of concept and task losses.

Parameters:
  • annotations (Annotations) – Annotations object with concept metadata.

  • fn_collection (GroupConfig) – Loss function configuration.

  • concept_weight (float) – Weight for concept loss

  • task_weight (float) – Weight for task loss

  • task_names (List[str]) – List of task concept names.

Example

>>> from torch_concepts.nn.modules.loss import WeightedConceptLoss
>>> from torch_concepts.nn.modules.utils import GroupConfig
>>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
>>> from torch_concepts.annotations import AxisAnnotation, Annotations
>>> ann = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'task'], cardinalities=[1, 3, 1])})
>>> fn = GroupConfig(binary=BCEWithLogitsLoss(), categorical=CrossEntropyLoss())
>>> loss_fn = WeightedConceptLoss(ann, fn, weight=0.7, task_names=['task'])
>>> input = torch.randn(2, 5)
>>> target = torch.randint(0, 2, (2, 3))
>>> loss = loss_fn(input, target)
forward(input: Tensor, target: Tensor) Tensor[source]

Compute weighted loss for concepts and tasks.

Parameters:
  • input (torch.Tensor) – Model predictions in endogenous space (logits).

  • target (torch.Tensor) – Ground truth labels/values.

Returns:

Weighted combination of concept and task losses (scalar).

Return type:

torch.Tensor

training: bool