From 5cbde092e75da8cc1631f50b4152fe7c048037ed Mon Sep 17 00:00:00 2001 From: AlexShefY Date: Tue, 10 Dec 2024 15:44:43 +0100 Subject: [PATCH] fixes --- tests/test_dafny.py | 35 ++++++++- tests/test_nagini.py | 80 +++++++++++++++----- tests/test_pure_calls.py | 2 +- tests/test_verus.py | 15 +++- verified_cogen/runners/languages/dafny.py | 6 +- verified_cogen/runners/languages/language.py | 24 ++++-- verified_cogen/runners/languages/nagini.py | 11 ++- verified_cogen/runners/languages/verus.py | 12 ++- verified_cogen/runners/validating.py | 4 +- 9 files changed, 143 insertions(+), 46 deletions(-) diff --git a/tests/test_dafny.py b/tests/test_dafny.py index 0054023..8407dc3 100644 --- a/tests/test_dafny.py +++ b/tests/test_dafny.py @@ -22,7 +22,7 @@ def test_dafny_generate(): result := value * 2; }""" ) - assert dafny_lang.generate_validators(code) == dedent( + assert dafny_lang.generate_validators(code, True) == dedent( """\ method main_valid(value: int) returns (result: int) requires value >= 10 @@ -32,6 +32,36 @@ def test_dafny_generate(): ) +def test_dafny_generate_with_helper(): + dafny_lang = LanguageDatabase().get("dafny") + code = dedent( + """\ + function abs(n: int) returns (k: nat) { if n > 0 then n else -n } + + method main(value: int) returns (result: int) + requires value >= 10 + ensures result >= 20 + { + assert value * 2 >= 20; // assert-line + result := value * 2; + }""" + ) + assert dafny_lang.generate_validators(code, True) == dedent( + """\ + function abs_valid_pure(n: int) returns (k: nat) { + if n > 0 then n else -n + } + + method main_valid(value: int) returns (result: int) + requires value >= 10 + ensures result >= 20 + { var ret0 := main(value); return ret0; } + + method abs_valid(n: int) returns (k: nat) { var ret0 := abs(n); return ret0; } + """ + ) + + def test_dafny_generate_multiple_returns(): dafny_lang = LanguageDatabase().get("dafny") code = dedent( @@ -46,7 +76,7 @@ def test_dafny_generate_multiple_returns(): result2 := value * 3; }""" ) - assert dafny_lang.generate_validators(code) == dedent( + assert dafny_lang.generate_validators(code, True) == dedent( """\ method main_valid(value: int) returns (result: int, result2: int) requires value >= 10 @@ -167,3 +197,4 @@ def test_remove_all(): } }""" ) + diff --git a/tests/test_nagini.py b/tests/test_nagini.py index 1e0954a..4365728 100644 --- a/tests/test_nagini.py +++ b/tests/test_nagini.py @@ -22,11 +22,13 @@ def main(value: int) -> int: return value * 2 # impl-end""" ) - assert nagini_lang.generate_validators(code) == dedent( + assert nagini_lang.generate_validators(code, True) == dedent( """\ + def main_valid(value: int) -> int: Requires(value >= 10) Ensures(Result() >= 20) + ret = main(value) return ret""" ) @@ -48,8 +50,9 @@ def main(value: int) -> int: return value * 2 # impl-end""" ) - assert nagini_lang.generate_validators(code) == dedent( + assert nagini_lang.generate_validators(code, True) == dedent( """\ + def main_valid(value: int) -> int: # pre-conditions-start Requires(value >= 10) @@ -57,6 +60,7 @@ def main_valid(value: int) -> int: # post-conditions-start Ensures(Result() >= 20) # post-conditions-end + ret = main(value) return ret""" ) @@ -185,21 +189,21 @@ def test_nagini_large(): @Pure def lower(c : int) -> bool : - # impl-start + # pure-start return ((0) <= (c)) and ((c) <= (25)) - # impl-end + # pure-end @Pure def upper(c : int) -> bool : - # impl-start + # pure-start return ((26) <= (c)) and ((c) <= (51)) - # impl-end + # pure-end @Pure def alpha(c : int) -> bool : - # impl-start + # pure-start return (lower(c)) or (upper(c)) - # impl-end + # pure-end @Pure def flip__char(c : int) -> int : @@ -208,14 +212,14 @@ def flip__char(c : int) -> int : Ensures(upper(c) == lower(Result())) # pre-conditions-end - # impl-start + # pure-start if lower(c): return ((c) - (0)) + (26) elif upper(c): return ((c) + (0)) - (26) elif True: return c - # impl-end + # pure-end def flip__case(s : List[int]) -> List[int] : # pre-conditions-start @@ -246,25 +250,64 @@ def flip__case(s : List[int]) -> List[int] : return res # impl-end""" ) - # print(nagini_lang.generate_validators(code)) - assert nagini_lang.generate_validators(code) == dedent( + assert nagini_lang.generate_validators(code, True) == dedent( """\ + + @Pure + def lower_valid_pure(c : int) -> bool : + + return ((0) <= (c)) and ((c) <= (25)) + + @Pure + def upper_valid_pure(c : int) -> bool : + + return ((26) <= (c)) and ((c) <= (51)) + + @Pure + def alpha_valid_pure(c : int) -> bool : + + return (lower_valid_pure(c)) or (upper_valid_pure(c)) + + @Pure + def flip__char_valid_pure(c : int) -> int : + # pre-conditions-start + Ensures(lower_valid_pure(c) == upper_valid_pure(Result())) + Ensures(upper_valid_pure(c) == lower_valid_pure(Result())) + # pre-conditions-end + if lower_valid_pure(c): + return ((c) - (0)) + (26) + elif upper_valid_pure(c): + return ((c) + (0)) - (26) + elif True: + return c + def lower_valid(c : int) -> bool : + + ret = lower(c) return ret + def upper_valid(c : int) -> bool : + + ret = upper(c) return ret + def alpha_valid(c : int) -> bool : + + ret = alpha(c) return ret + def flip__char_valid(c : int) -> int : # pre-conditions-start - Ensures(lower(c) == upper(Result())) - Ensures(upper(c) == lower(Result())) + Ensures(lower_valid_pure(c) == upper_valid_pure(Result())) + Ensures(upper_valid_pure(c) == lower_valid_pure(Result())) # pre-conditions-end + ret = flip__char(c) return ret + def flip__case_valid(s : List[int]) -> List[int] : # pre-conditions-start Requires(Acc(list_pred(s))) @@ -273,9 +316,10 @@ def flip__case_valid(s : List[int]) -> List[int] : Ensures(Acc(list_pred(s))) Ensures(Acc(list_pred(Result()))) Ensures((len(Result())) == (len(s))) - Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower((s)[d_0_i_]) == upper((Result())[d_0_i_]))))) - Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper((s)[d_0_i_]) == lower((Result())[d_0_i_]))))) + Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower_valid_pure((s)[d_0_i_]) == upper_valid_pure((Result())[d_0_i_]))))) + Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper_valid_pure((s)[d_0_i_]) == lower_valid_pure((Result())[d_0_i_]))))) # post-conditions-end + ret = flip__case(s) return ret""" ) @@ -302,13 +346,15 @@ def flip__char(c : int) -> int : return c # impl-end""" ) - assert nagini_lang.generate_validators(code) == dedent( + assert nagini_lang.generate_validators(code, False) == dedent( """\ + def flip__char_valid(c : int) -> int : # pre-conditions-start Ensures(lower(c) == upper(Result())) Ensures(upper(c) == lower(Result())) # pre-conditions-end + ret = flip__char(c) return ret""" ) diff --git a/tests/test_pure_calls.py b/tests/test_pure_calls.py index 6a133ad..9211dbc 100644 --- a/tests/test_pure_calls.py +++ b/tests/test_pure_calls.py @@ -273,4 +273,4 @@ def f(n : int) -> List[int]: result: List[str] = ["factorial__spec"] - assert result == nagini_lang.find_pure_non_helpers(code) \ No newline at end of file + assert nagini_lang.find_pure_non_helpers(code) == result \ No newline at end of file diff --git a/tests/test_verus.py b/tests/test_verus.py index 479d1bc..3db3c23 100644 --- a/tests/test_verus.py +++ b/tests/test_verus.py @@ -24,7 +24,10 @@ def test_verus_generate(): value * 2 } - spec fn test() {} + spec fn test(val: i32) -> (result: i32) + { + val + } fn is_prime(num: u32) -> (result: bool) requires @@ -55,9 +58,14 @@ def test_verus_generate(): result }""" ) - assert verus_lang.generate_validators(code) == dedent( + assert verus_lang.generate_validators(code, True) == dedent( """\ verus!{ + spec fn test_valid_pure(val: i32) -> (result: i32) + { + val + } + fn main_valid(value: i32) -> (result: i32) requires value >= 10, @@ -71,6 +79,9 @@ def test_verus_generate(): ensures result <==> spec_prime(num as int), { let ret = is_prime(num); ret } + + fn test_valid(val: i32) -> (result: i32) + { let ret = test(val); ret } }""" ) diff --git a/verified_cogen/runners/languages/dafny.py b/verified_cogen/runners/languages/dafny.py index 3b8fa71..70bb5a4 100644 --- a/verified_cogen/runners/languages/dafny.py +++ b/verified_cogen/runners/languages/dafny.py @@ -9,12 +9,13 @@ """ DAFNY_VALIDATOR_TEMPLATE_PURE = """\ -function {method_name}_valid({parameters}) returns ({returns}){specs}\ +function {method_name}_valid_pure({parameters}) returns ({returns}){specs}\ { {body} } """ + class DafnyLanguage(GenericLanguage): method_regex: Pattern[str] @@ -32,7 +33,8 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore r"method\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{", re.DOTALL ), re.compile( - r"function\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{(.*?)}", re.DOTALL + r"function\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{(.*?)}", + re.DOTALL, ), DAFNY_VALIDATOR_TEMPLATE, DAFNY_VALIDATOR_TEMPLATE_PURE, diff --git a/verified_cogen/runners/languages/language.py b/verified_cogen/runners/languages/language.py index b3de9b9..b09d78c 100644 --- a/verified_cogen/runners/languages/language.py +++ b/verified_cogen/runners/languages/language.py @@ -19,7 +19,7 @@ class Language: def __init__(self, *args: list[Any], **kwargs: dict[str, Any]): ... @abstractmethod - def generate_validators(self, code: str) -> str: ... + def generate_validators(self, code: str, validate_helpers: bool) -> str: ... @abstractmethod def remove_conditions(self, code: str) -> str: ... @@ -103,17 +103,19 @@ def _validators_from_pure( for param in parameters.split(",") if param.strip() ), - ).replace("{body}", body) + ) + .replace("{body}", body) ) def replace_pure(self, code: str, pure_names: list[str]): for pure_name in pure_names: - code = code.replace(pure_name + "(", pure_name + "_valid(") + code = code.replace(pure_name + "(", pure_name + "_valid_pure(") return code - def generate_validators(self, code: str) -> str: + def generate_validators(self, code: str, validate_helpers: bool) -> str: pure_methods = list(self.pure_regex.finditer(code)) methods = list(self.method_regex.finditer(code)) + method_names = [match.group(1) for match in methods] validators: list[str] = [] pure_names: list[str] = [] @@ -131,10 +133,16 @@ def generate_validators(self, code: str) -> str: pure_match.group(5), ) + if method_name not in method_names: + methods += [pure_match] + + specs = self.replace_pure(specs, pure_names) body = self.replace_pure(body, pure_names) validators.append( - self._validators_from_pure(method_name, parameters, returns, specs, body) + self._validators_from_pure( + method_name, parameters, returns, specs, body + ) ) for match in methods: @@ -144,11 +152,11 @@ def generate_validators(self, code: str) -> str: match.group(3), match.group(4), ) - if method_name in pure_names: - continue + + specs = self.replace_pure(specs, pure_names) validators.append( - self.replace_pure(self._validators_from(method_name, parameters, returns, specs), pure_names) + self._validators_from(method_name, parameters, returns, specs) ) return "\n".join(validators) diff --git a/verified_cogen/runners/languages/nagini.py b/verified_cogen/runners/languages/nagini.py index 6c49d87..17544df 100644 --- a/verified_cogen/runners/languages/nagini.py +++ b/verified_cogen/runners/languages/nagini.py @@ -12,11 +12,10 @@ def {method_name}_valid({parameters}) -> {returns}:{specs}\ return ret\ """ -NAGINI_VALIDATOR_TEMPLATE_PURE = """\ +NAGINI_VALIDATOR_TEMPLATE_PURE = """ @Pure -def {method_name}_valid({parameters}) -> {returns}:{specs}\ - {body}\ -""" +def {method_name}_valid_pure({parameters}) -> {returns}:{specs}\ +{body}""" class NaginiLanguage(GenericLanguage): @@ -37,7 +36,7 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore re.DOTALL, ), re.compile( - r"@Pure\s+def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(.*?)\s+# pure-start(.*?)# pure-end", + r"@Pure\s+def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(.*?)\s+# pure-start(.*?)\s+# pure-end", re.DOTALL, ), NAGINI_VALIDATOR_TEMPLATE, @@ -72,5 +71,5 @@ def find_pure_non_helpers(self, code: str) -> List[str]: methods = list(pattern.finditer(code)) non_helpers: list[str] = [] for match in methods: - non_helpers.append(match.group(3)) + non_helpers.append(match.group(1)) return non_helpers diff --git a/verified_cogen/runners/languages/verus.py b/verified_cogen/runners/languages/verus.py index e6729be..8b2c5a9 100644 --- a/verified_cogen/runners/languages/verus.py +++ b/verified_cogen/runners/languages/verus.py @@ -9,10 +9,8 @@ """ VERUS_VALIDATOR_TEMPLATE_PURE = """\ -spec fn {method_name}_valid({parameters}) -> ({returns}){specs}\ -{ - {body} -} +spec fn {method_name}_valid_pure({parameters}) -> ({returns}){specs}\ +{{body}} """ @@ -34,7 +32,7 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore flags=re.DOTALL | re.MULTILINE, ), re.compile( - r"^\s*spec fn\s+(\w+)\s*\((.*?)\)\s*->\s*\((.*?)\)(.*?)\{", + r"^\s*spec fn\s+(\w+)\s*\((.*?)\)\s*->\s*\((.*?)\)(.*?)\{(.*?)}", flags=re.DOTALL | re.MULTILINE, ), VERUS_VALIDATOR_TEMPLATE, @@ -47,8 +45,8 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore "//", ) - def generate_validators(self, code: str) -> str: - result = super().generate_validators(code) + def generate_validators(self, code: str, validate_helpers: bool) -> str: + result = super().generate_validators(code, validate_helpers) return "verus!{{\n{}}}".format(result) def separate_validator_errors(self, errors: str) -> tuple[str, str]: diff --git a/verified_cogen/runners/validating.py b/verified_cogen/runners/validating.py index fcc99cb..ae0758e 100644 --- a/verified_cogen/runners/validating.py +++ b/verified_cogen/runners/validating.py @@ -35,7 +35,9 @@ def __init__( self.pure_non_helpers = [] def _add_validators(self, prg: str, inv_prg: str): - validators = self.language.generate_validators(prg) + validators = self.language.generate_validators( + prg, not self.config.remove_helpers + ) comment = self.language.simple_comment val_prg = inv_prg + "\n" + comment + " ==== verifiers ==== \n" + validators return val_prg