Source code for torch_concepts.nn.modules.high.models.blackbox

import torch
from torch import nn
from typing import List, Optional, Union

from .....utils import ensure_list
from .....annotations import Annotations
from ...metrics import ConceptMetrics
from ...loss import ConceptLoss
from ...outputs import ModelOutput, logits_from_params

from ...low.dense_layers import MLP
from ..base.model import BaseModel


[docs] class BlackBox(BaseModel): """ BlackBox model. This model implements a standard neural network architecture for concept-based tasks, without explicit concept bottleneck or interpretable intermediate representations. It uses a backbone mapping the raw input to the latent representation, then a linear head. Args: input_size (int): Dimensionality of input features. annotations (Annotations): Annotation object for output variables. lightning (bool, optional): Enable Lightning training. Default False. **kwargs: Additional arguments for BaseModel. Example: >>> from torch_concepts.annotations import Annotations >>> ann = Annotations(labels=['c1', 'task'], cardinalities=[1, 1]) >>> model = BlackBox(input_size=8, annotations=ann) >>> out = model(torch.randn(2, 8)) """
[docs] def __init__( self, input_size: int, annotations: Annotations, lightning: bool = False, **kwargs ) -> None: super().__init__( input_size=input_size, annotations=annotations, lightning=lightning, **kwargs ) output_size = sum(self.concept_annotations.cardinalities) self.linear = nn.Linear(self.latent_size, output_size)
def build_query(self, ground_truth) -> dict: """Build query dict mapping each concept name to its ground-truth column. Parameters ---------- ground_truth : torch.Tensor Full concept-level ground truth, shape ``(batch, n_concepts)``. Returns ------- dict ``{concept_name: tensor(batch, cardinality)}`` for every concept. """ if ground_truth is None: return {name: None for name in self.concept_names} axis = self.concept_annotations query = {} for i, name in enumerate(axis.labels): card = axis.concept(name).cardinality if card == 1: query[name] = ground_truth[:, i].float().unsqueeze(-1) else: import torch.nn.functional as F query[name] = F.one_hot(ground_truth[:, i].long(), card).float() return query def forward( self, x: torch.Tensor = None, query=None, evidence: torch.Tensor = None, **kwargs ) -> ModelOutput: """Forward pass through the BlackBox model. Parameters ---------- x : torch.Tensor, optional Input tensor. When ``None``, the tensor is extracted from ``evidence['input']`` (used by :meth:`BaseLearner.shared_step`). query : list of str or dict, optional Concept names to return. Defaults to all concepts. When a dict is supplied (from ``build_query``), the keys are used as names. evidence : dict or torch.Tensor, optional Evidence dict (``{'input': x}`` from shared_step) or raw tensor (ignored for BlackBox). **kwargs Additional arguments (ignored). Returns ------- ModelOutput ``params[name]['logits']`` per queried concept (uniform with the PGM-based models). """ # Resolve the raw input tensor if x is None and isinstance(evidence, dict): x = evidence.get('input', None) output = self.linear(self.backbone(x)) axis = self.concept_annotations # query may be a list of strings, a dict (from build_query), or None if isinstance(query, dict): names = list(query.keys()) if query else axis.labels else: names = query if query is not None else axis.labels params = {name: {"logits": output[:, axis.concept_slices[name]]} for name in names} out = ModelOutput(params=params) # FIXME: update ModelOutput to generalize beyond logits out.logits = logits_from_params(params, keys=list(names)) return out
[docs] class BlackBoxTaskOnly(BaseModel): """ BlackBox model. This model implements a standard neural network architecture for predicting tasks only, without explicit concept bottleneck or interpretable intermediate representations. It uses a backbone mapping the raw input to the latent representation, then a linear head. Args: input_size (int): Dimensionality of input features. annotations (Annotations): Annotation object for output variables. task_names (Union[List[str], str]): Task names to predict. lightning (bool, optional): Enable Lightning training. Default False. **kwargs: Additional arguments for BaseModel. Attributes: task_annotations (Annotations): Sub-annotation restricted to task concepts only. Use this to build ``ConceptLoss`` / ``ConceptMetrics``. task_concept_idx (List[int]): Concept-level column indices used to slice the ground-truth target tensor to match the task-only output. Example: >>> from torch_concepts.annotations import Annotations >>> ann = Annotations(labels=['c1', 'task'], cardinalities=[1, 1]) >>> model = BlackBoxTaskOnly(input_size=8, annotations=ann, task_names=['task']) >>> out = model(torch.randn(2, 8)) """
[docs] def __init__( self, input_size: int, annotations: Annotations, task_names: Union[List[str], str], lightning: bool = False, **kwargs ) -> None: self.task_names = ensure_list(task_names) # Pre-compute task annotations before super().__init__ so that # setup_metrics (called by BaseLearner.__init__) can use them. self.task_annotations = annotations.subset(self.task_names) self.task_concept_idx = [ annotations.get_index(name) for name in self.task_names ] super().__init__( input_size=input_size, annotations=annotations, lightning=lightning, **kwargs ) # Rebuild loss with task-only annotations so index slicing matches # the task-only tensors produced by prepare_target. if isinstance(getattr(self, 'loss', None), ConceptLoss): task_ann = self.task_annotations self.loss = ConceptLoss( annotations=task_ann, binary=self.loss.fn_collection.get('binary'), categorical=self.loss.fn_collection.get('categorical'), continuous=self.loss.fn_collection.get('continuous'), binary_weights=self.loss._type_weights.get('binary'), categorical_weights=self.loss._type_weights.get('categorical'), continuous_weights=self.loss._type_weights.get('continuous'), ) # Logit-level output size from the task sub-annotation output_size = sum(self.task_annotations.cardinalities) self.linear = nn.Linear(self.latent_size, output_size)
def build_query(self, ground_truth) -> dict: """Build query dict mapping each *task* name to its ground-truth column. Parameters ---------- ground_truth : torch.Tensor Full concept-level ground truth, shape ``(batch, n_all_concepts)``. Returns ------- dict ``{task_name: tensor(batch, cardinality)}`` for every task. """ if ground_truth is None: return {name: None for name in self.task_names} axis = self.concept_annotations query = {} for idx, name in zip(self.task_concept_idx, self.task_names): card = axis.concept(name).cardinality if card == 1: query[name] = ground_truth[:, idx].float().unsqueeze(-1) else: import torch.nn.functional as F query[name] = F.one_hot(ground_truth[:, idx].long(), card).float() return query def forward(self, x: torch.Tensor = None, query=None, evidence=None, **kwargs ) -> ModelOutput: """Forward pass through the BlackBoxTaskOnly model. Parameters ---------- x : torch.Tensor, optional Input tensor. When ``None``, the tensor is extracted from ``evidence['input']`` (used by :meth:`BaseLearner.shared_step`). query : list of str or dict, optional Ignored; predictions are always returned for ``task_names``. evidence : dict or torch.Tensor, optional Evidence dict (``{'input': x}`` from shared_step) or raw tensor (ignored). **kwargs Additional arguments (ignored). Returns ------- ModelOutput ``params[name]['logits']`` per task (uniform with the PGM-based models). """ # Resolve the raw input tensor if x is None and isinstance(evidence, dict): x = evidence.get('input', None) output = self.linear(self.backbone(x)) # The linear head spans the task sub-annotation; slice it per task. slices = self.task_annotations.concept_slices params = {name: {"logits": output[:, slices[name]]} for name in self.task_names} out = ModelOutput(params=params) # FIXME: update ModelOutput to generalize beyond logits out.logits = logits_from_params(params, keys=list(self.task_names)) return out def prepare_target(self, target: torch.Tensor) -> torch.Tensor: """Slice target to task-only columns. Parameters ---------- target : torch.Tensor Full concept-level ground truth labels. Returns ------- torch.Tensor Target sliced to task columns only. """ return target[:, self.task_concept_idx] def setup_metrics(self, metrics: ConceptMetrics): """Rebuild metrics with task-only annotations. The base ``setup_metrics`` clones the original ``ConceptMetrics`` which was constructed with the *full* concept annotations. Because ``BlackBoxTaskOnly`` outputs only task logits, the internal index mappings would be misaligned. This override reconstructs the metrics using ``task_annotations`` so that indices match the task-only output. """ task_ann = self.task_annotations task_metrics = ConceptMetrics( annotations=task_ann, binary=metrics.fn_collection.get('binary'), categorical=metrics.fn_collection.get('categorical'), continuous=metrics.fn_collection.get('continuous'), summary=metrics.summary, per_concept=metrics.per_concept, ) super().setup_metrics(task_metrics)