torch_concepts.nn.functional.logic_rule_eval¶
- logic_rule_eval(concept_weights: ~torch.Tensor, c_pred: ~torch.Tensor, memory_idxs: ~torch.Tensor | None = None, semantic=<torch_concepts.nn.modules.low.semantic.CMRSemantic object>) Tensor[source]¶
Use concept weights to make predictions based on logic rules.
- Parameters:
concept_weights – concept weights with shape (batch_size, memory_size, n_concepts, n_tasks, n_roles) with n_roles=3.
c_pred – concept predictions with shape (batch_size, n_concepts).
memory_idxs – Indices of rules to evaluate with shape (batch_size, n_tasks). Default is None (evaluate all).
semantic – Semantic function to use for rule evaluation.
- Returns:
- Rule predictions with shape (batch_size, n_tasks,
memory_size)
- Return type: