Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update to 3.9 syntax #735

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions src/braket/aws/aws_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from enum import Enum
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any

import boto3
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -70,16 +70,16 @@ def create(
code_location: str = None,
role_arn: str = None,
wait_until_complete: bool = False,
hyperparameters: Dict[str, Any] = None,
input_data: Union[str, Dict, S3DataSourceConfig] = None,
hyperparameters: dict[str, Any] = None,
input_data: str | dict | S3DataSourceConfig = None,
instance_config: InstanceConfig = None,
distribution: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
aws_session: AwsSession = None,
tags: Dict[str, str] = None,
tags: dict[str, str] = None,
logger: Logger = getLogger(__name__),
) -> AwsQuantumJob:
"""Creates a hybrid job by invoking the Braket CreateJob API.
Expand Down Expand Up @@ -121,12 +121,12 @@ def create(
This would tail the hybrid job logs as it waits. Otherwise `False`.
Default: `False`.

hyperparameters (Dict[str, Any]): Hyperparameters accessible to the hybrid job.
The hyperparameters are made accessible as a Dict[str, str] to the hybrid job.
hyperparameters (dict[str, Any]): Hyperparameters accessible to the hybrid job.
The hyperparameters are made accessible as a dict[str, str] to the hybrid job.
For convenience, this accepts other types for keys and values, but `str()`
is called to convert them before being passed on. Default: None.

input_data (Union[str, Dict, S3DataSourceConfig]): Information about the training
input_data (str | dict | S3DataSourceConfig): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
Expand Down Expand Up @@ -166,7 +166,7 @@ def create(
aws_session (AwsSession): AwsSession for connecting to AWS Services.
Default: AwsSession()

tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this hybrid job.
tags (dict[str, str]): Dict specifying the key-value pairs for tagging this hybrid job.
Default: {}.

logger (Logger): Logger object with which to write logs, such as quantum task statuses
Expand Down Expand Up @@ -386,7 +386,7 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None:
elif self.state() in AwsQuantumJob.TERMINAL_STATES:
log_state = AwsQuantumJob.LogState.JOB_COMPLETE

def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
def metadata(self, use_cached_value: bool = False) -> dict[str, Any]:
"""Gets the hybrid job metadata defined in Amazon Braket.

