From ad31cc45f566e029e89beaa253b24f4727f80dfb Mon Sep 17 00:00:00 2001 From: mateuslatrova Date: Wed, 6 Dec 2023 15:34:59 -0300 Subject: [PATCH] add support for .pyw files Now, pipreqs will also scan imports in .pyw files by default. --- pipreqs/pipreqs.py | 8 ++++---- tests/_data_pyw/py.py | 5 +++++ tests/_data_pyw/pyw.pyw | 3 +++ tests/test_pipreqs.py | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 tests/_data_pyw/py.py create mode 100644 tests/_data_pyw/pyw.pyw diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 2162c8b..b969ab4 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -52,7 +52,7 @@ from pipreqs import __version__ REGEXP = [re.compile(r"^import (.+)$"), re.compile(r"^from ((?!\.+).*?) import (?:.*)$")] - +DEFAULT_EXTENSIONS = [".py", ".pyw"] scan_noteboooks = False @@ -126,7 +126,7 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links dirs[:] = [d for d in dirs if d not in ignore_dirs] candidates.append(os.path.basename(root)) - py_files = [file for file in files if file_ext_is_allowed(file, [".py"])] + py_files = [file for file in files if file_ext_is_allowed(file, DEFAULT_EXTENSIONS)] candidates.extend([os.path.splitext(filename)[0] for filename in py_files]) files = [fn for fn in files if file_ext_is_allowed(fn, extensions)] @@ -172,11 +172,11 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links def get_file_extensions(): - return [".py", ".ipynb"] if scan_noteboooks else [".py"] + return DEFAULT_EXTENSIONS + [".ipynb"] if scan_noteboooks else DEFAULT_EXTENSIONS def read_file_content(file_name: str, encoding="utf-8"): - if file_ext_is_allowed(file_name, [".py"]): + if file_ext_is_allowed(file_name, DEFAULT_EXTENSIONS): with open(file_name, "r", encoding=encoding) as f: contents = f.read() elif file_ext_is_allowed(file_name, [".ipynb"]) and scan_noteboooks: diff --git a/tests/_data_pyw/py.py b/tests/_data_pyw/py.py new file mode 100644 index 0000000..d6a91ae --- /dev/null +++ b/tests/_data_pyw/py.py @@ -0,0 +1,5 @@ +import airflow +import numpy + +airflow +numpy diff --git a/tests/_data_pyw/pyw.pyw b/tests/_data_pyw/pyw.pyw new file mode 100644 index 0000000..8377bb2 --- /dev/null +++ b/tests/_data_pyw/pyw.pyw @@ -0,0 +1,3 @@ +import matplotlib +import pandas +import tensorflow diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index 1418b87..240355b 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -629,6 +629,45 @@ def test_ignore_notebooks(self): assert os.path.exists(notebook_requirement_path) == 1 assert os.path.getsize(notebook_requirement_path) == 1 # file only has a "\n", meaning it's empty + def test_pipreqs_get_imports_from_pyw_file(self): + pyw_test_dirpath = os.path.join(os.path.dirname(__file__), "_data_pyw") + requirements_path = os.path.join(pyw_test_dirpath, "requirements.txt") + + pipreqs.init( + { + "": pyw_test_dirpath, + "--savepath": None, + "--print": False, + "--use-local": None, + "--force": True, + "--proxy": None, + "--pypi-server": None, + "--diff": None, + "--clean": None, + "--mode": None, + } + ) + + self.assertTrue(os.path.exists(requirements_path)) + + expected_imports = [ + "airflow", + "matplotlib", + "numpy", + "pandas", + "tensorflow", + ] + + with open(requirements_path, "r") as f: + imports_data = f.read().lower() + for _import in expected_imports: + self.assertTrue( + _import.lower() in imports_data, + f"'{_import}' import was expected but not found.", + ) + + os.remove(requirements_path) + def mock_scan_notebooks(self): pipreqs.scan_noteboooks = Mock(return_value=True) pipreqs.handle_scan_noteboooks()