Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use results dir as working directory for decorator jobs #743

Merged
merged 12 commits into from
Oct 14, 2023
73 changes: 53 additions & 20 deletions src/braket/jobs/_entry_point_template.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
run_entry_point = """
import cloudpickle
from braket.jobs import save_job_result
import os
from braket.jobs import get_results_dir, save_job_result
from braket.jobs_data import PersistedJobDataFormat
# set working directory to results dir
os.chdir(get_results_dir())
mbeach-aws marked this conversation as resolved.
Show resolved Hide resolved
# create symlinks to input data
links = link_input()
# load and run serialized entry point function
recovered = cloudpickle.loads({serialized})
def {function_name}():
result = recovered()
try:
result = recovered()
finally:
clean_links(links)
if result is not None:
save_job_result(result, data_format=PersistedJobDataFormat.PICKLED_V4)
return result
Expand All @@ -16,31 +27,53 @@ def {function_name}():
from pathlib import Path
from braket.jobs import get_input_data_dir
# map of data sources to lists of matched local files
prefix_matches = {prefix_matches}
def make_link(input_link_path, input_data_path):
def make_link(input_link_path, input_data_path, links):
""" Create symlink from input_link_path to input_data_path. """
input_link_path.parent.mkdir(parents=True, exist_ok=True)
input_link_path.symlink_to(input_data_path)
print(input_link_path, '->', input_data_path)
links[input_link_path] = input_data_path
def link_input():
links = {{}}
dirs = set()
# map of data sources to lists of matched local files
prefix_matches = {prefix_matches}
for channel, data in {input_data_items}:
for channel, data in {input_data_items}:
if channel in {prefix_channels}:
# link all matched files
for input_link_name in prefix_matches[channel]:
input_link_path = Path(input_link_name)
input_data_path = Path(get_input_data_dir(channel)) / input_link_path.name
make_link(input_link_path, input_data_path)
if channel in {prefix_channels}:
# link all matched files
for input_link_name in prefix_matches[channel]:
input_link_path = Path(input_link_name)
input_data_path = Path(get_input_data_dir(channel)) / input_link_path.name
make_link(input_link_path, input_data_path, links)
else:
input_link_path = Path(data)
if channel in {directory_channels}:
# link directory source directly to input channel directory
input_data_path = Path(get_input_data_dir(channel))
else:
# link file source to file within input channel directory
input_data_path = Path(get_input_data_dir(channel), Path(data).name)
make_link(input_link_path, input_data_path)
input_link_path = Path(data)
if channel in {directory_channels}:
# link directory source directly to input channel directory
input_data_path = Path(get_input_data_dir(channel))
else:
# link file source to file within input channel directory
input_data_path = Path(get_input_data_dir(channel), Path(data).name)
make_link(input_link_path, input_data_path, links)
return links
def clean_links(links):
for link, target in links.items():
if link.is_symlink and link.readlink() == target:
link.unlink()
if link.is_relative_to(Path()):
for dir in link.parents[:-1]:
try:
dir.rmdir()
except:
# directory not empty
pass
'''
41 changes: 31 additions & 10 deletions test/integ_tests/test_create_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,17 @@ def test_completed_quantum_job(aws_session, capsys):

with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
job.download_result()
assert (
Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists() and Path(downloaded_result).exists()
)
try:
job.download_result()
assert (
Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists()
and Path(downloaded_result).exists()
)

# Check results match the expectations.
assert job.result() == {"converged": True, "energy": -0.2}
os.chdir(current_dir)
# Check results match the expectations.
assert job.result() == {"converged": True, "energy": -0.2}
finally:
os.chdir(current_dir)

# Check the logs and validate it contains required output.
job.logs(wait=True)
Expand Down Expand Up @@ -227,9 +230,27 @@ def decorator_job(a, b: int, c=0, d: float = 1.0, **extras):
"extra_arg": "extra_value",
}

job = decorator_job(MyClass, 2, d=5, extra_arg="extra_value")
with open("test/output_file.txt", "w") as f:
f.write("hello")

job = decorator_job(MyClass(), 2, d=5, extra_arg="extra_value")
assert job.result()["status"] == "SUCCESS"

current_dir = Path.cwd()
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
try:
job.download_result()
with open(Path(job.name, "test", "output_file.txt"), "r") as f:
assert f.read() == "hello"
assert (
Path(job.name, "results.json").exists()
and Path(job.name, "test").exists()
and not Path(job.name, "test", "integ_tests").exists()
)
finally:
os.chdir(current_dir)


def test_decorator_job_submodule():
@hybrid_job(
Expand All @@ -247,7 +268,7 @@ def test_decorator_job_submodule():
"my_dir": str(Path("test", "integ_tests", "job_test_module")),
},
)
def decorator_job():
def decorator_job_submodule():
save_job_result(submodule_helper())
with open(Path(get_input_data_dir("my_input")) / "requirements.txt", "r") as f:
assert f.readlines() == ["pytest\n"]
Expand All @@ -266,5 +287,5 @@ def decorator_job():
assert f.readlines() == ["pytest\n"]
assert dir(pytest)

job = decorator_job()
job = decorator_job_submodule()
assert job.result()["status"] == "SUCCESS"