torch_concepts.nn.functional.logic_memory_reconstruction

logic_memory_reconstruction(concept_weights: Tensor, c_true: Tensor, y_true: Tensor) Tensor[source]

Reconstruct tasks based on concept reconstructions, ground truth concepts and ground truth tasks.

Parameters:
  • concept_weights – concept reconstructions with shape (batch_size, memory_size, n_concepts, n_tasks).

  • c_true – concept ground truth with shape (batch_size, n_concepts).

  • y_true – task ground truth with shape (batch_size, n_tasks).

Returns:

Reconstructed tasks with shape (batch_size, n_tasks,

memory_size).

Return type:

torch.Tensor