Skip to content

Commit

Permalink
adding keyword only args
Browse files Browse the repository at this point in the history
  • Loading branch information
krneta committed Nov 2, 2023
1 parent 35c0328 commit c7dbc0f
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions src/braket/jobs/hybrid_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@
def hybrid_job(
*,
device: str,
include_modules: str | ModuleType | Iterable[str | ModuleType] = None,
dependencies: str | Path | list[str] = None,
include_modules: str | ModuleType | Iterable[str | ModuleType] | None = None,
dependencies: str | Path | list[str] | None = None,
local: bool = False,
job_name: str = None,
image_uri: str = None,
input_data: str | dict | S3DataSourceConfig = None,
job_name: str | None = None,
image_uri: str | None = None,
input_data: str | dict | S3DataSourceConfig | None = None,
wait_until_complete: bool = False,
instance_config: InstanceConfig = None,
distribution: str = None,
copy_checkpoints_from_job: str = None,
checkpoint_config: CheckpointConfig = None,
role_arn: str = None,
stopping_condition: StoppingCondition = None,
output_data_config: OutputDataConfig = None,
aws_session: AwsSession = None,
tags: dict[str, str] = None,
instance_config: InstanceConfig | None = None,
distribution: str | None = None,
copy_checkpoints_from_job: str | None = None,
checkpoint_config: CheckpointConfig | None = None,
role_arn: str | None = None,
stopping_condition: StoppingCondition | None = None,
output_data_config: OutputDataConfig | None = None,
aws_session: AwsSession | None = None,
tags: dict[str, str] | None = None,
logger: Logger = getLogger(__name__),
) -> Callable:
"""Defines a hybrid job by decorating the entry point function. The job will be created
Expand All @@ -80,27 +80,27 @@ def hybrid_job(
When using embedded simulators, you may provide the device argument as string of the
form: "local:<provider>/<simulator_name>" or `None`.
include_modules (str | ModuleType | Iterable[str | ModuleType]): Either a
include_modules (str | ModuleType | Iterable[str | ModuleType] | None): Either a
single module or module name or a list of module or module names referring to local
modules to be included. Any references to members of these modules in the hybrid job
algorithm code will be serialized as part of the algorithm code. Default: `[]`
dependencies (str | Path | list[str]): Path (absolute or relative) to a requirements.txt
file, or alternatively a list of strings, with each string being a `requirement
specifier <https://pip.pypa.io/en/stable/reference/requirement-specifiers/
dependencies (str | Path | list[str] | None): Path (absolute or relative) to a
requirements.txt file, or alternatively a list of strings, with each string being a
`requirement specifier <https://pip.pypa.io/en/stable/reference/requirement-specifiers/
#requirement-specifiers>`_, to be used for the hybrid job.
local (bool): Whether to use local mode for the hybrid job. Default: `False`
job_name (str): A string that specifies the name with which the job is created.
job_name (str | None): A string that specifies the name with which the job is created.
Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to
f'{decorated-function-name}-{timestamp}'.
image_uri (str): A str that specifies the ECR image to use for executing the job.
image_uri (str | None): A str that specifies the ECR image to use for executing the job.
`retrieve_image()` function may be used for retrieving the ECR image URIs
for the containers supported by Braket. Default: `<Braket base image_uri>`.
input_data (str | dict | S3DataSourceConfig): Information about the training
input_data (str | dict | S3DataSourceConfig | None): Information about the training
data. Dictionary maps channel names to local paths or S3 URIs. Contents found
at any local paths will be uploaded to S3 at
f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}'. If a local
Expand All @@ -112,41 +112,41 @@ def hybrid_job(
This would tail the job logs as it waits. Otherwise `False`. Ignored if using
local mode. Default: `False`.
instance_config (InstanceConfig): Configuration of the instance(s) for running the
instance_config (InstanceConfig | None): Configuration of the instance(s) for running the
classical code for the hybrid job. Default:
`InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`.
distribution (str): A str that specifies how the job should be distributed.
distribution (str | None): A str that specifies how the job should be distributed.
If set to "data_parallel", the hyperparameters for the job will be set to use data
parallelism features for PyTorch or TensorFlow. Default: `None`.
copy_checkpoints_from_job (str): A str that specifies the job ARN whose
copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose
checkpoint you want to use in the current job. Specifying this value will copy
over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config
s3Uri to the current job's checkpoint_config s3Uri, making it available at
checkpoint_config.localPath during the job execution. Default: `None`
checkpoint_config (CheckpointConfig): Configuration that specifies the
checkpoint_config (CheckpointConfig | None): Configuration that specifies the
location where checkpoint data is stored.
Default: `CheckpointConfig(localPath='/opt/jobs/checkpoints',
s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints')`.
role_arn (str): A str providing the IAM role ARN used to execute the
role_arn (str | None): A str providing the IAM role ARN used to execute the
script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`.
stopping_condition (StoppingCondition): The maximum length of time, in seconds,
stopping_condition (StoppingCondition | None): The maximum length of time, in seconds,
and the maximum number of tasks that a job can run before being forcefully stopped.
Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60).
output_data_config (OutputDataConfig): Specifies the location for the output of
output_data_config (OutputDataConfig | None): Specifies the location for the output of
the job.
Default: `OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
kmsKeyId=None)`.
aws_session (AwsSession): AwsSession for connecting to AWS Services.
aws_session (AwsSession | None): AwsSession for connecting to AWS Services.
Default: AwsSession()
tags (dict[str, str]): Dict specifying the key-value pairs for tagging this job.
tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job.
Default: {}.
logger (Logger): Logger object with which to write logs, such as task statuses
Expand Down

0 comments on commit c7dbc0f

Please sign in to comment.