diff --git a/tensor_theorem_prover/prover/ProofContext.py b/tensor_theorem_prover/prover/ProofContext.py new file mode 100644 index 0000000..a7b406d --- /dev/null +++ b/tensor_theorem_prover/prover/ProofContext.py @@ -0,0 +1,46 @@ +from __future__ import annotations +from copy import copy +from typing import Optional + +from tensor_theorem_prover.prover.ProofStats import ProofStats + +from .ProofStep import ProofStep + + +class ProofContext: + """Helper class which accumulates successful proof steps and keeps track of stats during the proof process""" + + max_proofs: Optional[int] + scored_proof_steps: list[tuple[float, ProofStep, ProofStats]] + min_similarity_threshold: float + stats: ProofStats + + def __init__( + self, + initial_min_similarity_threshold: float = 0.0, + max_proofs: Optional[int] = None, + ) -> None: + self.stats = ProofStats() + self.min_similarity_threshold = initial_min_similarity_threshold + self.max_proofs = max_proofs + self.scored_proof_steps = [] + + def record_leaf_proof(self, proof_step: ProofStep) -> None: + """Add a leaf proof step to the accumulator""" + + # TODO: Make combining similarities customizable rather than always taking the minimum + similarity = proof_step.similarity + cur_step = proof_step + while cur_step.parent: + similarity = min(similarity, cur_step.parent.similarity) + cur_step = cur_step.parent + + # make sure to clone the stats before appending, since the stats will continue to get mutated after this + self.scored_proof_steps.append((similarity, proof_step, copy(self.stats))) + self.scored_proof_steps.sort(key=lambda x: x[0], reverse=True) + if self.max_proofs and len(self.scored_proof_steps) > self.max_proofs: + # Remove the proof step with the lowest similarity + self.scored_proof_steps.pop() + self.stats.discarded_proofs += 1 + # Update the minimum similarity threshold to the new lowest similarity + self.min_similarity_threshold = self.scored_proof_steps[-1][0] diff --git a/tensor_theorem_prover/prover/ResolutionProver.py b/tensor_theorem_prover/prover/ResolutionProver.py index c2b0912..36af3b4 100644 --- a/tensor_theorem_prover/prover/ResolutionProver.py +++ b/tensor_theorem_prover/prover/ResolutionProver.py @@ -5,7 +5,7 @@ from tensor_theorem_prover.normalize import Skolemizer, CNFDisjunction, to_cnf from tensor_theorem_prover.prover.Proof import Proof from tensor_theorem_prover.prover.ProofStats import ProofStats -from tensor_theorem_prover.prover.ProofStepAccumulator import ProofStepAccumulator +from tensor_theorem_prover.prover.ProofContext import ProofContext from tensor_theorem_prover.prover.operations.resolve import resolve from tensor_theorem_prover.prover.ProofStep import ProofStep from tensor_theorem_prover.similarity import ( @@ -94,12 +94,14 @@ def prove_all_with_stats( parsed_extra_knowledge = self._parse_knowledge(extra_knowledge or []) proofs = [] knowledge = self.base_knowledge + parsed_extra_knowledge + inverted_goals - leaf_proof_steps_acc = ProofStepAccumulator(max_proofs) - proof_stats = ProofStats() + ctx = ProofContext( + initial_min_similarity_threshold=self.min_similarity_threshold, + max_proofs=max_proofs, + ) similarity_func = self.similarity_func if self.cache_similarity and self.similarity_func: similarity_func = similarity_with_cache( - self.similarity_func, self.similarity_cache, proof_stats + self.similarity_func, self.similarity_cache, ctx.stats ) for inverted_goal in inverted_goals: @@ -107,14 +109,13 @@ def prove_all_with_stats( inverted_goal, knowledge, similarity_func, - leaf_proof_steps_acc, - proof_stats, + ctx, ) for ( similarity, leaf_proof_step, leaf_proof_stats, - ) in leaf_proof_steps_acc.scored_proof_steps: + ) in ctx.scored_proof_steps: proofs.append( Proof( inverted_goal, @@ -126,7 +127,7 @@ def prove_all_with_stats( return ( sorted(proofs, key=lambda proof: proof.similarity, reverse=True), - proof_stats, + ctx.stats, ) def purge_similarity_cache(self) -> None: @@ -142,15 +143,15 @@ def _prove_all_recursive( goal: CNFDisjunction, knowledge: Iterable[CNFDisjunction], similarity_func: Optional[SimilarityFunc], - leaf_proof_steps_acc: ProofStepAccumulator, - proof_stats: ProofStats, + ctx: ProofContext, depth: int = 0, parent_state: Optional[ProofStep] = None, ) -> None: if parent_state and depth >= self.max_proof_depth: return - if depth >= proof_stats.max_depth_seen: - proof_stats.max_depth_seen = depth + if depth >= ctx.stats.max_depth_seen: + # add 1 to match the depth stat seen in proofs. It's strange if the proof has depth 12, but max_depth_seen is 11 + ctx.stats.max_depth_seen = depth + 1 for clause in knowledge: # resolution always ends up removing a literal from the clause and the goal, and combining the remaining literals # so we know what the length of the resolvent will be before we even try to resolve @@ -160,35 +161,30 @@ def _prove_all_recursive( > self.max_resolvent_width ): continue - min_similarity_threshold = self.min_similarity_threshold - if leaf_proof_steps_acc.is_full(): - min_similarity_threshold = leaf_proof_steps_acc.min_similarity - proof_stats.attempted_resolutions += 1 + ctx.stats.attempted_resolutions += 1 next_states = resolve( goal, clause, - min_similarity_threshold=min_similarity_threshold, similarity_func=similarity_func, parent=parent_state, - proof_stats=proof_stats, + ctx=ctx, ) if len(next_states) > 0: - proof_stats.successful_resolutions += 1 + ctx.stats.successful_resolutions += 1 for next_state in next_states: if next_state.resolvent is None: raise ValueError("Resolvent was unexpectedly not present") if len(next_state.resolvent.literals) == 0: - leaf_proof_steps_acc.add_proof(next_state, proof_stats) + ctx.record_leaf_proof(next_state) else: resolvent_width = len(next_state.resolvent.literals) - if resolvent_width >= proof_stats.max_resolvent_width_seen: - proof_stats.max_resolvent_width_seen = resolvent_width + if resolvent_width >= ctx.stats.max_resolvent_width_seen: + ctx.stats.max_resolvent_width_seen = resolvent_width self._prove_all_recursive( next_state.resolvent, knowledge, similarity_func, - leaf_proof_steps_acc, - proof_stats, + ctx, depth + 1, next_state, ) diff --git a/tensor_theorem_prover/prover/operations/resolve.py b/tensor_theorem_prover/prover/operations/resolve.py index 9ac1e34..09ad783 100644 --- a/tensor_theorem_prover/prover/operations/resolve.py +++ b/tensor_theorem_prover/prover/operations/resolve.py @@ -3,7 +3,7 @@ from typing import Optional from tensor_theorem_prover.normalize.to_cnf import CNFDisjunction, CNFLiteral -from tensor_theorem_prover.prover.ProofStats import ProofStats +from tensor_theorem_prover.prover.ProofContext import ProofContext from tensor_theorem_prover.prover.ProofStep import ProofStep, SubstitutionsMap from tensor_theorem_prover.similarity import SimilarityFunc from tensor_theorem_prover.types import Atom, Term, Variable @@ -14,10 +14,9 @@ def resolve( source: CNFDisjunction, target: CNFDisjunction, - min_similarity_threshold: float = 0.5, + ctx: ProofContext, similarity_func: Optional[SimilarityFunc] = None, parent: Optional[ProofStep] = None, - proof_stats: Optional[ProofStats] = None, ) -> list[ProofStep]: """Resolve a source and target CNF disjunction @@ -35,18 +34,15 @@ def resolve( # we can only resolve literals with the opposite polarity if source_literal.polarity == target_literal.polarity: continue - if proof_stats: - proof_stats.attempted_unifications += 1 + ctx.stats.attempted_unifications += 1 unification = unify( source_literal.atom, target_literal.atom, - min_similarity_threshold, + ctx, similarity_func, - proof_stats, ) if unification: - if proof_stats: - proof_stats.successful_unifications += 1 + ctx.stats.successful_unifications += 1 resolvent = _build_resolvent( source, target, source_literal, target_literal, unification diff --git a/tensor_theorem_prover/prover/operations/unify.py b/tensor_theorem_prover/prover/operations/unify.py index 42d9db2..45d1305 100644 --- a/tensor_theorem_prover/prover/operations/unify.py +++ b/tensor_theorem_prover/prover/operations/unify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Optional, Tuple from typing_extensions import Literal -from tensor_theorem_prover.prover.ProofStats import ProofStats +from tensor_theorem_prover.prover.ProofContext import ProofContext from tensor_theorem_prover.prover.ProofStep import SubstitutionsMap from tensor_theorem_prover.similarity import SimilarityFunc, symbol_compare @@ -19,9 +19,8 @@ class Unification: def unify( source: Atom, target: Atom, - min_similarity_threshold: float = 0.5, + ctx: ProofContext, similarity_func: Optional[SimilarityFunc] = None, - proof_stats: Optional[ProofStats] = None, ) -> Unification | None: """ Fuzzy-optional implementation of unify @@ -36,20 +35,14 @@ def unify( # if there is no comparison function provided, just use symbol compare (non-fuzzy comparisons) adjusted_similarity_func = similarity_func or symbol_compare similarity = adjusted_similarity_func(source.predicate, target.predicate) - if proof_stats: - proof_stats.similarity_comparisons += 1 + ctx.stats.similarity_comparisons += 1 # abort early if the predicate similarity is too low - if similarity <= min_similarity_threshold: + if similarity <= ctx.min_similarity_threshold: return None return _unify_terms( - source.terms, - target.terms, - similarity, - adjusted_similarity_func, - min_similarity_threshold, - proof_stats, + source.terms, target.terms, similarity, adjusted_similarity_func, ctx ) @@ -64,8 +57,7 @@ def _unify_terms( target_terms: Iterable[Term], similarity: float, similarity_func: SimilarityFunc, - min_similarity_threshold: float, - proof_stats: Optional[ProofStats], + ctx: ProofContext, ) -> Unification | None: """ Unification with optional vector similarity, based on Robinson's 1965 algorithm, as described in: @@ -81,8 +73,7 @@ def _unify_terms( substitutions, cur_similarity, similarity_func, - min_similarity_threshold, - proof_stats, + ctx, ) if result is None: return None @@ -152,8 +143,7 @@ def _unify_term_pair( substitutions: SubstitutionSet, similarity: float, similarity_func: SimilarityFunc, - min_similarity_threshold: float, - proof_stats: Optional[ProofStats], + ctx: ProofContext, ) -> tuple[SubstitutionSet, float] | None: """ Check if a pair of terms can be unified, part of Robinson's 1965 algorithm @@ -182,9 +172,8 @@ def _unify_term_pair( # TODO: should we add a separate similarity func for constants which is bidirectional? similarity_func(cur_source_term, cur_target_term), ) - if proof_stats: - proof_stats.similarity_comparisons += 1 - if cur_similarity <= min_similarity_threshold: + ctx.stats.similarity_comparisons += 1 + if cur_similarity <= ctx.min_similarity_threshold: return None elif isinstance(cur_source_term, Variable): if isinstance(cur_target_term, Variable): diff --git a/tests/prover/operations/test_unify.py b/tests/prover/operations/test_unify.py index 1b65492..80686b2 100644 --- a/tests/prover/operations/test_unify.py +++ b/tests/prover/operations/test_unify.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from tensor_theorem_prover.prover.ProofContext import ProofContext from tensor_theorem_prover.prover.operations.unify import unify, Unification from tensor_theorem_prover.similarity import cosine_similarity @@ -25,88 +26,93 @@ Z = Variable("Z") +ctx = lambda: ProofContext(initial_min_similarity_threshold=0.5) + + def test_unify_with_all_constants() -> None: source = pred1(const1, const2) target = pred1(const1, const2) - assert unify(source, target) == Unification({}, {}) + assert unify(source, target, ctx()) == Unification({}, {}) def test_unify_fails_if_preds_dont_match() -> None: source = pred1(const1, const2) target = pred2(const1, const2) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_fails_if_terms_dont_match() -> None: source = pred1(const2, const2) target = pred1(const1, const2) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_fails_if_functions_dont_match() -> None: source = pred1(func1(X)) target = pred1(func2(Y)) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_fails_if_functions_take_different_number_of_params() -> None: source = pred1(func1(X, Y)) target = pred1(func1(X)) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_fails_if_terms_have_differing_lengths() -> None: source = pred1(const1) target = pred1(const1, const2) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_with_source_var_to_target_const() -> None: source = pred1(X, const1) target = pred1(const2, const1) - assert unify(source, target) == Unification({X: const2}, {}) + assert unify(source, target, ctx()) == Unification({X: const2}, {}) def test_unify_with_source_const_to_target_var() -> None: source = pred1(const2, const1) target = pred1(X, const1) - assert unify(source, target) == Unification({}, {X: const2}) + assert unify(source, target, ctx()) == Unification({}, {X: const2}) def test_unify_with_source_var_to_target_var() -> None: source = pred1(X, const1) target = pred1(Y, const1) - assert unify(source, target) == Unification({}, {Y: X}) + assert unify(source, target, ctx()) == Unification({}, {Y: X}) def test_unify_with_repeated_vars_in_source() -> None: source = pred1(X, X) target = pred1(Y, const1) - assert unify(source, target) == Unification({X: const1}, {Y: const1}) + assert unify(source, target, ctx()) == Unification({X: const1}, {Y: const1}) def test_unify_with_repeated_vars_in_target() -> None: source = pred1(X, const1) target = pred1(Y, Y) - assert unify(source, target) == Unification({X: const1}, {Y: const1}) + assert unify(source, target, ctx()) == Unification({X: const1}, {Y: const1}) def test_unify_fails_with_unfulfilable_constraints() -> None: source = pred1(X, X) target = pred1(const1, const2) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_with_source_var_to_target_var_with_repeat_constants() -> None: source = pred1(X, X, X, X) target = pred1(const1, Y, Z, const1) - assert unify(source, target) == Unification({X: const1}, {Y: const1, Z: const1}) + assert unify(source, target, ctx()) == Unification( + {X: const1}, {Y: const1, Z: const1} + ) def test_unify_with_chained_vars() -> None: source = pred1(X, X, Y, Y, Z, Z) target = pred1(Y, X, X, Z, Z, const2) - assert unify(source, target) == Unification( + assert unify(source, target, ctx()) == Unification( {X: const2, Y: const2, Z: const2}, {X: const2, Y: const2, Z: const2} ) @@ -114,37 +120,37 @@ def test_unify_with_chained_vars() -> None: def test_unify_with_function_map_var_to_const() -> None: source = pred1(func1(X)) target = pred1(func1(const1)) - assert unify(source, target) == Unification({X: const1}, {}) + assert unify(source, target, ctx()) == Unification({X: const1}, {}) def test_unify_with_function_map_var_to_var() -> None: source = pred1(func1(X)) target = pred1(func1(Y)) - assert unify(source, target) == Unification({}, {Y: X}) + assert unify(source, target, ctx()) == Unification({}, {Y: X}) def test_unify_with_function_map_var_to_var_with_repeat_constants() -> None: source = pred1(func1(X, X)) target = pred1(func1(const1, Y)) - assert unify(source, target) == Unification({X: const1}, {Y: const1}) + assert unify(source, target, ctx()) == Unification({X: const1}, {Y: const1}) def test_unify_with_function_map_var_to_var_with_repeat_constants2() -> None: source = pred1(func1(const1, Y)) target = pred1(func1(X, X)) - assert unify(source, target) == Unification({Y: const1}, {X: const1}) + assert unify(source, target, ctx()) == Unification({Y: const1}, {X: const1}) def test_unify_bind_nested_function_var() -> None: source = pred1(func1(X)) target = pred1(func1(func2(const1))) - assert unify(source, target) == Unification({X: func2(const1)}, {}) + assert unify(source, target, ctx()) == Unification({X: func2(const1)}, {}) def test_unify_fails_to_bind_reciprocal_functions() -> None: source = pred1(func1(X), X) target = pred1(Y, func1(Y)) - assert unify(source, target) is None + assert unify(source, target, ctx()) is None def test_unify_with_predicate_vector_embeddings() -> None: @@ -152,7 +158,7 @@ def test_unify_with_predicate_vector_embeddings() -> None: vec_pred2 = Predicate("pred2", np.array([1, 0, 0.9, 1])) source = vec_pred1(X) target = vec_pred2(const1) - unification = unify(source, target, similarity_func=cosine_similarity) + unification = unify(source, target, ctx(), similarity_func=cosine_similarity) assert unification is not None assert unification.source_substitutions == {X: const1} assert unification.target_substitutions == {} @@ -164,7 +170,7 @@ def test_unify_fails_with_dissimilar_predicate_vector_embeddings() -> None: vec_pred2 = Predicate("pred2", np.array([1, 0, 0.3, 1])) source = vec_pred1(X) target = vec_pred2(const1) - assert unify(source, target, similarity_func=cosine_similarity) is None + assert unify(source, target, ctx(), similarity_func=cosine_similarity) is None def test_unify_with_constant_vector_embeddings() -> None: @@ -172,7 +178,7 @@ def test_unify_with_constant_vector_embeddings() -> None: vec_const2 = Constant("const2", np.array([1, 0, 0.9, 1])) source = pred1(vec_const1) target = pred1(vec_const2) - unification = unify(source, target, similarity_func=cosine_similarity) + unification = unify(source, target, ctx(), similarity_func=cosine_similarity) assert unification is not None assert unification.source_substitutions == {} assert unification.target_substitutions == {} @@ -184,4 +190,4 @@ def test_unify_fails_with_dissimilar_constant_vector_embeddings() -> None: vec_const2 = Constant("const2", np.array([1, 0, 0.3, 1])) source = pred1(vec_const1) target = pred1(vec_const2) - assert unify(source, target, similarity_func=cosine_similarity) is None + assert unify(source, target, ctx(), similarity_func=cosine_similarity) is None