High-Level Models

Ready-to-use concept-based models with automatic or manual training support.

Summary

Model Classes

ConceptBottleneckModel

Alias for ConceptBottleneckModel_Joint.

ConceptBottleneckModel_Joint

High-level Concept Bottleneck Model using BipartiteModel.

BlackBox

BlackBox model.

Class Documentation

class ConceptBottleneckModel(**kwargs)[source]

Bases: ConceptBottleneckModel_Joint

Alias for ConceptBottleneckModel_Joint.

prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
training: bool
class ConceptBottleneckModel_Joint(input_size: int, annotations: ~torch_concepts.annotations.Annotations, task_names: ~typing.List[str] | str, variable_distributions: ~typing.Mapping | None = None, inference: ~torch_concepts.nn.modules.low.base.inference.BaseInference | None = <class 'torch_concepts.nn.modules.mid.inference.forward.DeterministicInference'>, loss: ~torch.nn.modules.module.Module | None = None, metrics: ~typing.Mapping | None = None, **kwargs)[source]

Bases: BaseModel, JointLearner

High-level Concept Bottleneck Model using BipartiteModel.

Implements a two-stage architecture: 1. Backbone + Latent Encoder + Concept Encoder → Concept predictions 2. Concept predictions → Task predictions

Example

>>> from torch_concepts.nn.modules.high.models.cbm import ConceptBottleneckModel_Joint
>>> from torch_concepts.annotations import AxisAnnotation, Annotations
>>> from torch.distributions import Categorical, Bernoulli
>>> ann = Annotations({
    1: AxisAnnotation(
        labels=['c1', 'task'],
        cardinalities=[2, 1],
        metadata={
            'c1': {'type': 'discrete', 'distribution': Categorical},
            'task': {'type': 'continuous', 'distribution': Bernoulli}
        }
    )})
>>> model = ConceptBottleneckModel_Joint(
...     input_size=8,
...     annotations=ann,
...     task_names=['task'],
...     variable_distributions=None
... )
>>> x = torch.randn(2, 8)
>>> out = model(x, query=['c1', 'task'])
forward(x: Tensor, query: List[str] | None = None) Tensor[source]

Forward pass through CBM.

Parameters:
  • x (torch.Tensor) – Input data (raw or pre-computed inputs).

  • query (List[str], optional) – Variables to query from PGM. Typically all concepts and tasks. Defaults to None.

  • backbone_kwargs (Optional[Mapping[str, Any]], optional) – Arguments for backbone. Defaults to None.

  • *args – Additional arguments for future extensions.

  • **kwargs

    Additional arguments for future extensions.

Returns:

Concatenated endogenous for queried variables.

Shape: (batch_size, sum of variable cardinalities).

Return type:

torch.Tensor

filter_output_for_loss(forward_out, target)[source]

No filtering needed - return raw endogenous for standard loss computation.

Parameters:
  • forward_out – Model output endogenous.

  • target – Ground truth labels.

Returns:

Dict with ‘input’ and ‘target’ for loss computation.

filter_output_for_metrics(forward_out, target)[source]

No filtering needed - return raw endogenous for metric computation.

Parameters:
  • forward_out – Model output endogenous.

  • target – Ground truth labels.

Returns:

Dict with ‘input’ and ‘target’ for metric computation.

prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
training: bool
class BlackBox(input_size: int, annotations: Annotations, variable_distributions: Mapping | None = None, loss: Module | None = None, metrics: Mapping | None = None, inference: bool = False, **kwargs)[source]

Bases: BaseModel, JointLearner

BlackBox model.

This model implements a standard neural network architecture for concept-based tasks, without explicit concept bottleneck or interpretable intermediate representations. It uses a backbone for feature extraction and a latent encoder for concepts prediction.

Parameters:
  • input_size (int) – Dimensionality of input features.

  • annotations (Annotations) – Annotation object for output variables.

  • loss (nn.Module, optional) – Loss function for training.

  • metrics (Mapping, optional) – Metrics for evaluation.

  • backbone (nn.Module, optional) – Feature extraction module.

  • latent_encoder (nn.Module, optional) – Latent encoder module.

  • latent_encoder_kwargs (dict, optional) – Arguments for latent encoder.

  • **kwargs – Additional arguments for BaseModel.

Example

>>> model = BlackBox(input_size=8, annotations=ann)
>>> out = model(torch.randn(2, 8))
forward(x: Tensor, query: List[str] | None = None) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

filter_output_for_loss(forward_out, target)[source]

No filtering needed - return raw endogenous for standard loss computation.

Parameters:
  • forward_out – Model output endogenous.

  • target – Ground truth labels.

Returns:

Dict with ‘input’ and ‘target’ for loss computation.

filter_output_for_metrics(forward_out, target)[source]

No filtering needed - return raw endogenous for metric computation.

Parameters:
  • forward_out – Model output endogenous.

  • target – Ground truth labels.

Returns:

Dict with ‘input’ and ‘target’ for metric computation.

prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool
training: bool