Source code for torch_concepts.nn.modules.low.sequential
"""
Annotated sequential container for concept-based pipelines.
"""
from typing import Optional, Union
import torch
from torch_concepts.annotations import Annotations
from torch_concepts.tensor import AnnotatedTensor
[docs]
class Sequential(torch.nn.Sequential):
"""``nn.Sequential`` whose **first** module may take multiple inputs.
Standard ``nn.Sequential`` threads one tensor through the chain, so its first
layer cannot be a multi-input PyC layer such as
:class:`~torch_concepts.nn.MixConceptEmbeddingToConcept` (``forward(concepts,
embeddings)``). This subclass forwards **all** of its inputs to the first
module, then threads that module's single output through the rest — while a
single-tensor ``seq(x)`` still behaves exactly like ``nn.Sequential``.
If ``out_concepts`` (an :class:`~torch_concepts.Annotations`) is set,
:meth:`annotate` wraps an output in an
:class:`~torch_concepts.tensor.AnnotatedTensor` to label its columns.
"""
[docs]
def __init__(self, *args, out_concepts: Optional[Annotations] = None, **kwargs):
super().__init__(*args, **kwargs)
self.out_concepts = out_concepts
def forward(self, *args, **kwargs):
it = iter(self)
try:
output = next(it)(*args, **kwargs) # first layer takes all inputs
except StopIteration: # empty container: mirror nn.Sequential's identity
return args[0] if len(args) == 1 and not kwargs else None
for module in it:
output = module(output) # the rest are single-tensor
return output
def annotate(self, x, out_concepts: Annotations = None) -> AnnotatedTensor:
if out_concepts is None:
if isinstance(self.out_concepts, Annotations):
out_concepts = self.out_concepts
else:
return x
return AnnotatedTensor(x, out_concepts)