diff --git a/src/braket/jobs/_entry_point_template.py b/src/braket/jobs/_entry_point_template.py index 5165824e1..285e4d85e 100644 --- a/src/braket/jobs/_entry_point_template.py +++ b/src/braket/jobs/_entry_point_template.py @@ -1,12 +1,23 @@ run_entry_point = """ import cloudpickle -from braket.jobs import save_job_result +import os +from braket.jobs import get_results_dir, save_job_result from braket.jobs_data import PersistedJobDataFormat + +# set working directory to results dir +os.chdir(get_results_dir()) + +# create symlinks to input data +links = link_input() + # load and run serialized entry point function recovered = cloudpickle.loads({serialized}) def {function_name}(): - result = recovered() + try: + result = recovered() + finally: + clean_links(links) if result is not None: save_job_result(result, data_format=PersistedJobDataFormat.PICKLED_V4) return result @@ -16,31 +27,53 @@ def {function_name}(): from pathlib import Path from braket.jobs import get_input_data_dir -# map of data sources to lists of matched local files -prefix_matches = {prefix_matches} -def make_link(input_link_path, input_data_path): +def make_link(input_link_path, input_data_path, links): """ Create symlink from input_link_path to input_data_path. """ input_link_path.parent.mkdir(parents=True, exist_ok=True) input_link_path.symlink_to(input_data_path) print(input_link_path, '->', input_data_path) + links[input_link_path] = input_data_path + + +def link_input(): + links = {{}} + dirs = set() + # map of data sources to lists of matched local files + prefix_matches = {prefix_matches} -for channel, data in {input_data_items}: + for channel, data in {input_data_items}: - if channel in {prefix_channels}: - # link all matched files - for input_link_name in prefix_matches[channel]: - input_link_path = Path(input_link_name) - input_data_path = Path(get_input_data_dir(channel)) / input_link_path.name - make_link(input_link_path, input_data_path) + if channel in {prefix_channels}: + # link all matched files + for input_link_name in prefix_matches[channel]: + input_link_path = Path(input_link_name) + input_data_path = Path(get_input_data_dir(channel)) / input_link_path.name + make_link(input_link_path, input_data_path, links) - else: - input_link_path = Path(data) - if channel in {directory_channels}: - # link directory source directly to input channel directory - input_data_path = Path(get_input_data_dir(channel)) else: - # link file source to file within input channel directory - input_data_path = Path(get_input_data_dir(channel), Path(data).name) - make_link(input_link_path, input_data_path) + input_link_path = Path(data) + if channel in {directory_channels}: + # link directory source directly to input channel directory + input_data_path = Path(get_input_data_dir(channel)) + else: + # link file source to file within input channel directory + input_data_path = Path(get_input_data_dir(channel), Path(data).name) + make_link(input_link_path, input_data_path, links) + + return links + + +def clean_links(links): + for link, target in links.items(): + if link.is_symlink and link.readlink() == target: + link.unlink() + + if link.is_relative_to(Path()): + for dir in link.parents[:-1]: + try: + dir.rmdir() + except: + # directory not empty + pass ''' diff --git a/test/integ_tests/test_create_quantum_job.py b/test/integ_tests/test_create_quantum_job.py index 5364c8e3b..80a684476 100644 --- a/test/integ_tests/test_create_quantum_job.py +++ b/test/integ_tests/test_create_quantum_job.py @@ -158,14 +158,17 @@ def test_completed_quantum_job(aws_session, capsys): with tempfile.TemporaryDirectory() as temp_dir: os.chdir(temp_dir) - job.download_result() - assert ( - Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists() and Path(downloaded_result).exists() - ) + try: + job.download_result() + assert ( + Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists() + and Path(downloaded_result).exists() + ) - # Check results match the expectations. - assert job.result() == {"converged": True, "energy": -0.2} - os.chdir(current_dir) + # Check results match the expectations. + assert job.result() == {"converged": True, "energy": -0.2} + finally: + os.chdir(current_dir) # Check the logs and validate it contains required output. job.logs(wait=True) @@ -227,9 +230,27 @@ def decorator_job(a, b: int, c=0, d: float = 1.0, **extras): "extra_arg": "extra_value", } - job = decorator_job(MyClass, 2, d=5, extra_arg="extra_value") + with open("test/output_file.txt", "w") as f: + f.write("hello") + + job = decorator_job(MyClass(), 2, d=5, extra_arg="extra_value") assert job.result()["status"] == "SUCCESS" + current_dir = Path.cwd() + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + try: + job.download_result() + with open(Path(job.name, "test", "output_file.txt"), "r") as f: + assert f.read() == "hello" + assert ( + Path(job.name, "results.json").exists() + and Path(job.name, "test").exists() + and not Path(job.name, "test", "integ_tests").exists() + ) + finally: + os.chdir(current_dir) + def test_decorator_job_submodule(): @hybrid_job( @@ -247,7 +268,7 @@ def test_decorator_job_submodule(): "my_dir": str(Path("test", "integ_tests", "job_test_module")), }, ) - def decorator_job(): + def decorator_job_submodule(): save_job_result(submodule_helper()) with open(Path(get_input_data_dir("my_input")) / "requirements.txt", "r") as f: assert f.readlines() == ["pytest\n"] @@ -266,5 +287,5 @@ def decorator_job(): assert f.readlines() == ["pytest\n"] assert dir(pytest) - job = decorator_job() + job = decorator_job_submodule() assert job.result()["status"] == "SUCCESS"