diff --git a/tests/test_filter.py b/tests/test_filter.py index 1d64e67..d98878d 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -291,14 +291,20 @@ def test_filter_words_edge_cases(self): def test_load_filtered_policy_with_comments(self): """Test loading filtering policies with comments.""" - temp_file = tempfile.NamedTemporaryFile(delete=False) - try: - shutil.copyfile(get_examples("rbac_with_domains_policy.csv"), temp_file.name) + import tempfile + import shutil + + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + with open(get_examples("rbac_with_domains_policy.csv"), "r") as source: + shutil.copyfileobj(source, temp_file) - with open(temp_file.name, "a") as f: - f.write("\n# This is a comment\np, admin, domain1, data3, read") + temp_file.write("\n# This is a comment\np, admin, domain1, data3, read") + temp_file.flush() - adapter = FilteredFileAdapter(temp_file.name) + temp_path = temp_file.name + + try: + adapter = FilteredFileAdapter(temp_path) e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) filter = Filter() filter.P = ["", "domain1"] @@ -307,4 +313,7 @@ def test_load_filtered_policy_with_comments(self): e.load_filtered_policy(filter) self.assertTrue(e.has_policy(["admin", "domain1", "data3", "read"])) finally: - os.unlink(temp_file.name) + try: + os.unlink(temp_path) + except OSError: + pass