Source code for torch_concepts.nn.modules.low.predictors.linear
"""
Linear predictor modules for concept-based models.
This module provides linear prediction layers that transform concept
representations into new concept representations using a linear layer.
"""
from typing import Union
import torch
from torch_concepts import Annotations
from ..base.layer import BaseConceptLayer
from ....functional import prune_linear_layer
[docs]
class LinearConceptToConcept(BaseConceptLayer):
"""
Linear concept predictor.
This predictor transforms input concept representations into other concept
representations using a linear layer.
Attributes:
in_concepts (int): Number of input concept representations.
out_concepts (int): Number of output concept representations.
predictor (nn.Sequential): The prediction network.
Args:
in_concepts: Number of input concept representations.
out_concepts: Number of output concept representations.
Example:
>>> import torch
>>> from torch_concepts.nn import LinearConceptToConcept
>>>
>>> # Create predictor
>>> predictor = LinearConceptToConcept(
... in_concepts=10,
... out_concepts=5
... )
>>>
>>> # Forward pass
>>> in_concepts = torch.rand(2, 10) # batch_size=2, in_concepts=10
>>> out_concepts = predictor(in_concepts)
>>> print(out_concepts.shape)
torch.Size([2, 5])
References:
Koh et al. "Concept Bottleneck Models", ICML 2020.
https://arxiv.org/pdf/2007.04612
"""
[docs]
def __init__(
self,
in_concepts: Union[int, Annotations],
out_concepts: Union[int, Annotations],
*args,
**kwargs,
):
"""
Initialize the predictor.
Args:
in_concepts: Number of input concept representations.
out_concepts: Number of output concept representations.
"""
super().__init__(
in_concepts=in_concepts,
out_concepts=out_concepts,
)
self.predictor = torch.nn.Linear(
in_concepts,
out_concepts,
*args,
**kwargs,
)
def forward(
self,
concepts: torch.Tensor
) -> torch.Tensor:
"""
Forward pass through the predictor.
Args:
concepts: Input concepts of shape (..., in_concepts).
Returns:
torch.Tensor: Predicted concept probabilities of shape (..., out_concepts).
"""
return self.predictor(concepts)
def prune(self, mask: torch.Tensor):
"""
Prune input features based on a binary mask.
Removes input features where mask is False/0, reducing model complexity.
Args:
mask: Binary mask of shape (in_concepts,) indicating which
features to keep (True/1) or remove (False/0).
Example:
>>> import torch
>>> from torch_concepts.nn import LinearConceptToConcept
>>>
>>> predictor = LinearConceptToConcept(in_concepts=10, out_concepts=5)
>>>
>>> # Prune first 3 features
>>> mask = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=torch.bool)
>>> predictor.prune(mask)
>>>
>>> # Now only accepts 7 input features
>>> concepts = torch.randn(2, 7)
>>> probs = predictor(concepts)
>>> print(probs.shape)
torch.Size([2, 5])
"""
self.in_concepts = sum(mask.int())
self.predictor = prune_linear_layer(self.predictor, mask, dim=0)