Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alex28sh committed Dec 10, 2024
1 parent 77c3bdb commit 5cbde09
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 46 deletions.
35 changes: 33 additions & 2 deletions tests/test_dafny.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_dafny_generate():
result := value * 2;
}"""
)
assert dafny_lang.generate_validators(code) == dedent(
assert dafny_lang.generate_validators(code, True) == dedent(
"""\
method main_valid(value: int) returns (result: int)
requires value >= 10
Expand All @@ -32,6 +32,36 @@ def test_dafny_generate():
)


def test_dafny_generate_with_helper():
dafny_lang = LanguageDatabase().get("dafny")
code = dedent(
"""\
function abs(n: int) returns (k: nat) { if n > 0 then n else -n }
method main(value: int) returns (result: int)
requires value >= 10
ensures result >= 20
{
assert value * 2 >= 20; // assert-line
result := value * 2;
}"""
)
assert dafny_lang.generate_validators(code, True) == dedent(
"""\
function abs_valid_pure(n: int) returns (k: nat) {
if n > 0 then n else -n
}
method main_valid(value: int) returns (result: int)
requires value >= 10
ensures result >= 20
{ var ret0 := main(value); return ret0; }
method abs_valid(n: int) returns (k: nat) { var ret0 := abs(n); return ret0; }
"""
)


def test_dafny_generate_multiple_returns():
dafny_lang = LanguageDatabase().get("dafny")
code = dedent(
Expand All @@ -46,7 +76,7 @@ def test_dafny_generate_multiple_returns():
result2 := value * 3;
}"""
)
assert dafny_lang.generate_validators(code) == dedent(
assert dafny_lang.generate_validators(code, True) == dedent(
"""\
method main_valid(value: int) returns (result: int, result2: int)
requires value >= 10
Expand Down Expand Up @@ -167,3 +197,4 @@ def test_remove_all():
}
}"""
)

80 changes: 63 additions & 17 deletions tests/test_nagini.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ def main(value: int) -> int:
return value * 2
# impl-end"""
)
assert nagini_lang.generate_validators(code) == dedent(
assert nagini_lang.generate_validators(code, True) == dedent(
"""\
def main_valid(value: int) -> int:
Requires(value >= 10)
Ensures(Result() >= 20)
ret = main(value)
return ret"""
)
Expand All @@ -48,15 +50,17 @@ def main(value: int) -> int:
return value * 2
# impl-end"""
)
assert nagini_lang.generate_validators(code) == dedent(
assert nagini_lang.generate_validators(code, True) == dedent(
"""\
def main_valid(value: int) -> int:
# pre-conditions-start
Requires(value >= 10)
# pre-conditions-end
# post-conditions-start
Ensures(Result() >= 20)
# post-conditions-end
ret = main(value)
return ret"""
)
Expand Down Expand Up @@ -185,21 +189,21 @@ def test_nagini_large():
@Pure
def lower(c : int) -> bool :
# impl-start
# pure-start
return ((0) <= (c)) and ((c) <= (25))
# impl-end
# pure-end
@Pure
def upper(c : int) -> bool :
# impl-start
# pure-start
return ((26) <= (c)) and ((c) <= (51))
# impl-end
# pure-end
@Pure
def alpha(c : int) -> bool :
# impl-start
# pure-start
return (lower(c)) or (upper(c))
# impl-end
# pure-end
@Pure
def flip__char(c : int) -> int :
Expand All @@ -208,14 +212,14 @@ def flip__char(c : int) -> int :
Ensures(upper(c) == lower(Result()))
# pre-conditions-end
# impl-start
# pure-start
if lower(c):
return ((c) - (0)) + (26)
elif upper(c):
return ((c) + (0)) - (26)
elif True:
return c
# impl-end
# pure-end
def flip__case(s : List[int]) -> List[int] :
# pre-conditions-start
Expand Down Expand Up @@ -246,25 +250,64 @@ def flip__case(s : List[int]) -> List[int] :
return res
# impl-end"""
)
# print(nagini_lang.generate_validators(code))
assert nagini_lang.generate_validators(code) == dedent(
assert nagini_lang.generate_validators(code, True) == dedent(
"""\
@Pure
def lower_valid_pure(c : int) -> bool :
return ((0) <= (c)) and ((c) <= (25))
@Pure
def upper_valid_pure(c : int) -> bool :
return ((26) <= (c)) and ((c) <= (51))
@Pure
def alpha_valid_pure(c : int) -> bool :
return (lower_valid_pure(c)) or (upper_valid_pure(c))
@Pure
def flip__char_valid_pure(c : int) -> int :
# pre-conditions-start
Ensures(lower_valid_pure(c) == upper_valid_pure(Result()))
Ensures(upper_valid_pure(c) == lower_valid_pure(Result()))
# pre-conditions-end
if lower_valid_pure(c):
return ((c) - (0)) + (26)
elif upper_valid_pure(c):
return ((c) + (0)) - (26)
elif True:
return c
def lower_valid(c : int) -> bool :
ret = lower(c)
return ret
def upper_valid(c : int) -> bool :
ret = upper(c)
return ret
def alpha_valid(c : int) -> bool :
ret = alpha(c)
return ret
def flip__char_valid(c : int) -> int :
# pre-conditions-start
Ensures(lower(c) == upper(Result()))
Ensures(upper(c) == lower(Result()))
Ensures(lower_valid_pure(c) == upper_valid_pure(Result()))
Ensures(upper_valid_pure(c) == lower_valid_pure(Result()))
# pre-conditions-end
ret = flip__char(c)
return ret
def flip__case_valid(s : List[int]) -> List[int] :
# pre-conditions-start
Requires(Acc(list_pred(s)))
Expand All @@ -273,9 +316,10 @@ def flip__case_valid(s : List[int]) -> List[int] :
Ensures(Acc(list_pred(s)))
Ensures(Acc(list_pred(Result())))
Ensures((len(Result())) == (len(s)))
Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower((s)[d_0_i_]) == upper((Result())[d_0_i_])))))
Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper((s)[d_0_i_]) == lower((Result())[d_0_i_])))))
Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), lower_valid_pure((s)[d_0_i_]) == upper_valid_pure((Result())[d_0_i_])))))
Ensures(Forall(int, lambda d_0_i_: (Implies(((0) <= (d_0_i_)) and ((d_0_i_) < (len(s))), upper_valid_pure((s)[d_0_i_]) == lower_valid_pure((Result())[d_0_i_])))))
# post-conditions-end
ret = flip__case(s)
return ret"""
)
Expand All @@ -302,13 +346,15 @@ def flip__char(c : int) -> int :
return c
# impl-end"""
)
assert nagini_lang.generate_validators(code) == dedent(
assert nagini_lang.generate_validators(code, False) == dedent(
"""\
def flip__char_valid(c : int) -> int :
# pre-conditions-start
Ensures(lower(c) == upper(Result()))
Ensures(upper(c) == lower(Result()))
# pre-conditions-end
ret = flip__char(c)
return ret"""
)
2 changes: 1 addition & 1 deletion tests/test_pure_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,4 @@ def f(n : int) -> List[int]:

result: List[str] = ["factorial__spec"]

assert result == nagini_lang.find_pure_non_helpers(code)
assert nagini_lang.find_pure_non_helpers(code) == result
15 changes: 13 additions & 2 deletions tests/test_verus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def test_verus_generate():
value * 2
}
spec fn test() {}
spec fn test(val: i32) -> (result: i32)
{
val
}
fn is_prime(num: u32) -> (result: bool)
requires
Expand Down Expand Up @@ -55,9 +58,14 @@ def test_verus_generate():
result
}"""
)
assert verus_lang.generate_validators(code) == dedent(
assert verus_lang.generate_validators(code, True) == dedent(
"""\
verus!{
spec fn test_valid_pure(val: i32) -> (result: i32)
{
val
}
fn main_valid(value: i32) -> (result: i32)
requires
value >= 10,
Expand All @@ -71,6 +79,9 @@ def test_verus_generate():
ensures
result <==> spec_prime(num as int),
{ let ret = is_prime(num); ret }
fn test_valid(val: i32) -> (result: i32)
{ let ret = test(val); ret }
}"""
)

Expand Down
6 changes: 4 additions & 2 deletions verified_cogen/runners/languages/dafny.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
"""

