diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 2162c8b..b969ab4 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -52,7 +52,7 @@ from pipreqs import __version__ REGEXP = [re.compile(r"^import (.+)$"), re.compile(r"^from ((?!\.+).*?) import (?:.*)$")] - +DEFAULT_EXTENSIONS = [".py", ".pyw"] scan_noteboooks = False @@ -126,7 +126,7 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links dirs[:] = [d for d in dirs if d not in ignore_dirs] candidates.append(os.path.basename(root)) - py_files = [file for file in files if file_ext_is_allowed(file, [".py"])] + py_files = [file for file in files if file_ext_is_allowed(file, DEFAULT_EXTENSIONS)] candidates.extend([os.path.splitext(filename)[0] for filename in py_files]) files = [fn for fn in files if file_ext_is_allowed(fn, extensions)] @@ -172,11 +172,11 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links def get_file_extensions(): - return [".py", ".ipynb"] if scan_noteboooks else [".py"] + return DEFAULT_EXTENSIONS + [".ipynb"] if scan_noteboooks else DEFAULT_EXTENSIONS def read_file_content(file_name: str, encoding="utf-8"): - if file_ext_is_allowed(file_name, [".py"]): + if file_ext_is_allowed(file_name, DEFAULT_EXTENSIONS): with open(file_name, "r", encoding=encoding) as f: contents = f.read() elif file_ext_is_allowed(file_name, [".ipynb"]) and scan_noteboooks: diff --git a/tests/_data_pyw/py.py b/tests/_data_pyw/py.py new file mode 100644 index 0000000..d6a91ae --- /dev/null +++ b/tests/_data_pyw/py.py @@ -0,0 +1,5 @@ +import airflow +import numpy + +airflow +numpy diff --git a/tests/_data_pyw/pyw.pyw b/tests/_data_pyw/pyw.pyw new file mode 100644 index 0000000..8377bb2 --- /dev/null +++ b/tests/_data_pyw/pyw.pyw @@ -0,0 +1,3 @@ +import matplotlib +import pandas +import tensorflow diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index 1418b87..240355b 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -629,6 +629,45 @@ def test_ignore_notebooks(self): assert os.path.exists(notebook_requirement_path) == 1 assert os.path.getsize(notebook_requirement_path) == 1 # file only has a "\n", meaning it's empty + def test_pipreqs_get_imports_from_pyw_file(self): + pyw_test_dirpath = os.path.join(os.path.dirname(__file__), "_data_pyw") + requirements_path = os.path.join(pyw_test_dirpath, "requirements.txt") + + pipreqs.init( + { + "": pyw_test_dirpath, + "--savepath": None, + "--print": False, + "--use-local": None, + "--force": True, + "--proxy": None, + "--pypi-server": None, + "--diff": None, + "--clean": None, + "--mode": None, + } + ) + + self.assertTrue(os.path.exists(requirements_path)) + + expected_imports = [ + "airflow", + "matplotlib", + "numpy", + "pandas", + "tensorflow", + ] + + with open(requirements_path, "r") as f: + imports_data = f.read().lower() + for _import in expected_imports: + self.assertTrue( + _import.lower() in imports_data, + f"'{_import}' import was expected but not found.", + ) + + os.remove(requirements_path) + def mock_scan_notebooks(self): pipreqs.scan_noteboooks = Mock(return_value=True) pipreqs.handle_scan_noteboooks()