Skip to content

Commit

Permalink
Add environment variable checks to test suite and include numpy as a …
Browse files Browse the repository at this point in the history
…required package
  • Loading branch information
kasinadhsarma committed Dec 26, 2024
1 parent c866e4c commit fdc596d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
Binary file modified tests/__pycache__/test_environment.cpython-310-pytest-8.3.4.pyc
Binary file not shown.
14 changes: 12 additions & 2 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest import skipIf
from packaging import version # Fix: import version directly
import torch
import os # Added for environment variable checks

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -20,8 +21,10 @@ def setUpClass(cls):
cls.required_packages = {
'torch': '2.0.1',
'torchvision': '0.15.2',
'torchaudio': '2.0.1'
'torchaudio': '2.0.1',
'numpy': '1.21.0' # Added numpy as an example
}
cls.required_env_vars = ['HOME', 'PATH'] # Example environment variables
cls._check_package_installations()

@classmethod
Expand Down Expand Up @@ -124,7 +127,7 @@ def test_memory_allocation(self):
import torch

test_sizes = [(1000, 1000), (2000, 2000)]
devices = ['cpu'] + [f'cuda:{i}' for i in range(torch.cuda.device_count())] if torch.cuda.is_available() else []
devices = ['cpu'] + [f'cuda:{i}' for i in range(torch.cuda.device_count())] if torch.cuda.is_available() else ['cpu']

for device in devices:
for size in test_sizes:
Expand Down Expand Up @@ -216,6 +219,13 @@ def test_error_handling_and_logging(self):
logger.error(f"Unexpected error during tensor creation: {str(e)}")
self.fail(f"Unexpected error during tensor creation: {str(e)}")

def test_environment_variables(self):
"""Verify essential environment variables are set"""
for var in self.required_env_vars:
value = os.getenv(var)
self.assertIsNotNone(value, f"Environment variable '{var}' is not set")
logger.info(f"Environment variable '{var}': {value}")

if __name__ == '__main__':
logger.info("Starting environment tests")
logger.info(f"Platform: {platform.platform()}")
Expand Down

0 comments on commit fdc596d

Please sign in to comment.