torch_concepts.nn.functional.linear_equation_eval¶
- linear_equation_eval(concept_weights: Tensor, c_pred: Tensor, bias: Tensor | None = None) Tensor[source]¶
Function to evaluate a set of linear equations with concept predictions. In this case we have one equation (concept_weights) for each sample in the batch.
- Parameters:
concept_weights – Parameters representing the weights of multiple linear models with shape (batch_size, memory_size, n_concepts, n_classes).
c_pred – Concept predictions with shape (batch_size, n_concepts).
bias – Bias term to add to the linear models (batch_size, memory_size, n_classes).
- Returns:
- Predictions made by the linear models with shape (batch_size,
n_classes, memory_size).
- Return type:
Tensor