Skip to content

Commit

Permalink
fix: refactor min similarity tracking for better performance
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Dec 13, 2022
1 parent cecc6b8 commit 256ca65
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 78 deletions.
46 changes: 46 additions & 0 deletions tensor_theorem_prover/prover/ProofContext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations
from copy import copy
from typing import Optional

from tensor_theorem_prover.prover.ProofStats import ProofStats

from .ProofStep import ProofStep


class ProofContext:
"""Helper class which accumulates successful proof steps and keeps track of stats during the proof process"""

max_proofs: Optional[int]
scored_proof_steps: list[tuple[float, ProofStep, ProofStats]]
min_similarity_threshold: float
stats: ProofStats

def __init__(
self,
initial_min_similarity_threshold: float = 0.0,
max_proofs: Optional[int] = None,
) -> None:
self.stats = ProofStats()
self.min_similarity_threshold = initial_min_similarity_threshold
self.max_proofs = max_proofs
self.scored_proof_steps = []

def record_leaf_proof(self, proof_step: ProofStep) -> None:
"""Add a leaf proof step to the accumulator"""

# TODO: Make combining similarities customizable rather than always taking the minimum
similarity = proof_step.similarity
cur_step = proof_step
while cur_step.parent:
similarity = min(similarity, cur_step.parent.similarity)
cur_step = cur_step.parent