Args:
Expand All @@ -395,7 +395,7 @@ def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
`GetJob` is called to retrieve the metadata. If `False`, always calls
`GetJob`, which also updates the cached value. Default: `False`.
Returns:
Dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket.
dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket.
"""
if not use_cached_value or not self._metadata:
self._metadata = self._aws_session.get_job(self._arn)
Expand All @@ -405,7 +405,7 @@ def metrics(
self,
metric_type: MetricType = MetricType.TIMESTAMP,
statistic: MetricStatistic = MetricStatistic.MAX,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
"""Gets all the metrics data, where the keys are the column names, and the values are a list
containing the values in each row. For example, the table:
timestamp energy
Expand All @@ -422,7 +422,7 @@ def metrics(
when there is a conflict. Default: MetricStatistic.MAX.

Returns:
Dict[str, List[Any]] : The metrics data.
dict[str, list[Any]] : The metrics data.
"""
fetcher = CwlInsightsMetricsFetcher(self._aws_session)
metadata = self.metadata(True)
Expand Down Expand Up @@ -451,7 +451,7 @@ def result(
self,
poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Retrieves the hybrid job result persisted using save_job_result() function.

Args:
Expand All @@ -461,7 +461,7 @@ def result(
Default: 5 seconds.

Returns:
Dict[str, Any]: Dict specifying the job results.
dict[str, Any]: Dict specifying the job results.

Raises:
RuntimeError: if hybrid job is in a FAILED or CANCELLED state.
Expand All @@ -481,7 +481,7 @@ def result(
return AwsQuantumJob._read_and_deserialize_results(temp_dir, job_name)

@staticmethod
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> Dict[str, Any]:
def _read_and_deserialize_results(temp_dir: str, job_name: str) -> dict[str, Any]:
return load_job_result(Path(temp_dir, job_name, AwsQuantumJob.RESULTS_FILENAME))

def download_result(
Expand Down Expand Up @@ -566,9 +566,7 @@ def __hash__(self) -> int:
return hash(self.arn)

@staticmethod
def _initialize_session(
session_value: AwsSession, device: AwsDevice, logger: Logger
) -> AwsSession:
def _initialize_session(session_value: AwsSession, device: str, logger: Logger) -> AwsSession:
aws_session = session_value or AwsSession()
if device.startswith("local:"):
return aws_session
Expand Down
13 changes: 7 additions & 6 deletions src/braket/jobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional


@dataclass
class CheckpointConfig:
"""Configuration that specifies the location where checkpoint data is stored."""

localPath: str = "/opt/jobs/checkpoints"
s3Uri: Optional[str] = None
s3Uri: str | None = None


@dataclass
Expand All @@ -36,8 +37,8 @@ class InstanceConfig:
class OutputDataConfig:
"""Configuration that specifies the location for the output of the hybrid job."""

s3Path: Optional[str] = None
kmsKeyId: Optional[str] = None
s3Path: str | None = None
kmsKeyId: str | None = None


@dataclass
Expand All @@ -61,8 +62,8 @@ class S3DataSourceConfig:

def __init__(
self,
s3_data,
content_type=None,
s3_data: str,
content_type: str | None = None,
):
"""Create a definition for input data used by a Braket Hybrid job.

Expand Down
31 changes: 17 additions & 14 deletions src/braket/jobs/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any

from braket.jobs.environment_variables import get_checkpoint_dir, get_job_name, get_results_dir
from braket.jobs.serialization import deserialize_values, serialize_values
from braket.jobs_data import PersistedJobData, PersistedJobDataFormat


def save_job_checkpoint(
checkpoint_data: Dict[str, Any],
checkpoint_data: dict[str, Any],
checkpoint_file_suffix: str = "",
data_format: PersistedJobDataFormat = PersistedJobDataFormat.PLAINTEXT,
) -> None:
Expand All @@ -35,7 +38,7 @@ def save_job_checkpoint(


Args:
checkpoint_data (Dict[str, Any]): Dict that specifies the checkpoint data to be persisted.
checkpoint_data (dict[str, Any]): Dict that specifies the checkpoint data to be persisted.
checkpoint_file_suffix (str): str that specifies the file suffix to be used for
the checkpoint filename. The resulting filename
`f"{job_name}(_{checkpoint_file_suffix}).json"` is used to save the checkpoints.
Expand Down Expand Up @@ -63,8 +66,8 @@ def save_job_checkpoint(


def load_job_checkpoint(
job_name: Optional[str] = None, checkpoint_file_suffix: str = ""
) -> Dict[str, Any]:
job_name: str | None = None, checkpoint_file_suffix: str = ""
) -> dict[str, Any]:
"""
Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint
file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose
Expand All @@ -77,7 +80,7 @@ def load_job_checkpoint(


Args:
job_name (Optional[str]): str that specifies the name of the job whose checkpoints
job_name (str | None): str that specifies the name of the job whose checkpoints
are to be loaded. Default: current job name.

checkpoint_file_suffix (str): str specifying the file suffix that is used to
Expand All @@ -86,7 +89,7 @@ def load_job_checkpoint(
checkpoint file. Default: ""

Returns:
Dict[str, Any]: Dict that contains the checkpoint data persisted in the checkpoint file.
dict[str, Any]: Dict that contains the checkpoint data persisted in the checkpoint file.

Raises:
FileNotFoundError: If the file `f"{job_name}(_{checkpoint_file_suffix})"` could not be found
Expand All @@ -109,7 +112,7 @@ def load_job_checkpoint(
return deserialized_data


def _load_persisted_data(filename: Union[str, Path] = None) -> PersistedJobData:
def _load_persisted_data(filename: str | Path = None) -> PersistedJobData:
filename = filename or Path(get_results_dir()) / "results.json"
try:
with open(filename, mode="r") as f:
Expand All @@ -121,25 +124,25 @@ def _load_persisted_data(filename: Union[str, Path] = None) -> PersistedJobData:
)


def load_job_result(filename: Union[str, Path] = None) -> Dict[str, Any]:
def load_job_result(filename: str | Path = None) -> dict[str, Any]:
"""
Loads job result of currently running job.

Args:
filename (Union[str, Path]): Location of job results. Default `results.json` in job
filename (str | Path): Location of job results. Default `results.json` in job
results directory in a job instance or in working directory locally. This file
must be in the format used by `save_job_result`.

Returns:
Dict[str, Any]: Job result data of current job
dict[str, Any]: Job result data of current job
ajberdy marked this conversation as resolved.
Show resolved Hide resolved
"""
persisted_data = _load_persisted_data(filename)
deserialized_data = deserialize_values(persisted_data.dataDictionary, persisted_data.dataFormat)
return deserialized_data


def save_job_result(
result_data: Union[Dict[str, Any], Any],
result_data: dict[str, Any] | Any,
data_format: PersistedJobDataFormat = None,
) -> None:
"""
Expand All @@ -152,7 +155,7 @@ def save_job_result(


Args:
result_data (Union[Dict[str, Any], Any]): Dict that specifies the result data to be
result_data (dict[str, Any] | Any): Dict that specifies the result data to be
persisted. If result data is not a dict, then it will be wrapped as
`{"result": result_data}`.
data_format (PersistedJobDataFormat): The data format used to serialize the
Expand Down Expand Up @@ -183,7 +186,7 @@ def save_job_result(
current_persisted_data.dataDictionary,
current_persisted_data.dataFormat,
)
updated_results = {**current_results, **result_data}
updated_results = current_results | result_data

with open(Path(get_results_dir()) / "results.json", "w") as f:
serialized_data = serialize_values(updated_results or {}, data_format)
Expand Down
18 changes: 15 additions & 3 deletions src/braket/jobs/environment_variables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
import os
from typing import Dict


def get_job_name() -> str:
Expand Down Expand Up @@ -60,12 +72,12 @@ def get_checkpoint_dir() -> str:
return os.getenv("AMZN_BRAKET_CHECKPOINT_DIR", ".")


def get_hyperparameters() -> Dict[str, str]:
def get_hyperparameters() -> dict[str, str]:
"""
Get the job hyperparameters as a dict, with the values stringified.

Returns:
Dict[str, str]: The hyperparameters of the job.
dict[str, str]: The hyperparameters of the job.
"""
if "AMZN_BRAKET_HP_FILE" in os.environ:
with open(os.getenv("AMZN_BRAKET_HP_FILE"), "r") as f:
Expand Down
Loading