torch_concepts.nn.functional.linear_equation_eval

linear_equation_eval(concept_weights: Tensor, c_pred: Tensor, bias: Tensor | None = None) Tensor[source]

Function to evaluate a set of linear equations with concept predictions. In this case we have one equation (concept_weights) for each sample in the batch.

Parameters:
  • concept_weights – Parameters representing the weights of multiple linear models with shape (batch_size, memory_size, n_concepts, n_classes).

  • c_pred – Concept predictions with shape (batch_size, n_concepts).

  • bias – Bias term to add to the linear models (batch_size, memory_size, n_classes).

Returns:

Predictions made by the linear models with shape (batch_size,

n_classes, memory_size).

Return type:

Tensor