torch_concepts.nn.functional.logic_rule_eval

logic_rule_eval(concept_weights: ~torch.Tensor, c_pred: ~torch.Tensor, memory_idxs: ~torch.Tensor | None = None, semantic=<torch_concepts.nn.modules.low.semantic.CMRSemantic object>) Tensor[source]

Use concept weights to make predictions based on logic rules.

Parameters:
  • concept_weights – concept weights with shape (batch_size, memory_size, n_concepts, n_tasks, n_roles) with n_roles=3.

  • c_pred – concept predictions with shape (batch_size, n_concepts).

  • memory_idxs – Indices of rules to evaluate with shape (batch_size, n_tasks). Default is None (evaluate all).

  • semantic – Semantic function to use for rule evaluation.

Returns:

Rule predictions with shape (batch_size, n_tasks,

memory_size)

Return type:

torch.Tensor