"""Loss functions for concept-based models."""
import inspect
import warnings
from typing import List, Mapping, Optional, Union
import torch
from torch import nn
from .utils import GroupConfig, check_collection
from .outputs import ModelOutput
from ...annotations import Annotations
from ...utils import instantiate_from_string
from ...concept_graph import ConceptGraph
def _get_forward_signature(module: nn.Module):
"""Introspect forward() to get accepted parameter names and whether it has **kwargs.
Returns:
Tuple[set, bool]: (set of parameter names, has_var_keyword)
"""
params = inspect.signature(module.forward).parameters
names = set()
has_var_keyword = False
for name, param in params.items():
if param.kind == inspect.Parameter.VAR_KEYWORD:
has_var_keyword = True
else:
names.add(name)
return names, has_var_keyword
def _normalize_loss_terms(terms, weights):
"""Normalize loss terms and weights to consistent list form.
Args:
terms: A single nn.Module, a list of nn.Module, or None.
weights: A list of floats, or None.
Returns:
Tuple of (list_of_modules, list_of_weights), or (None, None) if terms is None.
"""
if terms is None:
return None, None
if isinstance(terms, nn.Module):
terms = [terms]
if not isinstance(terms, (list, tuple)):
raise TypeError(
f"Loss terms must be an nn.Module or a list of nn.Module, got {type(terms)}"
)
if weights is None:
weights = [1.0] * len(terms)
if len(weights) != len(terms):
raise ValueError(
f"Number of weights ({len(weights)}) must match "
f"number of loss terms ({len(terms)})."
)
return list(terms), list(weights)
def get_concept_task_idx(annotations: Annotations, concepts: List[str], tasks: List[str]):
"""Get concept and task indices at both concept-level and logit-level."""
# Concept-level indices
concepts_idxs = [annotations.get_index(name) for name in concepts]
tasks_idxs = [annotations.get_index(name) for name in tasks]
# Logit-level indices using cached get_slice
concepts_logits = annotations.get_slice(concepts)
tasks_logits = annotations.get_slice(tasks)
return concepts_idxs, tasks_idxs, concepts_logits, tasks_logits
[docs]
class ConceptLoss(nn.Module):
"""
Concept loss for concept-based models.
Automatically routes to appropriate loss functions based on concept types
(binary, categorical, continuous) using annotation metadata. Each type
accepts either a single loss module or a list of loss modules with
optional per-term weights, enabling type-specific composition (e.g.
adding a regularizer only to binary concepts).
Args:
annotations (Annotations): Concept annotations with metadata including
type information for each concept.
binary (nn.Module or list of nn.Module, optional): Loss function(s)
for binary concepts. A single module (e.g. ``BCEWithLogitsLoss()``)
or a list of modules to be summed.
categorical (nn.Module or list of nn.Module, optional): Loss function(s)
for categorical concepts. A single module (e.g.
``CrossEntropyLoss()``) or a list of modules.
continuous (nn.Module or list of nn.Module, optional): Loss function(s)
for continuous concepts (e.g. ``MSELoss()``). Not yet supported.
binary_weights (list of float, optional): Per-term weights when
``binary`` is a list. Defaults to ``[1.0, ...]``.
categorical_weights (list of float, optional): Per-term weights when
``categorical`` is a list. Defaults to ``[1.0, ...]``.
continuous_weights (list of float, optional): Per-term weights when
``continuous`` is a list. Defaults to ``[1.0, ...]``.
Example:
>>> from torch_concepts.nn import ConceptLoss, L1LogitRegularizer
>>> from torch_concepts import Annotations
>>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
>>>
>>> ann = Annotations(
... labels=['is_round', 'color'],
... cardinalities=[1, 3],
... types=['binary', 'categorical'],
... )
>>>
>>> # Single loss per type (backward compatible)
>>> loss_fn = ConceptLoss(
... ann,
... binary=BCEWithLogitsLoss(),
... categorical=CrossEntropyLoss()
... )
>>>
>>> # Composite loss per type with weights
>>> loss_fn = ConceptLoss(
... ann,
... binary=[BCEWithLogitsLoss(), L1LogitRegularizer(scale=0.01)],
... binary_weights=[1.0, 0.5],
... categorical=CrossEntropyLoss()
... )
"""
[docs]
def __init__(
self,
annotations: Annotations,
binary: Optional[Union[nn.Module, List[nn.Module]]] = None,
categorical: Optional[Union[nn.Module, List[nn.Module]]] = None,
continuous: Optional[Union[nn.Module, List[nn.Module]]] = None,
binary_weights: Optional[List[float]] = None,
categorical_weights: Optional[List[float]] = None,
continuous_weights: Optional[List[float]] = None,
):
super().__init__()
# Normalize to lists
binary, binary_weights = _normalize_loss_terms(binary, binary_weights)
categorical, categorical_weights = _normalize_loss_terms(categorical, categorical_weights)
continuous, continuous_weights = _normalize_loss_terms(continuous, continuous_weights)
# Validate against annotations (check_collection checks None vs not-None)
fn_collection = GroupConfig(binary=binary, categorical=categorical, continuous=continuous)
self.fn_collection = check_collection(annotations, fn_collection, 'loss')
# Use cached type_groups from Annotations
self.groups = annotations.type_groups
self.cardinalities = annotations.cardinalities
# Register modules, weights, and signatures per type
self._type_weights = {}
self._type_signatures = {}
weights_map = {
'binary': binary_weights,
'categorical': categorical_weights,
'continuous': continuous_weights,
}
for type_name in ['binary', 'categorical', 'continuous']:
terms = self.fn_collection.get(type_name)
if terms is not None:
# Register as nn.ModuleList for proper parameter tracking
setattr(self, f'_{type_name}_terms', nn.ModuleList(terms))
self._type_weights[type_name] = weights_map[type_name]
self._type_signatures[type_name] = [
_get_forward_signature(m) for m in terms
]
# For categorical loss, precompute max cardinality for padding
if self.fn_collection.get('categorical'):
cat_idx = self.groups['categorical']['concept_idx']
self.max_card = max([self.cardinalities[i] for i in cat_idx])
if self.fn_collection.get('continuous'):
cont_idx = self.groups['continuous']['concept_idx']
self.max_dim = max([self.cardinalities[i] for i in cont_idx])
def __repr__(self) -> str:
types = ['binary', 'categorical', 'continuous']
parts = []
for t in types:
terms = self.fn_collection.get(t)
if terms is not None:
weights = self._type_weights[t]
if len(terms) == 1 and weights[0] == 1.0:
name = terms[0].__class__.__name__
parts.append(f"{t}={name}")
else:
term_strs = []
for m, w in zip(terms, weights):
n = m.__class__.__name__
term_strs.append(f"{w}*{n}" if w != 1.0 else n)
parts.append(f"{t}=[{' + '.join(term_strs)}]")
return f"{self.__class__.__name__}({', '.join(parts)})"
def _compute_type_loss(self, type_name: str, kwargs: dict) -> torch.Tensor:
"""Compute weighted sum of loss terms for a specific concept type.
Each term receives only the kwargs its ``forward()`` signature accepts.
If ``padding_mask`` is present in *kwargs* but a term's signature does
not accept it (and has no ``**kwargs``), a warning is emitted so that
users are aware their custom loss/regularizer is receiving padded
values without explicit masking information.
"""
terms = getattr(self, f'_{type_name}_terms')
weights = self._type_weights[type_name]
signatures = self._type_signatures[type_name]
has_padding = 'padding_mask' in kwargs
total = torch.tensor(0.0, device=kwargs['input'].device)
for module, weight, (sig, has_var_kw) in zip(terms, weights, signatures):
if has_var_kw:
term_kwargs = dict(kwargs)
else:
term_kwargs = {k: v for k, v in kwargs.items() if k in sig}
if has_padding and 'padding_mask' not in sig and 'target' not in sig:
warnings.warn(
f"{module.__class__.__name__} does not accept a "
f"'padding_mask' parameter. Categorical concept "
f"logits are padded with -inf for concepts with "
f"cardinality < max_cardinality. If this module "
f"could be affected by this, add a 'padding_mask' parameter "
f"to its forward() to handle padded positions "
f"correctly.",
stacklevel=2,
)
total = total + weight * module(**term_kwargs)
return total
def _prepare_categorical(self, input: torch.Tensor, target: torch.Tensor):
"""Pad and stack categorical logits/targets for summary computation.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
``(padded_logits, targets, padding_mask)`` ready for loss
functions like ``CrossEntropyLoss``. ``padding_mask`` is a
boolean tensor of the same shape as ``padded_logits`` that is
``True`` for real logit positions and ``False`` for padding.
"""
cat_concept_idx = self.groups['categorical']['concept_idx']
split_tuple = torch.split(
input[:, self.groups['categorical']['logits_idx']],
[self.cardinalities[i] for i in cat_concept_idx],
dim=1,
)
padded_logits = []
masks = []
for logits in split_tuple:
pad_size = self.max_card - logits.shape[1]
padded_logits.append(
nn.functional.pad(logits, (0, pad_size), value=float('-inf'))
)
mask = torch.ones(
logits.shape[0], self.max_card,
dtype=torch.bool, device=logits.device,
)
if pad_size > 0:
mask[:, -pad_size:] = False
masks.append(mask)
cat_logits = torch.cat(padded_logits, dim=0)
cat_mask = torch.cat(masks, dim=0)
cat_targets = target[:, cat_concept_idx].T.reshape(-1).long()
return cat_logits, cat_targets, cat_mask
def forward(self, output: ModelOutput) -> torch.Tensor:
"""Compute total loss across all concept types.
Splits ``output.logits`` and ``output.target`` by concept type,
merges them with any extras, computes individual losses (each a
weighted sum of its terms dispatched by signature), and sums them.
Args:
output (ModelOutput): Structured model output containing
``logits``, ``target``, and optionally ``extras``.
Returns:
torch.Tensor: Total computed loss (scalar).
"""
input = output.logits
target = output.target
extra = dict(output.extra) if output.extra else {}
total_loss = torch.tensor(0.0, device=input.device)
# Binary concepts
if self.fn_collection.get('binary'):
binary_logits = input[:, self.groups['binary']['logits_idx']]
binary_targets = target[:, self.groups['binary']['concept_idx']].float()
total_loss = total_loss + self._compute_type_loss('binary', {
'input': binary_logits, 'target': binary_targets, **extra
})
# Categorical concepts
if self.fn_collection.get('categorical'):
cat_logits, cat_targets, cat_mask = self._prepare_categorical(input, target)
total_loss = total_loss + self._compute_type_loss('categorical', {
'input': cat_logits, 'target': cat_targets,
'padding_mask': cat_mask, **extra
})
# Continuous concepts
if self.fn_collection.get('continuous'):
raise NotImplementedError("Continuous concepts not yet implemented.")
return total_loss
[docs]
class WeightedConceptLoss(nn.Module):
"""
Weighted concept loss for concept-based models.
Computes a weighted combination of concept and task losses.
Args:
annotations (Annotations): Annotations object with concept metadata.
concept_weight (float): Weight for concept loss.
task_weight (float): Weight for task loss.
task_names (List[str]): List of task concept names.
binary (nn.Module or list of nn.Module, optional): Loss function(s) for binary concepts.
categorical (nn.Module or list of nn.Module, optional): Loss function(s) for categorical concepts.
continuous (nn.Module or list of nn.Module, optional): Loss function(s) for continuous concepts.
binary_weights (list of float, optional): Per-term weights when ``binary`` is a list.
categorical_weights (list of float, optional): Per-term weights when ``categorical`` is a list.
continuous_weights (list of float, optional): Per-term weights when ``continuous`` is a list.
Example:
>>> from torch_concepts.nn.modules.loss import WeightedConceptLoss
>>> from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
>>> from torch_concepts.annotations import Annotations
>>> import torch
>>> from torch_concepts.nn.modules.outputs import ModelOutput
>>> ann = Annotations(labels=['c1', 'c2', 'task'], cardinalities=[1, 1, 1])
>>> loss_fn = WeightedConceptLoss(
... ann, concept_weight=0.7, task_weight=0.3,
... task_names=['task'], binary=BCEWithLogitsLoss()
... )
>>> out = ModelOutput(logits=torch.randn(2, 3), target=torch.randint(0, 2, (2, 3)).float())
>>> loss = loss_fn(out)
"""
[docs]
def __init__(
self,
annotations: Annotations,
concept_weight: float,
task_weight: float,
task_names: List[str],
binary: Optional[Union[nn.Module, List[nn.Module]]] = None,
categorical: Optional[Union[nn.Module, List[nn.Module]]] = None,
continuous: Optional[Union[nn.Module, List[nn.Module]]] = None,
binary_weights: Optional[List[float]] = None,
categorical_weights: Optional[List[float]] = None,
continuous_weights: Optional[List[float]] = None,
):
super().__init__()
self.concept_weight = concept_weight
self.task_weight = task_weight
fn_collection = GroupConfig(binary=binary, categorical=categorical, continuous=continuous)
self.fn_collection = fn_collection
concept_names = [name for name in annotations.labels if name not in task_names]
task_annotations = annotations.subset(task_names)
concept_annotations = annotations.subset(concept_names)
self.concept_loss = ConceptLoss(
concept_annotations, binary=binary, categorical=categorical, continuous=continuous,
binary_weights=binary_weights, categorical_weights=categorical_weights,
continuous_weights=continuous_weights,
)
self.task_loss = ConceptLoss(
task_annotations, binary=binary, categorical=categorical, continuous=continuous,
binary_weights=binary_weights, categorical_weights=categorical_weights,
continuous_weights=continuous_weights,
)
self.target_c_idx, self.target_t_idx, self.input_c_idx, self.input_t_idx = get_concept_task_idx(
annotations, concept_names, task_names
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(fn_collection={self.fn_collection})"
def forward(self, output: ModelOutput) -> torch.Tensor:
"""Compute weighted loss for concepts and tasks.
Args:
output (ModelOutput): Structured model output containing
``logits``, ``target``, and optionally ``extras``.
Returns:
torch.Tensor: Weighted combination of concept and task losses (scalar).
"""
input = output.logits
target = output.target
extra = dict(output.extra) if output.extra else {}
concept_input = input[:, self.input_c_idx]
concept_target = target[:, self.target_c_idx]
task_input = input[:, self.input_t_idx]
task_target = target[:, self.target_t_idx]
# FIXME: update ModelOutput to generalize beyond logits
c_sub = ModelOutput(target=concept_target, extra=extra or None)
c_sub.logits = concept_input
t_sub = ModelOutput(target=task_target, extra=extra or None)
t_sub.logits = task_input
c_loss = self.concept_loss(c_sub)
t_loss = self.task_loss(t_sub)
return c_loss * self.concept_weight + t_loss * self.task_weight
[docs]
class DepthWeightedConceptLoss(nn.Module):
"""Depth-weighted concept loss for graph-structured concept models.
Applies different weights to concept losses based on their depth
in a directed acyclic graph (DAG). Concepts at the graph sources
(roots, depth 0) receive ``source_weight``; at each subsequent depth
level the weight is multiplied by ``depth_decay``.
Weight at depth *d* = ``source_weight * depth_decay ** d``
Args:
annotations (Annotations): Concept annotations with metadata.
graph (ConceptGraph): DAG defining structure among concepts.
source_weight (float): Weight applied to loss terms at depth 0
(graph sources). Default ``1.0``.
depth_decay (float): Multiplicative factor applied at every
additional depth level. Values < 1 down-weight deeper
concepts; values > 1 up-weight them. Default ``0.5``.
binary (nn.Module or list of nn.Module, optional): Loss function(s)
for binary concepts (e.g. ``BCEWithLogitsLoss()``).
categorical (nn.Module or list of nn.Module, optional): Loss function(s)
for categorical concepts (e.g. ``CrossEntropyLoss()``).
continuous (nn.Module or list of nn.Module, optional): Loss function(s)
for continuous concepts (e.g. ``MSELoss()``). Not yet supported.
binary_weights (list of float, optional): Per-term weights when
``binary`` is a list.
categorical_weights (list of float, optional): Per-term weights when
``categorical`` is a list.
continuous_weights (list of float, optional): Per-term weights when
``continuous`` is a list.
Example:
>>> import torch
>>> from torch_concepts.nn.modules.loss import DepthWeightedConceptLoss
>>> from torch_concepts.annotations import Annotations
>>> from torch_concepts import ConceptGraph
>>>
>>> ann = Annotations(
... labels=['A', 'B', 'C'],
... cardinalities=[1, 1, 1],
... types=['binary', 'binary', 'binary'],
... )
>>> adj = torch.tensor([[0., 1., 0.],
... [0., 0., 1.],
... [0., 0., 0.]])
>>> graph = ConceptGraph(adj, node_names=['A', 'B', 'C'])
>>> loss_fn = DepthWeightedConceptLoss(
... ann, graph,
... source_weight=1.0, depth_decay=0.5,
... binary=torch.nn.BCEWithLogitsLoss()
... )
>>> from torch_concepts.nn.modules.outputs import ModelOutput
>>> out = ModelOutput(logits=torch.randn(4, 3), target=torch.randint(0, 2, (4, 3)).float())
>>> loss = loss_fn(out)
"""
[docs]
def __init__(
self,
annotations: Annotations,
graph: ConceptGraph,
source_weight: float = 1.0,
depth_decay: float = 0.5,
binary: Optional[Union[nn.Module, List[nn.Module]]] = None,
categorical: Optional[Union[nn.Module, List[nn.Module]]] = None,
continuous: Optional[Union[nn.Module, List[nn.Module]]] = None,
binary_weights: Optional[List[float]] = None,
categorical_weights: Optional[List[float]] = None,
continuous_weights: Optional[List[float]] = None,
):
super().__init__()
self.source_weight = source_weight
self.depth_decay = depth_decay
axis = annotations
concept_names = list(axis.labels)
concept_set = set(concept_names)
# Compute levels from graph
depth_levels = graph.get_levels()
# For each depth level store a ConceptLoss sub-module,
# concept-level indices (target slicing), logit-level indices
# (input slicing), and the corresponding weight.
self._depth_levels: List[int] = []
self._depth_weights_list: List[float] = []
self._target_idx: List[List[int]] = []
self._input_idx: List[List[int]] = []
for d, level_names in enumerate(depth_levels):
# Keep only concepts that appear in the annotations
names = [n for n in level_names if n in concept_set]
if not names:
continue
sub_ann = axis.subset(names)
key = f"loss_depth_{d}"
sub_loss = ConceptLoss(
sub_ann,
binary=binary,
categorical=categorical,
continuous=continuous,
binary_weights=binary_weights,
categorical_weights=categorical_weights,
continuous_weights=continuous_weights,
)
setattr(self, key, sub_loss)
self._depth_levels.append(d)
self._target_idx.append([axis.get_index(n) for n in names])
self._input_idx.append(axis.get_slice(names))
self._depth_weights_list.append(source_weight * (depth_decay ** d))
# Concepts not in the graph get depth 0
graph_names = {n for level in depth_levels for n in level}
missing = [n for n in concept_names if n not in graph_names]
if missing:
sub_ann = axis.subset(missing)
key = "loss_depth_0"
if not hasattr(self, key):
sub_loss = ConceptLoss(
sub_ann,
binary=binary,
categorical=categorical,
continuous=continuous,
binary_weights=binary_weights,
categorical_weights=categorical_weights,
continuous_weights=continuous_weights,
)
setattr(self, key, sub_loss)
self._depth_levels.insert(0, 0)
self._target_idx.insert(0, [axis.get_index(n) for n in missing])
self._input_idx.insert(0, axis.get_slice(missing))
self._depth_weights_list.insert(0, source_weight)
# ------------------------------------------------------------------
# repr
# ------------------------------------------------------------------
def __repr__(self) -> str:
parts = []
for d, w in zip(self._depth_levels, self._depth_weights_list):
parts.append(f"depth_{d}: weight={w:.4g}")
return f"{self.__class__.__name__}({', '.join(parts)})"
# ------------------------------------------------------------------
# forward
# ------------------------------------------------------------------
def forward(self, output: ModelOutput) -> torch.Tensor:
"""Compute depth-weighted loss across all concept depths.
Args:
output (ModelOutput): Structured model output containing
``logits``, ``target``, and optionally ``extras``.
Returns:
torch.Tensor: Total depth-weighted loss (scalar).
"""
input = output.logits
target = output.target
extra = dict(output.extra) if output.extra else {}
total_loss = torch.tensor(0.0, device=input.device)
for i, d in enumerate(self._depth_levels):
sub_input = input[:, self._input_idx[i]]
sub_target = target[:, self._target_idx[i]]
sub_loss = getattr(self, f"loss_depth_{d}")
sub_out = ModelOutput(target=sub_target, extra=extra or None)
# FIXME: update ModelOutput to generalize beyond logits
sub_out.logits = sub_input
total_loss = total_loss + self._depth_weights_list[i] * sub_loss(sub_out)
return total_loss
[docs]
class L1LogitRegularizer(nn.Module):
"""Penalise large logit magnitudes via L1 regularisation.
Computes ``scale * mean(|input|)`` over all valid (non-padded)
positions. When used as a categorical loss term inside
:class:`ConceptLoss`, a ``padding_mask`` is automatically provided
to distinguish real logits from padding.
:class:`ConceptLoss`::
loss_fn = ConceptLoss(
annotations=ann,
binary=[BCEWithLogitsLoss(), L1LogitRegularizer(scale=0.01)],
binary_weights=[1.0, 0.5],
)
Args:
scale (float): Multiplicative factor applied to the L1 mean.
Default ``1.0``.
Returns:
torch.Tensor: Scalar regularisation loss.
"""
[docs]
def __init__(self, scale: float = 1.0):
super().__init__()
self.scale = scale
def forward(
self,
input: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if padding_mask is not None:
mask = padding_mask
else:
mask = torch.isfinite(input)
if mask.any():
return self.scale * input[mask].abs().mean()
return torch.tensor(0.0, device=input.device)