DAFNY_VALIDATOR_TEMPLATE_PURE = """\
function {method_name}_valid({parameters}) returns ({returns}){specs}\
function {method_name}_valid_pure({parameters}) returns ({returns}){specs}\
{
{body}
}
"""


class DafnyLanguage(GenericLanguage):
method_regex: Pattern[str]

Expand All @@ -32,7 +33,8 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore
r"method\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{", re.DOTALL
),
re.compile(
r"function\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{(.*?)}", re.DOTALL
r"function\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{(.*?)}",
re.DOTALL,
),
DAFNY_VALIDATOR_TEMPLATE,
DAFNY_VALIDATOR_TEMPLATE_PURE,
Expand Down
24 changes: 16 additions & 8 deletions verified_cogen/runners/languages/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Language:
def __init__(self, *args: list[Any], **kwargs: dict[str, Any]): ...

@abstractmethod
def generate_validators(self, code: str) -> str: ...
def generate_validators(self, code: str, validate_helpers: bool) -> str: ...

@abstractmethod
def remove_conditions(self, code: str) -> str: ...
Expand Down Expand Up @@ -103,17 +103,19 @@ def _validators_from_pure(
for param in parameters.split(",")
if param.strip()
),
).replace("{body}", body)
)
.replace("{body}", body)
)

def replace_pure(self, code: str, pure_names: list[str]):
for pure_name in pure_names:
code = code.replace(pure_name + "(", pure_name + "_valid(")
code = code.replace(pure_name + "(", pure_name + "_valid_pure(")
return code

def generate_validators(self, code: str) -> str:
def generate_validators(self, code: str, validate_helpers: bool) -> str:
pure_methods = list(self.pure_regex.finditer(code))
methods = list(self.method_regex.finditer(code))
method_names = [match.group(1) for match in methods]

validators: list[str] = []
pure_names: list[str] = []
Expand All @@ -131,10 +133,16 @@ def generate_validators(self, code: str) -> str:
pure_match.group(5),
)

if method_name not in method_names:
methods += [pure_match]

specs = self.replace_pure(specs, pure_names)
body = self.replace_pure(body, pure_names)

validators.append(
self._validators_from_pure(method_name, parameters, returns, specs, body)
self._validators_from_pure(
method_name, parameters, returns, specs, body
)
)

for match in methods:
Expand All @@ -144,11 +152,11 @@ def generate_validators(self, code: str) -> str:
match.group(3),
match.group(4),
)
if method_name in pure_names:
continue

specs = self.replace_pure(specs, pure_names)

validators.append(
self.replace_pure(self._validators_from(method_name, parameters, returns, specs), pure_names)
self._validators_from(method_name, parameters, returns, specs)
)

return "\n".join(validators)
Expand Down
Loading

0 comments on commit 5cbde09

Please sign in to comment.