diff --git a/CHANGELOG.md b/CHANGELOG.md index 47d87d33a..7bafd45c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## v1.68.0 (2024-01-25) + +### Features + + * update S3 locations for jobs + +## v1.67.0 (2024-01-23) + +### Features + + * add queue position to the logs for tasks and jobs + +## v1.66.0 (2024-01-11) + +### Features + + * update job name to use metadata + ## v1.65.1 (2023-12-25) ### Bug Fixes and Other Changes diff --git a/setup.py b/setup.py index 86ec939e3..d31f89f16 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,6 @@ "black", "botocore", "flake8<=5.0.4", - "flake8-rst-docstrings", "isort", "jsonschema==3.2.0", "pre-commit", diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index 278494287..051177529 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -15,4 +15,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "1.65.2.dev0" +__version__ = "1.68.1.dev0" diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index 1e929e857..562c61ba8 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -81,6 +81,7 @@ def create( aws_session: AwsSession | None = None, tags: dict[str, str] | None = None, logger: Logger = getLogger(__name__), + quiet: bool = False, reservation_arn: str | None = None, ) -> AwsQuantumJob: """Creates a hybrid job by invoking the Braket CreateJob API. @@ -176,6 +177,9 @@ def create( while waiting for quantum task to be in a terminal state. Default is `getLogger(__name__)` + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + reservation_arn (str | None): the reservation window arn provided by Braket Direct to reserve exclusive usage for the device to run the hybrid job on. Default: None. @@ -210,7 +214,7 @@ def create( ) job_arn = aws_session.create_job(**create_job_kwargs) - job = AwsQuantumJob(job_arn, aws_session) + job = AwsQuantumJob(job_arn, aws_session, quiet) if wait_until_complete: print(f"Initializing Braket Job: {job_arn}") @@ -218,15 +222,18 @@ def create( return job - def __init__(self, arn: str, aws_session: AwsSession | None = None): + def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool = False): """ Args: arn (str): The ARN of the hybrid job. aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services. Default is `None`, in which case an `AwsSession` object will be created with the region of the hybrid job. + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. """ self._arn: str = arn + self._quiet = quiet if aws_session: if not self._is_valid_aws_session_region_for_job_arn(aws_session, arn): raise ValueError( @@ -268,7 +275,7 @@ def arn(self) -> str: @property def name(self) -> str: """str: The name of the quantum job.""" - return self._arn.partition("job/")[-1] + return self.metadata(use_cached_value=True).get("jobName") def state(self, use_cached_value: bool = False) -> str: """The state of the quantum hybrid job. @@ -371,10 +378,11 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"] has_streams = False color_wrap = logs.ColorWrap() + previous_state = self.state() while True: time.sleep(poll_interval_seconds) - + current_state = self.state() has_streams = logs.flush_log_streams( self._aws_session, log_group, @@ -384,14 +392,17 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: instance_count, has_streams, color_wrap, + [previous_state, current_state], + self.queue_position().queue_position if not self._quiet else None, ) + previous_state = current_state if log_state == AwsQuantumJob.LogState.COMPLETE: break if log_state == AwsQuantumJob.LogState.JOB_COMPLETE: log_state = AwsQuantumJob.LogState.COMPLETE - elif self.state() in AwsQuantumJob.TERMINAL_STATES: + elif current_state in AwsQuantumJob.TERMINAL_STATES: log_state = AwsQuantumJob.LogState.JOB_COMPLETE def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index c490a4190..0785c03c5 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -105,6 +105,7 @@ def create( tags: dict[str, str] | None = None, inputs: dict[str, float] | None = None, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None, + quiet: bool = False, reservation_arn: str | None = None, *args, **kwargs, @@ -152,6 +153,9 @@ def create( a `PulseSequence`. Default: None. + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + reservation_arn (str | None): The reservation ARN provided by Braket Direct to reserve exclusive usage for the device to run the quantum task on. Note: If you are creating tasks in a job that itself was created reservation ARN, @@ -215,6 +219,7 @@ def create( disable_qubit_rewiring, inputs, gate_definitions=gate_definitions, + quiet=quiet, *args, **kwargs, ) @@ -226,6 +231,7 @@ def __init__( poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL, logger: Logger = getLogger(__name__), + quiet: bool = False, ): """ Args: @@ -238,6 +244,8 @@ def __init__( logger (Logger): Logger object with which to write logs, such as quantum task statuses while waiting for quantum task to be in a terminal state. Default is `getLogger(__name__)` + quiet (bool): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. Examples: >>> task = AwsQuantumTask(arn='task_arn') @@ -259,6 +267,7 @@ def __init__( self._poll_interval_seconds = poll_interval_seconds self._logger = logger + self._quiet = quiet self._metadata: dict[str, Any] = {} self._result: Union[ @@ -477,6 +486,11 @@ async def _wait_for_completion( while (time.time() - start_time) < self._poll_timeout_seconds: # Used cached metadata if cached status is terminal task_status = self._update_status_if_nonterminal() + if not self._quiet and task_status == "QUEUED": + queue = self.queue_position() + self._logger.debug( + f"Task is in {queue.queue_type} queue position: {queue.queue_position}" + ) self._logger.debug(f"Task {self._arn}: task status {task_status}") if task_status in AwsQuantumTask.RESULTS_READY_STATES: return self._download_result() diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index 707f18fd5..b8e1e58bf 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -63,6 +63,7 @@ def hybrid_job( aws_session: AwsSession | None = None, tags: dict[str, str] | None = None, logger: Logger = getLogger(__name__), + quiet: bool | None = None, reservation_arn: str | None = None, ) -> Callable: """Defines a hybrid job by decorating the entry point function. The job will be created @@ -71,7 +72,7 @@ def hybrid_job( The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an `AwsQuantumJob`. The following parameters will be ignored when running a job with `local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`, - `copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`. + `copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`. Args: device (str | None): Device ARN of the QPU device that receives priority quantum @@ -153,6 +154,9 @@ def hybrid_job( logger (Logger): Logger object with which to write logs, such as task statuses while waiting for task to be in a terminal state. Default: `getLogger(__name__)` + quiet (bool | None): Sets the verbosity of the logger to low and does not report queue + position. Default is `False`. + reservation_arn (str | None): the reservation window arn provided by Braket Direct to reserve exclusive usage for the device to run the hybrid job on. Default: None. @@ -210,6 +214,7 @@ def job_wrapper(*args, **kwargs) -> Callable: "output_data_config": output_data_config, "aws_session": aws_session, "tags": tags, + "quiet": quiet, "reservation_arn": reservation_arn, } for key, value in optional_args.items(): diff --git a/src/braket/jobs/logs.py b/src/braket/jobs/logs.py index e0f54458d..734d51123 100644 --- a/src/braket/jobs/logs.py +++ b/src/braket/jobs/logs.py @@ -20,7 +20,7 @@ # Support for reading logs # ############################################################################## -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from botocore.exceptions import ClientError @@ -155,7 +155,7 @@ def log_stream( yield ev -def flush_log_streams( +def flush_log_streams( # noqa C901 aws_session: AwsSession, log_group: str, stream_prefix: str, @@ -164,6 +164,8 @@ def flush_log_streams( stream_count: int, has_streams: bool, color_wrap: ColorWrap, + state: list[str], + queue_position: Optional[str] = None, ) -> bool: """Flushes log streams to stdout. @@ -183,6 +185,9 @@ def flush_log_streams( been found. This value is possibly updated and returned at the end of execution. color_wrap (ColorWrap): An instance of ColorWrap to potentially color-wrap print statements from different streams. + state (list[str]): The previous and current state of the job. + queue_position (Optional[str]): The current queue position. This is not passed in if the job + is ran with `quiet=True` Returns: bool: Returns 'True' if any streams have been flushed. @@ -225,6 +230,10 @@ def flush_log_streams( positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1) else: positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1) + elif queue_position is not None and state[1] == "QUEUED": + print(f"Job queue position: {queue_position}", end="\n", flush=True) + elif state[0] != state[1] and state[1] == "RUNNING" and queue_position is not None: + print("Running:", end="\n", flush=True) else: print(".", end="", flush=True) return has_streams diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 657ed0829..9e18faeab 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -161,14 +161,15 @@ def prepare_quantum_job( _validate_params(param_datatype_map) aws_session = aws_session or AwsSession() device_config = DeviceConfig(device) - job_name = job_name or _generate_default_job_name(image_uri=image_uri) + timestamp = str(int(time.time() * 1000)) + job_name = job_name or _generate_default_job_name(image_uri=image_uri, timestamp=timestamp) role_arn = role_arn or os.getenv("BRAKET_JOBS_ROLE_ARN", aws_session.get_default_jobs_role()) hyperparameters = hyperparameters or {} hyperparameters = {str(key): str(value) for key, value in hyperparameters.items()} input_data = input_data or {} tags = tags or {} default_bucket = aws_session.default_bucket() - input_data_list = _process_input_data(input_data, job_name, aws_session) + input_data_list = _process_input_data(input_data, job_name, aws_session, timestamp) instance_config = instance_config or InstanceConfig() stopping_condition = stopping_condition or StoppingCondition() output_data_config = output_data_config or OutputDataConfig() @@ -177,6 +178,7 @@ def prepare_quantum_job( default_bucket, "jobs", job_name, + timestamp, "script", ) @@ -201,6 +203,7 @@ def prepare_quantum_job( default_bucket, "jobs", job_name, + timestamp, "data", ) if not checkpoint_config.s3Uri: @@ -208,6 +211,7 @@ def prepare_quantum_job( default_bucket, "jobs", job_name, + timestamp, "checkpoints", ) if copy_checkpoints_from_job: @@ -251,19 +255,22 @@ def prepare_quantum_job( return create_job_kwargs -def _generate_default_job_name(image_uri: str | None = None, func: Callable | None = None) -> str: +def _generate_default_job_name( + image_uri: str | None = None, func: Callable | None = None, timestamp: int | str | None = None +) -> str: """ Generate default job name using the image uri and entrypoint function. Args: image_uri (str | None): URI for the image container. func (Callable | None): The entry point function. + timestamp (int | str | None): Optional timestamp to use instead of generating one. Returns: str: Hybrid job name. """ max_length = 50 - timestamp = str(int(time.time() * 1000)) + timestamp = timestamp if timestamp is not None else str(int(time.time() * 1000)) if func: name = func.__name__.replace("_", "-") @@ -395,7 +402,10 @@ def _validate_params(dict_arr: dict[str, tuple[any, any]]) -> None: def _process_input_data( - input_data: str | dict | S3DataSourceConfig, job_name: str, aws_session: AwsSession + input_data: str | dict | S3DataSourceConfig, + job_name: str, + aws_session: AwsSession, + subdirectory: str, ) -> list[dict[str, Any]]: """ Convert input data into a list of dicts compatible with the Braket API. @@ -405,6 +415,7 @@ def _process_input_data( can be an S3DataSourceConfig or a str corresponding to a local prefix or S3 prefix. job_name (str): Hybrid job name. aws_session (AwsSession): AwsSession for possibly uploading local data. + subdirectory (str): Subdirectory within job name for S3 locations. Returns: list[dict[str, Any]]: A list of channel configs. @@ -413,12 +424,18 @@ def _process_input_data( input_data = {"input": input_data} for channel_name, data in input_data.items(): if not isinstance(data, S3DataSourceConfig): - input_data[channel_name] = _process_channel(data, job_name, aws_session, channel_name) + input_data[channel_name] = _process_channel( + data, job_name, aws_session, channel_name, subdirectory + ) return _convert_input_to_config(input_data) def _process_channel( - location: str, job_name: str, aws_session: AwsSession, channel_name: str + location: str, + job_name: str, + aws_session: AwsSession, + channel_name: str, + subdirectory: str, ) -> S3DataSourceConfig: """ Convert a location to an S3DataSourceConfig, uploading local data to S3, if necessary. @@ -427,6 +444,7 @@ def _process_channel( job_name (str): Hybrid job name. aws_session (AwsSession): AwsSession to be used for uploading local data. channel_name (str): Name of the channel. + subdirectory (str): Subdirectory within job name for S3 locations. Returns: S3DataSourceConfig: S3DataSourceConfig for the channel. @@ -435,10 +453,16 @@ def _process_channel( return S3DataSourceConfig(location) else: # local prefix "path/to/prefix" will be mapped to - # s3://bucket/jobs/job-name/data/input/prefix + # s3://bucket/jobs/job-name/subdirectory/data/input/prefix location_name = Path(location).name s3_prefix = AwsSession.construct_s3_uri( - aws_session.default_bucket(), "jobs", job_name, "data", channel_name, location_name + aws_session.default_bucket(), + "jobs", + job_name, + subdirectory, + "data", + channel_name, + location_name, ) aws_session.upload_local_data(location, s3_prefix) return S3DataSourceConfig(s3_prefix) diff --git a/test/integ_tests/test_create_quantum_job.py b/test/integ_tests/test_create_quantum_job.py index b1ef2f12b..a8aaf4ca5 100644 --- a/test/integ_tests/test_create_quantum_job.py +++ b/test/integ_tests/test_create_quantum_job.py @@ -45,8 +45,9 @@ def test_failed_quantum_job(aws_session, capsys, failed_quantum_job): """ job = failed_quantum_job job_name = job.name - pattern = f"^arn:aws:braket:{aws_session.region}:\\d12:job/{job_name}$" - re.match(pattern=pattern, string=job.arn) + + pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$" + assert re.match(pattern=pattern, string=job.arn) # Check job is in failed state. while True: @@ -57,11 +58,17 @@ def test_failed_quantum_job(aws_session, capsys, failed_quantum_job): # Check whether the respective folder with files are created for script, # output, tasks and checkpoints. + job_name = job.name + s3_bucket = aws_session.default_bucket() + subdirectory = re.match( + rf"s3://{s3_bucket}/jobs/{job.name}/(\d+)/script/source.tar.gz", + job.metadata()["algorithmSpecification"]["scriptModeConfig"]["s3Uri"], + ).group(1) keys = aws_session.list_keys( - bucket=f"amazon-braket-{aws_session.region}-{aws_session.account_id}", - prefix=f"jobs/{job_name}", + bucket=s3_bucket, + prefix=f"jobs/{job_name}/{subdirectory}/", ) - assert keys == [f"jobs/{job_name}/script/source.tar.gz"] + assert keys == [f"jobs/{job_name}/{subdirectory}/script/source.tar.gz"] # no results saved assert job.result() == {} @@ -97,8 +104,8 @@ def test_completed_quantum_job(aws_session, capsys, completed_quantum_job): job = completed_quantum_job job_name = job.name - pattern = f"^arn:aws:braket:{aws_session.region}:\\d12:job/{job_name}$" - re.match(pattern=pattern, string=job.arn) + pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$" + assert re.match(pattern=pattern, string=job.arn) # check job is in completed state. while True: @@ -109,24 +116,36 @@ def test_completed_quantum_job(aws_session, capsys, completed_quantum_job): # Check whether the respective folder with files are created for script, # output, tasks and checkpoints. - s3_bucket = f"amazon-braket-{aws_session.region}-{aws_session.account_id}" + job_name = job.name + s3_bucket = aws_session.default_bucket() + subdirectory = re.match( + rf"s3://{s3_bucket}/jobs/{job.name}/(\d+)/script/source.tar.gz", + job.metadata()["algorithmSpecification"]["scriptModeConfig"]["s3Uri"], + ).group(1) keys = aws_session.list_keys( bucket=s3_bucket, - prefix=f"jobs/{job_name}", + prefix=f"jobs/{job_name}/{subdirectory}/", ) for expected_key in [ - f"jobs/{job_name}/script/source.tar.gz", - f"jobs/{job_name}/data/output/model.tar.gz", - f"jobs/{job_name}/tasks/[^/]*/results.json", - f"jobs/{job_name}/checkpoints/{job_name}_plain_data.json", - f"jobs/{job_name}/checkpoints/{job_name}.json", + f"jobs/{job_name}/{subdirectory}/script/source.tar.gz", + f"jobs/{job_name}/{subdirectory}/data/output/model.tar.gz", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}_plain_data.json", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}.json", ]: assert any(re.match(expected_key, key) for key in keys) + # Check that tasks exist in the correct location + tasks_keys = aws_session.list_keys( + bucket=s3_bucket, + prefix=f"jobs/{job_name}/tasks/", + ) + expected_task_location = f"jobs/{job_name}/tasks/[^/]*/results.json" + assert any(re.match(expected_task_location, key) for key in tasks_keys) + # Check if checkpoint is uploaded in requested format. for s3_key, expected_data in [ ( - f"jobs/{job_name}/checkpoints/{job_name}_plain_data.json", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}_plain_data.json", { "braketSchemaHeader": { "name": "braket.jobs_data.persisted_job_data", @@ -137,7 +156,7 @@ def test_completed_quantum_job(aws_session, capsys, completed_quantum_job): }, ), ( - f"jobs/{job_name}/checkpoints/{job_name}.json", + f"jobs/{job_name}/{subdirectory}/checkpoints/{job_name}.json", { "braketSchemaHeader": { "name": "braket.jobs_data.persisted_job_data", diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index 3a36d8e75..7f9dc1a84 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -93,6 +93,7 @@ def _get_job_response(**kwargs): "jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446", "jobName": "job-test-20210628140446", "outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/data"}, + "queueInfo": {"position": "1", "queue": "JOBS_QUEUE"}, "roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole", "status": "RUNNING", "stoppingCondition": {"maxRuntimeInSeconds": 1200}, @@ -554,8 +555,9 @@ def test_arn(quantum_job_arn, aws_session): assert quantum_job.arn == quantum_job_arn -def test_name(quantum_job_arn, quantum_job_name, aws_session): +def test_name(quantum_job_arn, quantum_job_name, aws_session, generate_get_job_response): quantum_job = AwsQuantumJob(quantum_job_arn, aws_session) + aws_session.get_job.return_value = generate_get_job_response(jobName=quantum_job_name) assert quantum_job.name == quantum_job_name @@ -719,6 +721,14 @@ def test_logs( generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), generate_get_job_response(status="COMPLETED"), ) quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses @@ -739,6 +749,48 @@ def test_logs( ) +def test_logs_queue_progress( + quantum_job, + generate_get_job_response, + log_events_responses, + log_stream_responses, + capsys, +): + queue_info = {"queue": "JOBS_QUEUE", "position": "1"} + quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="QUEUED", queue_info=queue_info), + generate_get_job_response(status="QUEUED", queue_info=queue_info), + generate_get_job_response(status="QUEUED", queue_info=queue_info), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + ) + quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses + quantum_job._aws_session.get_log_events.side_effect = log_events_responses + + quantum_job.logs(wait=True, poll_interval_seconds=0) + + captured = capsys.readouterr() + assert captured.out == "\n".join( + ( + f"Job queue position: {queue_info['position']}", + "Running:", + "", + "hi there #1", + "hi there #2", + "hi there #2a", + "hi there #3", + "", + ) + ) + + @patch.dict("os.environ", {"JPY_PARENT_PID": "True"}) def test_logs_multiple_instances( quantum_job, @@ -752,6 +804,15 @@ def test_logs_multiple_instances( generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="RUNNING"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), + generate_get_job_response(status="COMPLETED"), generate_get_job_response(status="COMPLETED"), ) log_stream_responses[-1]["logStreams"].append({"logStreamName": "stream-2"}) @@ -817,6 +878,7 @@ def get_log_events(log_group, log_stream, start_time, start_from_head, next_toke def test_logs_error(quantum_job, generate_get_job_response, capsys): quantum_job._aws_session.get_job.side_effect = ( + generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="RUNNING"), generate_get_job_response(status="COMPLETED"), diff --git a/test/unit_tests/braket/aws/test_aws_quantum_task.py b/test/unit_tests/braket/aws/test_aws_quantum_task.py index 656c37dcf..e96af57f5 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_task.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_task.py @@ -83,6 +83,11 @@ def quantum_task(aws_session): return AwsQuantumTask("foo:bar:arn", aws_session, poll_timeout_seconds=2) +@pytest.fixture +def quantum_task_quiet(aws_session): + return AwsQuantumTask("foo:bar:arn", aws_session, poll_timeout_seconds=2, quiet=True) + + @pytest.fixture def circuit_task(aws_session): return AwsQuantumTask("foo:bar:arn", aws_session, poll_timeout_seconds=2) @@ -243,6 +248,23 @@ def test_queue_position(quantum_task): ) +def test_queued_quiet(quantum_task_quiet): + state_1 = "QUEUED" + _mock_metadata(quantum_task_quiet._aws_session, state_1) + assert quantum_task_quiet.queue_position() == QuantumTaskQueueInfo( + queue_type=QueueType.NORMAL, queue_position="2", message=None + ) + + state_2 = "COMPLETED" + message = ( + f"'Task is in {state_2} status. AmazonBraket does not show queue position for this status.'" + ) + _mock_metadata(quantum_task_quiet._aws_session, state_2) + assert quantum_task_quiet.queue_position() == QuantumTaskQueueInfo( + queue_type=QueueType.NORMAL, queue_position=None, message=message + ) + + def test_state(quantum_task): state_1 = "RUNNING" _mock_metadata(quantum_task._aws_session, state_1) @@ -432,6 +454,43 @@ def set_result_from_callback(future): assert result_from_future == result +@pytest.mark.parametrize( + "status, result", + [ + ("COMPLETED", GateModelQuantumTaskResult.from_string(MockS3.MOCK_S3_RESULT_GATE_MODEL)), + ("FAILED", None), + ], +) +def test_async_result_queued(circuit_task, status, result): + def set_result_from_callback(future): + # Set the result_from_callback variable in the enclosing functions scope + nonlocal result_from_callback + result_from_callback = future.result() + + _mock_metadata(circuit_task._aws_session, "QUEUED") + _mock_s3(circuit_task._aws_session, MockS3.MOCK_S3_RESULT_GATE_MODEL) + + future = circuit_task.async_result() + + # test the different ways to get the result from async + + # via callback + result_from_callback = None + future.add_done_callback(set_result_from_callback) + + # via asyncio waiting for result + _mock_metadata(circuit_task._aws_session, status) + event_loop = asyncio.get_event_loop() + result_from_waiting = event_loop.run_until_complete(future) + + # via future.result(). Note that this would fail if the future is not complete. + result_from_future = future.result() + + assert result_from_callback == result + assert result_from_waiting == result + assert result_from_future == result + + def test_failed_task(quantum_task): _mock_metadata(quantum_task._aws_session, "FAILED") _mock_s3(quantum_task._aws_session, MockS3.MOCK_S3_RESULT_GATE_MODEL) diff --git a/test/unit_tests/braket/jobs/test_quantum_job_creation.py b/test/unit_tests/braket/jobs/test_quantum_job_creation.py index bef4fd643..8cd1fbca9 100644 --- a/test/unit_tests/braket/jobs/test_quantum_job_creation.py +++ b/test/unit_tests/braket/jobs/test_quantum_job_creation.py @@ -323,8 +323,9 @@ def _translate_creation_args(create_job_args): image_uri = create_job_args["image_uri"] job_name = create_job_args["job_name"] or _generate_default_job_name(image_uri) default_bucket = aws_session.default_bucket() + timestamp = str(int(time.time() * 1000)) code_location = create_job_args["code_location"] or AwsSession.construct_s3_uri( - default_bucket, "jobs", job_name, "script" + default_bucket, "jobs", job_name, timestamp, "script" ) role_arn = create_job_args["role_arn"] or aws_session.get_default_jobs_role() device = create_job_args["device"] @@ -340,11 +341,13 @@ def _translate_creation_args(create_job_args): } hyperparameters.update(distributed_hyperparams) output_data_config = create_job_args["output_data_config"] or OutputDataConfig( - s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "data") + s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, timestamp, "data") ) stopping_condition = create_job_args["stopping_condition"] or StoppingCondition() checkpoint_config = create_job_args["checkpoint_config"] or CheckpointConfig( - s3Uri=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "checkpoints") + s3Uri=AwsSession.construct_s3_uri( + default_bucket, "jobs", job_name, timestamp, "checkpoints" + ) ) entry_point = create_job_args["entry_point"] source_module = create_job_args["source_module"] @@ -365,7 +368,7 @@ def _translate_creation_args(create_job_args): "jobName": job_name, "roleArn": role_arn, "algorithmSpecification": algorithm_specification, - "inputDataConfig": _process_input_data(input_data, job_name, aws_session), + "inputDataConfig": _process_input_data(input_data, job_name, aws_session, timestamp), "instanceConfig": asdict(instance_config), "outputDataConfig": asdict(output_data_config, dict_factory=_exclude_nones_factory), "checkpointConfig": asdict(checkpoint_config), @@ -403,6 +406,7 @@ def test_generate_default_job_name(mock_time, image_uri): mock_time.return_value = datetime.datetime.now().timestamp() timestamp = str(int(time.time() * 1000)) assert _generate_default_job_name(image_uri) == f"braket-job{job_type}-{timestamp}" + assert _generate_default_job_name(image_uri, timestamp="ts") == f"braket-job{job_type}-ts" @pytest.mark.parametrize( @@ -602,7 +606,7 @@ def test_invalid_input_parameters(entry_point, aws_session): "channelName": "input", "dataSource": { "s3DataSource": { - "s3Uri": "s3://default-bucket-name/jobs/job-name/data/input/prefix", + "s3Uri": "s3://default-bucket-name/jobs/job-name/ts/data/input/prefix", }, }, } @@ -651,7 +655,7 @@ def test_invalid_input_parameters(entry_point, aws_session): "channelName": "local-input", "dataSource": { "s3DataSource": { - "s3Uri": "s3://default-bucket-name/jobs/job-name/" + "s3Uri": "s3://default-bucket-name/jobs/job-name/ts/" "data/local-input/prefix", }, }, @@ -678,4 +682,4 @@ def test_invalid_input_parameters(entry_point, aws_session): ) def test_process_input_data(aws_session, input_data, input_data_configs): job_name = "job-name" - assert _process_input_data(input_data, job_name, aws_session) == input_data_configs + assert _process_input_data(input_data, job_name, aws_session, "ts") == input_data_configs diff --git a/tox.ini b/tox.ini index 92ccad2a9..26a5a43db 100644 --- a/tox.ini +++ b/tox.ini @@ -66,7 +66,6 @@ basepython = python3 skip_install = true deps = flake8 - flake8-rst-docstrings git+https://github.com/amazon-braket/amazon-braket-build-tools.git commands = flake8 --extend-exclude src {posargs}