Loss Functions¶
Concept-aware loss functions with automatic routing and weighting.
Summary¶
High-Level Losses
Concept loss for concept-based models. |
|
Weighted concept loss for concept-based models. |
Low-Level Losses
Class Documentation¶
- class ConceptLoss(annotations: Annotations, fn_collection: GroupConfig)[source]¶
Bases:
ModuleConcept loss for concept-based models.
Automatically routes to appropriate loss functions based on concept types (binary, categorical, continuous) using annotation metadata.
- Parameters:
annotations (Annotations) – Concept annotations with metadata including type information for each concept.
fn_collection (GroupConfig) – Loss function configuration per concept type. Keys should be ‘binary’, ‘categorical’, and/or ‘continuous’.
Example
>>> from torch_concepts.nn import ConceptLoss >>> from torch_concepts import GroupConfig, Annotations, AxisAnnotation >>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss >>> from torch.distributions import Bernoulli, Categorical >>> >>> # Define annotations >>> ann = Annotations({1: AxisAnnotation( ... labels=['is_round', 'color'], ... cardinalities=[1, 3], ... metadata={ ... 'is_round': {'type': 'discrete', 'distribution': Bernoulli}, ... 'color': {'type': 'discrete', 'distribution': Categorical} ... } ... )}) >>> >>> # Configure loss functions >>> loss_config = GroupConfig( ... binary=BCEWithLogitsLoss(), ... categorical=CrossEntropyLoss() ... ) >>> loss_fn = ConceptLoss(ann[1], loss_config) >>> >>> # Compute loss >>> predictions = torch.randn(2, 4) # 1 binary + 3 categorical logits >>> targets = torch.cat([ ... torch.randint(0, 2, (2, 1)), # binary target ... torch.randint(0, 3, (2, 1)) # categorical target ... ], dim=1) >>> loss = loss_fn(predictions, targets)
- forward(input: Tensor, target: Tensor) Tensor[source]¶
Compute total loss across all concept types.
Splits inputs and targets by concept type, computes individual losses, and sums them to get the total loss.
- Parameters:
input (torch.Tensor) – Model predictions in endogenous space (logits).
target (torch.Tensor) – Ground truth labels/values.
- Returns:
Total computed loss (scalar).
- Return type:
- class WeightedConceptLoss(annotations: Annotations, fn_collection: GroupConfig, concept_weight: float, task_weight: float, task_names: List[str])[source]¶
Bases:
ModuleWeighted concept loss for concept-based models.
Computes a weighted combination of concept and task losses.
- Parameters:
annotations (Annotations) – Annotations object with concept metadata.
fn_collection (GroupConfig) – Loss function configuration.
concept_weight (float) – Weight for concept loss
task_weight (float) – Weight for task loss
task_names (List[str]) – List of task concept names.
Example
>>> from torch_concepts.nn.modules.loss import WeightedConceptLoss >>> from torch_concepts.nn.modules.utils import GroupConfig >>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss >>> from torch_concepts.annotations import AxisAnnotation, Annotations >>> ann = Annotations({1: AxisAnnotation(labels=['c1', 'c2', 'task'], cardinalities=[1, 3, 1])}) >>> fn = GroupConfig(binary=BCEWithLogitsLoss(), categorical=CrossEntropyLoss()) >>> loss_fn = WeightedConceptLoss(ann, fn, weight=0.7, task_names=['task']) >>> input = torch.randn(2, 5) >>> target = torch.randint(0, 2, (2, 3)) >>> loss = loss_fn(input, target)
- forward(input: Tensor, target: Tensor) Tensor[source]¶
Compute weighted loss for concepts and tasks.
- Parameters:
input (torch.Tensor) – Model predictions in endogenous space (logits).
target (torch.Tensor) – Ground truth labels/values.
- Returns:
Weighted combination of concept and task losses (scalar).
- Return type: