Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add queue position to the logs for tasks and jobs #821

Merged
merged 10 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create(
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
quiet: bool = False,
reservation_arn: str | None = None,
) -> AwsQuantumJob:
"""Creates a hybrid job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -176,6 +177,9 @@ def create(
while waiting for quantum task to be in a terminal state. Default is
`getLogger(__name__)`

quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.

reservation_arn (str | None): the reservation window arn provided by Braket
Direct to reserve exclusive usage for the device to run the hybrid job on.
Default: None.
Expand Down Expand Up @@ -210,23 +214,26 @@ def create(
)

job_arn = aws_session.create_job(**create_job_kwargs)
job = AwsQuantumJob(job_arn, aws_session)
job = AwsQuantumJob(job_arn, aws_session, quiet)

if wait_until_complete:
print(f"Initializing Braket Job: {job_arn}")
job.logs(wait=True)

return job

def __init__(self, arn: str, aws_session: AwsSession | None = None):
def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool = False):
"""
Args:
arn (str): The ARN of the hybrid job.
aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services.
Default is `None`, in which case an `AwsSession` object will be created with the
region of the hybrid job.
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.
"""
self._arn: str = arn
self._quiet = quiet
if aws_session:
if not self._is_valid_aws_session_region_for_job_arn(aws_session, arn):
raise ValueError(
Expand Down Expand Up @@ -371,10 +378,11 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"]
has_streams = False
color_wrap = logs.ColorWrap()
previous_state = self.state()

while True:
time.sleep(poll_interval_seconds)

current_state = self.state()
has_streams = logs.flush_log_streams(
self._aws_session,
log_group,
Expand All @@ -384,14 +392,17 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
instance_count,
has_streams,
color_wrap,
[previous_state, current_state],
self.queue_position().queue_position if not self._quiet else None,
)
previous_state = current_state

if log_state == AwsQuantumJob.LogState.COMPLETE:
break

if log_state == AwsQuantumJob.LogState.JOB_COMPLETE:
log_state = AwsQuantumJob.LogState.COMPLETE
elif self.state() in AwsQuantumJob.TERMINAL_STATES:
elif current_state in AwsQuantumJob.TERMINAL_STATES:
log_state = AwsQuantumJob.LogState.JOB_COMPLETE

def metadata(self, use_cached_value: bool = False) -> dict[str, Any]:
Expand Down
14 changes: 14 additions & 0 deletions src/braket/aws/aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def create(
tags: dict[str, str] | None = None,
inputs: dict[str, float] | None = None,
gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None,
quiet: bool = False,
reservation_arn: str | None = None,
*args,
**kwargs,
Expand Down Expand Up @@ -152,6 +153,9 @@ def create(
a `PulseSequence`.
Default: None.

quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.

reservation_arn (str | None): The reservation ARN provided by Braket Direct
to reserve exclusive usage for the device to run the quantum task on.
Note: If you are creating tasks in a job that itself was created reservation ARN,
Expand Down Expand Up @@ -215,6 +219,7 @@ def create(
disable_qubit_rewiring,
inputs,
gate_definitions=gate_definitions,
quiet=quiet,
*args,
**kwargs,
)
Expand All @@ -226,6 +231,7 @@ def __init__(
poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL,
logger: Logger = getLogger(__name__),
quiet: bool = False,
):
"""
Args:
Expand All @@ -238,6 +244,8 @@ def __init__(
logger (Logger): Logger object with which to write logs, such as quantum task statuses
while waiting for quantum task to be in a terminal state. Default is
`getLogger(__name__)`
quiet (bool): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.

Examples:
>>> task = AwsQuantumTask(arn='task_arn')
Expand All @@ -259,6 +267,7 @@ def __init__(
self._poll_interval_seconds = poll_interval_seconds

self._logger = logger
self._quiet = quiet

self._metadata: dict[str, Any] = {}
self._result: Union[
Expand Down Expand Up @@ -477,6 +486,11 @@ async def _wait_for_completion(
while (time.time() - start_time) < self._poll_timeout_seconds:
# Used cached metadata if cached status is terminal
task_status = self._update_status_if_nonterminal()
if not self._quiet and task_status == "QUEUED":
queue = self.queue_position()
self._logger.debug(
f"Task is in {queue.queue_type} queue position: {queue.queue_position}"
)
self._logger.debug(f"Task {self._arn}: task status {task_status}")
if task_status in AwsQuantumTask.RESULTS_READY_STATES:
return self._download_result()
Expand Down
7 changes: 6 additions & 1 deletion src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def hybrid_job(
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
quiet: bool | None = None,
reservation_arn: str | None = None,
) -> Callable:
"""Defines a hybrid job by decorating the entry point function. The job will be created
Expand All @@ -71,7 +72,7 @@ def hybrid_job(
The job created will be a `LocalQuantumJob` when `local` is set to `True`, otherwise an
`AwsQuantumJob`. The following parameters will be ignored when running a job with
`local` set to `True`: `wait_until_complete`, `instance_config`, `distribution`,
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, and `logger`.
`copy_checkpoints_from_job`, `stopping_condition`, `tags`, `logger`, and `quiet`.

Args:
device (str | None): Device ARN of the QPU device that receives priority quantum
Expand Down Expand Up @@ -153,6 +154,9 @@ def hybrid_job(
logger (Logger): Logger object with which to write logs, such as task statuses
while waiting for task to be in a terminal state. Default: `getLogger(__name__)`

quiet (bool | None): Sets the verbosity of the logger to low and does not report queue
position. Default is `False`.

reservation_arn (str | None): the reservation window arn provided by Braket
Direct to reserve exclusive usage for the device to run the hybrid job on.
Default: None.
Expand Down Expand Up @@ -210,6 +214,7 @@ def job_wrapper(*args, **kwargs) -> Callable:
"output_data_config": output_data_config,
"aws_session": aws_session,
"tags": tags,
"quiet": quiet,
"reservation_arn": reservation_arn,
}
for key, value in optional_args.items():
Expand Down
13 changes: 11 additions & 2 deletions src/braket/jobs/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Support for reading logs
#
##############################################################################
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -155,7 +155,7 @@ def log_stream(
yield ev


def flush_log_streams(
def flush_log_streams( # noqa C901
aws_session: AwsSession,
log_group: str,
stream_prefix: str,
Expand All @@ -164,6 +164,8 @@ def flush_log_streams(
stream_count: int,
has_streams: bool,
color_wrap: ColorWrap,
state: list[str],
queue_position: Optional[str] = None,
) -> bool:
"""Flushes log streams to stdout.

Expand All @@ -183,6 +185,9 @@ def flush_log_streams(
been found. This value is possibly updated and returned at the end of execution.
color_wrap (ColorWrap): An instance of ColorWrap to potentially color-wrap print statements
from different streams.
state (list[str]): The previous and current state of the job.
queue_position (Optional[str]): The current queue position. This is not passed in if the job
is ran with `quiet=True`

Returns:
bool: Returns 'True' if any streams have been flushed.
Expand Down Expand Up @@ -225,6 +230,10 @@ def flush_log_streams(
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
elif queue_position is not None and state[1] == "QUEUED":
print(f"Job queue position: {queue_position}", end="\n", flush=True)
elif state[0] != state[1] and state[1] == "RUNNING" and queue_position is not None:
print("Running:", end="\n", flush=True)
else:
print(".", end="", flush=True)
return has_streams
61 changes: 61 additions & 0 deletions test/unit_tests/braket/aws/test_aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _get_job_response(**kwargs):
"jobArn": "arn:aws:braket:us-west-2:875981177017:job/job-test-20210628140446",
"jobName": "job-test-20210628140446",
"outputDataConfig": {"s3Path": "s3://amazon-braket-jobs/job-path/data"},
"queueInfo": {"position": "1", "queue": "JOBS_QUEUE"},
"roleArn": "arn:aws:iam::875981177017:role/AmazonBraketJobRole",
"status": "RUNNING",
"stoppingCondition": {"maxRuntimeInSeconds": 1200},
Expand Down Expand Up @@ -720,6 +721,14 @@ def test_logs(
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
)
quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses
Expand All @@ -740,6 +749,48 @@ def test_logs(
)


def test_logs_queue_progress(
quantum_job,
generate_get_job_response,
log_events_responses,
log_stream_responses,
capsys,
):
queue_info = {"queue": "JOBS_QUEUE", "position": "1"}
quantum_job._aws_session.get_job.side_effect = (
generate_get_job_response(status="QUEUED", queue_info=queue_info),
generate_get_job_response(status="QUEUED", queue_info=queue_info),
generate_get_job_response(status="QUEUED", queue_info=queue_info),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
)
quantum_job._aws_session.describe_log_streams.side_effect = log_stream_responses
quantum_job._aws_session.get_log_events.side_effect = log_events_responses

quantum_job.logs(wait=True, poll_interval_seconds=0)

captured = capsys.readouterr()
assert captured.out == "\n".join(
(
f"Job queue position: {queue_info['position']}",
"Running:",
"",
"hi there #1",
"hi there #2",
"hi there #2a",
"hi there #3",
"",
)
)


@patch.dict("os.environ", {"JPY_PARENT_PID": "True"})
def test_logs_multiple_instances(
quantum_job,
Expand All @@ -753,6 +804,15 @@ def test_logs_multiple_instances(
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
generate_get_job_response(status="COMPLETED"),
)
log_stream_responses[-1]["logStreams"].append({"logStreamName": "stream-2"})
Expand Down Expand Up @@ -818,6 +878,7 @@ def get_log_events(log_group, log_stream, start_time, start_from_head, next_toke

def test_logs_error(quantum_job, generate_get_job_response, capsys):
quantum_job._aws_session.get_job.side_effect = (
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="RUNNING"),
generate_get_job_response(status="COMPLETED"),
Expand Down
Loading