diff --git a/tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc b/tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc index 109ddad..50c3dc7 100644 Binary files a/tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc and b/tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc differ diff --git a/tests/test_environment.py b/tests/test_environment.py index 2036aa4..eb94584 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -7,6 +7,7 @@ from unittest import skipIf from packaging import version # Fix: import version directly import torch +import os # Added for environment variable checks logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -20,8 +21,10 @@ def setUpClass(cls): cls.required_packages = { 'torch': '2.0.1', 'torchvision': '0.15.2', - 'torchaudio': '2.0.1' + 'torchaudio': '2.0.1', + 'numpy': '1.21.0' # Added numpy as an example } + cls.required_env_vars = ['HOME', 'PATH'] # Example environment variables cls._check_package_installations() @classmethod @@ -124,7 +127,7 @@ def test_memory_allocation(self): import torch test_sizes = [(1000, 1000), (2000, 2000)] - devices = ['cpu'] + [f'cuda:{i}' for i in range(torch.cuda.device_count())] if torch.cuda.is_available() else [] + devices = ['cpu'] + [f'cuda:{i}' for i in range(torch.cuda.device_count())] if torch.cuda.is_available() else ['cpu'] for device in devices: for size in test_sizes: @@ -216,6 +219,13 @@ def test_error_handling_and_logging(self): logger.error(f"Unexpected error during tensor creation: {str(e)}") self.fail(f"Unexpected error during tensor creation: {str(e)}") + def test_environment_variables(self): + """Verify essential environment variables are set""" + for var in self.required_env_vars: + value = os.getenv(var) + self.assertIsNotNone(value, f"Environment variable '{var}' is not set") + logger.info(f"Environment variable '{var}': {value}") + if __name__ == '__main__': logger.info("Starting environment tests") logger.info(f"Platform: {platform.platform()}")