Source code for torch_concepts.nn.modules.low.base.inference
"""
Base inference and intervention classes for concept-based models.
This module provides abstract base classes for implementing inference mechanisms
and intervention strategies in concept-based models.
"""
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
[docs]
class BaseInference(torch.nn.Module):
"""
Abstract base class for inference modules.
Inference modules define how to query concept-based models to obtain
concept predictions, supporting various inference strategies such as
forward inference, ancestral sampling, or stochastic inference.
Example:
>>> import torch
>>> from torch_concepts.nn import BaseInference
>>>
>>> # Create a custom inference class
>>> class SimpleInference(BaseInference):
... def __init__(self, model):
... super().__init__()
... self.model = model
...
... def query(self, x, **kwargs):
... # Simple forward pass through model
... return self.model(x)
>>>
>>> # Example usage
>>> dummy_model = torch.nn.Linear(10, 5)
>>> inference = SimpleInference(dummy_model)
>>>
>>> # Generate random input
>>> x = torch.randn(2, 10) # batch_size=2, input_features=10
>>>
>>> # Query concepts using forward method
>>> concepts = inference(x)
>>> print(concepts.shape) # torch.Size([2, 5])
>>>
>>> # Or use query method directly
>>> concepts = inference.query(x)
>>> print(concepts.shape) # torch.Size([2, 5])
"""
[docs]
def __init__(self):
"""Initialize the inference module."""
super(BaseInference, self).__init__()
[docs]
def forward(self,
x: torch.Tensor,
*args,
**kwargs) -> torch.Tensor:
"""
Forward pass delegates to the query method.
Args:
x: Input tensor.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: Queried concepts.
"""
return self.query(x, *args, **kwargs)
[docs]
@abstractmethod
def query(self,
*args,
**kwargs) -> torch.Tensor:
"""
Query model to get concepts.
This method must be implemented by subclasses to define the
specific inference strategy.
Args:
*args: Variable length argument list (typically includes input x).
**kwargs: Arbitrary keyword arguments (may include intervention c).
Returns:
torch.Tensor: Queried concept predictions.
Raises:
NotImplementedError: This is an abstract method.
"""
raise NotImplementedError
[docs]
class BaseIntervention(BaseInference, ABC):
"""
Abstract base class for intervention modules.
Intervention modules modify concept-based models by replacing certain
modules, enabling causal reasoning and what-if analysis.
This class provides a framework for implementing different intervention
strategies on concept-based models.
Attributes:
model (nn.Module): The concept-based model to apply interventions to.
Args:
model: The neural network model to intervene on.
Example:
>>> import torch
>>> import torch.nn as nn
>>> from torch_concepts.nn import BaseIntervention
>>>
>>> # Create a custom intervention class
>>> class CustomIntervention(BaseIntervention):
... def query(self, module_name, **kwargs):
... # Get the module to intervene on
... module = self.model.get_submodule(module_name)
... # Apply intervention logic
... return module(**kwargs)
>>>
>>> # Create a simple concept model
>>> class ConceptModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.encoder = nn.Linear(10, 5)
... self.predictor = nn.Linear(5, 3)
...
... def forward(self, x):
... concepts = torch.sigmoid(self.encoder(x))
... return self.predictor(concepts)
>>>
>>> # Example usage
>>> model = ConceptModel()
>>> intervention = CustomIntervention(model)
>>>
>>> # Generate random input
>>> x = torch.randn(2, 10) # batch_size=2, input_features=10
>>>
>>> # Query encoder module
>>> encoder_output = intervention.query('encoder', input=x)
>>> print(encoder_output.shape) # torch.Size([2, 5])
"""
[docs]
def __init__(self, model: nn.Module):
"""
Initialize the intervention module.
Args:
model (nn.Module): The concept-based model to apply interventions to.
"""
super().__init__()
self.model = model