# make sure to clone the stats before appending, since the stats will continue to get mutated after this
self.scored_proof_steps.append((similarity, proof_step, copy(self.stats)))
self.scored_proof_steps.sort(key=lambda x: x[0], reverse=True)
if self.max_proofs and len(self.scored_proof_steps) > self.max_proofs:
# Remove the proof step with the lowest similarity
self.scored_proof_steps.pop()
self.stats.discarded_proofs += 1
# Update the minimum similarity threshold to the new lowest similarity
self.min_similarity_threshold = self.scored_proof_steps[-1][0]
44 changes: 20 additions & 24 deletions tensor_theorem_prover/prover/ResolutionProver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tensor_theorem_prover.normalize import Skolemizer, CNFDisjunction, to_cnf
from tensor_theorem_prover.prover.Proof import Proof
from tensor_theorem_prover.prover.ProofStats import ProofStats
from tensor_theorem_prover.prover.ProofStepAccumulator import ProofStepAccumulator
from tensor_theorem_prover.prover.ProofContext import ProofContext
from tensor_theorem_prover.prover.operations.resolve import resolve
from tensor_theorem_prover.prover.ProofStep import ProofStep
from tensor_theorem_prover.similarity import (
Expand Down Expand Up @@ -94,27 +94,28 @@ def prove_all_with_stats(
parsed_extra_knowledge = self._parse_knowledge(extra_knowledge or [])
proofs = []
knowledge = self.base_knowledge + parsed_extra_knowledge + inverted_goals
leaf_proof_steps_acc = ProofStepAccumulator(max_proofs)
proof_stats = ProofStats()
ctx = ProofContext(
initial_min_similarity_threshold=self.min_similarity_threshold,
max_proofs=max_proofs,
)
similarity_func = self.similarity_func
if self.cache_similarity and self.similarity_func:
similarity_func = similarity_with_cache(
self.similarity_func, self.similarity_cache, proof_stats
self.similarity_func, self.similarity_cache, ctx.stats
)

for inverted_goal in inverted_goals:
self._prove_all_recursive(
inverted_goal,
knowledge,
similarity_func,
leaf_proof_steps_acc,
proof_stats,
ctx,
)
for (
similarity,
leaf_proof_step,
leaf_proof_stats,
) in leaf_proof_steps_acc.scored_proof_steps:
) in ctx.scored_proof_steps:
proofs.append(
Proof(
inverted_goal,
Expand All @@ -126,7 +127,7 @@ def prove_all_with_stats(

return (
sorted(proofs, key=lambda proof: proof.similarity, reverse=True),
proof_stats,
ctx.stats,
)

def purge_similarity_cache(self) -> None:
Expand All @@ -142,15 +143,15 @@ def _prove_all_recursive(
goal: CNFDisjunction,
knowledge: Iterable[CNFDisjunction],
similarity_func: Optional[SimilarityFunc],
leaf_proof_steps_acc: ProofStepAccumulator,
proof_stats: ProofStats,
ctx: ProofContext,
depth: int = 0,
parent_state: Optional[ProofStep] = None,
) -> None:
if parent_state and depth >= self.max_proof_depth:
return
if depth >= proof_stats.max_depth_seen:
proof_stats.max_depth_seen = depth
if depth >= ctx.stats.max_depth_seen:
# add 1 to match the depth stat seen in proofs. It's strange if the proof has depth 12, but max_depth_seen is 11
ctx.stats.max_depth_seen = depth + 1
for clause in knowledge:
# resolution always ends up removing a literal from the clause and the goal, and combining the remaining literals
# so we know what the length of the resolvent will be before we even try to resolve
Expand All @@ -160,35 +161,30 @@ def _prove_all_recursive(
> self.max_resolvent_width
):
continue
min_similarity_threshold = self.min_similarity_threshold
if leaf_proof_steps_acc.is_full():
min_similarity_threshold = leaf_proof_steps_acc.min_similarity
proof_stats.attempted_resolutions += 1
ctx.stats.attempted_resolutions += 1
next_states = resolve(
goal,
clause,
min_similarity_threshold=min_similarity_threshold,
similarity_func=similarity_func,
parent=parent_state,
proof_stats=proof_stats,
ctx=ctx,
)
if len(next_states) > 0:
proof_stats.successful_resolutions += 1
ctx.stats.successful_resolutions += 1
for next_state in next_states:
if next_state.resolvent is None:
raise ValueError("Resolvent was unexpectedly not present")
if len(next_state.resolvent.literals) == 0:
leaf_proof_steps_acc.add_proof(next_state, proof_stats)
ctx.record_leaf_proof(next_state)
else:
resolvent_width = len(next_state.resolvent.literals)
if resolvent_width >= proof_stats.max_resolvent_width_seen:
proof_stats.max_resolvent_width_seen = resolvent_width
if resolvent_width >= ctx.stats.max_resolvent_width_seen:
ctx.stats.max_resolvent_width_seen = resolvent_width
self._prove_all_recursive(
next_state.resolvent,
knowledge,
similarity_func,
leaf_proof_steps_acc,
proof_stats,
ctx,
depth + 1,
next_state,
)
14 changes: 5 additions & 9 deletions tensor_theorem_prover/prover/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from tensor_theorem_prover.normalize.to_cnf import CNFDisjunction, CNFLiteral
from tensor_theorem_prover.prover.ProofStats import ProofStats
from tensor_theorem_prover.prover.ProofContext import ProofContext
from tensor_theorem_prover.prover.ProofStep import ProofStep, SubstitutionsMap
from tensor_theorem_prover.similarity import SimilarityFunc
from tensor_theorem_prover.types import Atom, Term, Variable
Expand All @@ -14,10 +14,9 @@
def resolve(
source: CNFDisjunction,
target: CNFDisjunction,
min_similarity_threshold: float = 0.5,
ctx: ProofContext,
similarity_func: Optional[SimilarityFunc] = None,
parent: Optional[ProofStep] = None,
proof_stats: Optional[ProofStats] = None,
) -> list[ProofStep]:
"""Resolve a source and target CNF disjunction
Expand All @@ -35,18 +34,15 @@ def resolve(
# we can only resolve literals with the opposite polarity
if source_literal.polarity == target_literal.polarity:
continue
if proof_stats:
proof_stats.attempted_unifications += 1
ctx.stats.attempted_unifications += 1
unification = unify(
source_literal.atom,
target_literal.atom,
min_similarity_threshold,
ctx,
similarity_func,
proof_stats,
)
if unification:
if proof_stats:
proof_stats.successful_unifications += 1
ctx.stats.successful_unifications += 1

resolvent = _build_resolvent(
source, target, source_literal, target_literal, unification
Expand Down
31 changes: 10 additions & 21 deletions tensor_theorem_prover/prover/operations/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from typing import Dict, Iterable, Optional, Tuple
from typing_extensions import Literal
from tensor_theorem_prover.prover.ProofStats import ProofStats
from tensor_theorem_prover.prover.ProofContext import ProofContext

from tensor_theorem_prover.prover.ProofStep import SubstitutionsMap
from tensor_theorem_prover.similarity import SimilarityFunc, symbol_compare
Expand All @@ -19,9 +19,8 @@ class Unification:
def unify(
source: Atom,
target: Atom,
min_similarity_threshold: float = 0.5,
ctx: ProofContext,
similarity_func: Optional[SimilarityFunc] = None,
proof_stats: Optional[ProofStats] = None,
) -> Unification | None:
"""
Fuzzy-optional implementation of unify
Expand All @@ -36,20 +35,14 @@ def unify(
# if there is no comparison function provided, just use symbol compare (non-fuzzy comparisons)
adjusted_similarity_func = similarity_func or symbol_compare
similarity = adjusted_similarity_func(source.predicate, target.predicate)
if proof_stats:
proof_stats.similarity_comparisons += 1
ctx.stats.similarity_comparisons += 1

# abort early if the predicate similarity is too low
if similarity <= min_similarity_threshold:
if similarity <= ctx.min_similarity_threshold:
return None

return _unify_terms(
source.terms,
target.terms,
similarity,
adjusted_similarity_func,
min_similarity_threshold,
proof_stats,
source.terms, target.terms, similarity, adjusted_similarity_func, ctx
)


Expand All @@ -64,8 +57,7 @@ def _unify_terms(
target_terms: Iterable[Term],
similarity: float,
similarity_func: SimilarityFunc,
min_similarity_threshold: float,
proof_stats: Optional[ProofStats],
ctx: ProofContext,
) -> Unification | None:
"""
Unification with optional vector similarity, based on Robinson's 1965 algorithm, as described in:
Expand All @@ -81,8 +73,7 @@ def _unify_terms(
substitutions,
cur_similarity,
similarity_func,
min_similarity_threshold,
proof_stats,
ctx,
)
if result is None:
return None
Expand Down Expand Up @@ -152,8 +143,7 @@ def _unify_term_pair(
substitutions: SubstitutionSet,
similarity: float,
similarity_func: SimilarityFunc,
min_similarity_threshold: float,
proof_stats: Optional[ProofStats],
ctx: ProofContext,
) -> tuple[SubstitutionSet, float] | None:
"""
Check if a pair of terms can be unified, part of Robinson's 1965 algorithm
Expand Down Expand Up @@ -182,9 +172,8 @@ def _unify_term_pair(
# TODO: should we add a separate similarity func for constants which is bidirectional?
similarity_func(cur_source_term, cur_target_term),
)
if proof_stats:
proof_stats.similarity_comparisons += 1
if cur_similarity <= min_similarity_threshold:
ctx.stats.similarity_comparisons += 1
if cur_similarity <= ctx.min_similarity_threshold:
return None
elif isinstance(cur_source_term, Variable):
if isinstance(cur_target_term, Variable):
Expand Down
Loading

0 comments on commit 256ca65

Please sign in to comment.