High-Level Models¶
Ready-to-use concept-based models with automatic or manual training support.
Summary¶
Model Classes
Alias for ConceptBottleneckModel_Joint. |
|
High-level Concept Bottleneck Model using BipartiteModel. |
|
BlackBox model. |
Class Documentation¶
- class ConceptBottleneckModel(**kwargs)[source]¶
Bases:
ConceptBottleneckModel_JointAlias for ConceptBottleneckModel_Joint.
- class ConceptBottleneckModel_Joint(input_size: int, annotations: ~torch_concepts.annotations.Annotations, task_names: ~typing.List[str] | str, variable_distributions: ~typing.Mapping | None = None, inference: ~torch_concepts.nn.modules.low.base.inference.BaseInference | None = <class 'torch_concepts.nn.modules.mid.inference.forward.DeterministicInference'>, loss: ~torch.nn.modules.module.Module | None = None, metrics: ~typing.Mapping | None = None, **kwargs)[source]¶
Bases:
BaseModel,JointLearnerHigh-level Concept Bottleneck Model using BipartiteModel.
Implements a two-stage architecture: 1. Backbone + Latent Encoder + Concept Encoder → Concept predictions 2. Concept predictions → Task predictions
Example
>>> from torch_concepts.nn.modules.high.models.cbm import ConceptBottleneckModel_Joint >>> from torch_concepts.annotations import AxisAnnotation, Annotations >>> from torch.distributions import Categorical, Bernoulli >>> ann = Annotations({ 1: AxisAnnotation( labels=['c1', 'task'], cardinalities=[2, 1], metadata={ 'c1': {'type': 'discrete', 'distribution': Categorical}, 'task': {'type': 'continuous', 'distribution': Bernoulli} } )}) >>> model = ConceptBottleneckModel_Joint( ... input_size=8, ... annotations=ann, ... task_names=['task'], ... variable_distributions=None ... ) >>> x = torch.randn(2, 8) >>> out = model(x, query=['c1', 'task'])
- forward(x: Tensor, query: List[str] | None = None) Tensor[source]¶
Forward pass through CBM.
- Parameters:
x (torch.Tensor) – Input data (raw or pre-computed inputs).
query (List[str], optional) – Variables to query from PGM. Typically all concepts and tasks. Defaults to None.
backbone_kwargs (Optional[Mapping[str, Any]], optional) – Arguments for backbone. Defaults to None.
*args – Additional arguments for future extensions.
**kwargs –
Additional arguments for future extensions.
- Returns:
- Concatenated endogenous for queried variables.
Shape: (batch_size, sum of variable cardinalities).
- Return type:
- filter_output_for_loss(forward_out, target)[source]¶
No filtering needed - return raw endogenous for standard loss computation.
- Parameters:
forward_out – Model output endogenous.
target – Ground truth labels.
- Returns:
Dict with ‘input’ and ‘target’ for loss computation.
- class BlackBox(input_size: int, annotations: Annotations, variable_distributions: Mapping | None = None, loss: Module | None = None, metrics: Mapping | None = None, inference: bool = False, **kwargs)[source]¶
Bases:
BaseModel,JointLearnerBlackBox model.
This model implements a standard neural network architecture for concept-based tasks, without explicit concept bottleneck or interpretable intermediate representations. It uses a backbone for feature extraction and a latent encoder for concepts prediction.
- Parameters:
input_size (int) – Dimensionality of input features.
annotations (Annotations) – Annotation object for output variables.
loss (nn.Module, optional) – Loss function for training.
metrics (Mapping, optional) – Metrics for evaluation.
backbone (nn.Module, optional) – Feature extraction module.
latent_encoder (nn.Module, optional) – Latent encoder module.
latent_encoder_kwargs (dict, optional) – Arguments for latent encoder.
**kwargs – Additional arguments for BaseModel.
Example
>>> model = BlackBox(input_size=8, annotations=ann) >>> out = model(torch.randn(2, 8))
- forward(x: Tensor, query: List[str] | None = None) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- filter_output_for_loss(forward_out, target)[source]¶
No filtering needed - return raw endogenous for standard loss computation.
- Parameters:
forward_out – Model output endogenous.
target – Ground truth labels.
- Returns:
Dict with ‘input’ and ‘target’ for loss computation.