diff --git a/tensor_theorem_prover/prover/ResolutionProver.py b/tensor_theorem_prover/prover/ResolutionProver.py index b299188..4e21525 100644 --- a/tensor_theorem_prover/prover/ResolutionProver.py +++ b/tensor_theorem_prover/prover/ResolutionProver.py @@ -90,6 +90,11 @@ def prove_all( def purge_similarity_cache(self) -> None: self.similarity_cache.clear() + def reset(self) -> None: + """Clear all knowledge from the prover and wipe the similarity cache""" + self.base_knowledge = [] + self.purge_similarity_cache() + def _prove_all_recursive( self, goal: CNFDisjunction, diff --git a/tests/prover/test_ResolutionProver.py b/tests/prover/test_ResolutionProver.py index aa7daa9..f4d418e 100644 --- a/tests/prover/test_ResolutionProver.py +++ b/tests/prover/test_ResolutionProver.py @@ -200,3 +200,11 @@ def test_purge_similarity_cache() -> None: prover.similarity_cache = {(1, 2): 0.5} prover.purge_similarity_cache() assert prover.similarity_cache == {} + + +def test_reset() -> None: + prover = ResolutionProver(knowledge=[parent_of(homer, bart)]) + prover.similarity_cache = {(1, 2): 0.5} + prover.reset() + assert prover.similarity_cache == {} + assert prover.base_knowledge == []