diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index da0b8436a..bb92bdb2e 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -46,22 +46,22 @@ def hybrid_job( *, device: str, - include_modules: str | ModuleType | Iterable[str | ModuleType] = None, - dependencies: str | Path | list[str] = None, + include_modules: str | ModuleType | Iterable[str | ModuleType] | None = None, + dependencies: str | Path | list[str] | None = None, local: bool = False, - job_name: str = None, - image_uri: str = None, - input_data: str | dict | S3DataSourceConfig = None, + job_name: str | None = None, + image_uri: str | None = None, + input_data: str | dict | S3DataSourceConfig | None = None, wait_until_complete: bool = False, - instance_config: InstanceConfig = None, - distribution: str = None, - copy_checkpoints_from_job: str = None, - checkpoint_config: CheckpointConfig = None, - role_arn: str = None, - stopping_condition: StoppingCondition = None, - output_data_config: OutputDataConfig = None, - aws_session: AwsSession = None, - tags: dict[str, str] = None, + instance_config: InstanceConfig | None = None, + distribution: str | None = None, + copy_checkpoints_from_job: str | None = None, + checkpoint_config: CheckpointConfig | None = None, + role_arn: str | None = None, + stopping_condition: StoppingCondition | None = None, + output_data_config: OutputDataConfig | None = None, + aws_session: AwsSession | None = None, + tags: dict[str, str] | None = None, logger: Logger = getLogger(__name__), ) -> Callable: """Defines a hybrid job by decorating the entry point function. The job will be created @@ -80,27 +80,27 @@ def hybrid_job( When using embedded simulators, you may provide the device argument as string of the form: "local:/" or `None`. - include_modules (str | ModuleType | Iterable[str | ModuleType]): Either a + include_modules (str | ModuleType | Iterable[str | ModuleType] | None): Either a single module or module name or a list of module or module names referring to local modules to be included. Any references to members of these modules in the hybrid job algorithm code will be serialized as part of the algorithm code. Default: `[]` - dependencies (str | Path | list[str]): Path (absolute or relative) to a requirements.txt - file, or alternatively a list of strings, with each string being a `requirement - specifier `_, to be used for the hybrid job. local (bool): Whether to use local mode for the hybrid job. Default: `False` - job_name (str): A string that specifies the name with which the job is created. + job_name (str | None): A string that specifies the name with which the job is created. Allowed pattern for job name: `^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,50}$`. Defaults to f'{decorated-function-name}-{timestamp}'. - image_uri (str): A str that specifies the ECR image to use for executing the job. + image_uri (str | None): A str that specifies the ECR image to use for executing the job. `retrieve_image()` function may be used for retrieving the ECR image URIs for the containers supported by Braket. Default: ``. - input_data (str | dict | S3DataSourceConfig): Information about the training + input_data (str | dict | S3DataSourceConfig | None): Information about the training data. Dictionary maps channel names to local paths or S3 URIs. Contents found at any local paths will be uploaded to S3 at f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}'. If a local @@ -112,41 +112,41 @@ def hybrid_job( This would tail the job logs as it waits. Otherwise `False`. Ignored if using local mode. Default: `False`. - instance_config (InstanceConfig): Configuration of the instance(s) for running the + instance_config (InstanceConfig | None): Configuration of the instance(s) for running the classical code for the hybrid job. Default: `InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30)`. - distribution (str): A str that specifies how the job should be distributed. + distribution (str | None): A str that specifies how the job should be distributed. If set to "data_parallel", the hyperparameters for the job will be set to use data parallelism features for PyTorch or TensorFlow. Default: `None`. - copy_checkpoints_from_job (str): A str that specifies the job ARN whose + copy_checkpoints_from_job (str | None): A str that specifies the job ARN whose checkpoint you want to use in the current job. Specifying this value will copy over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config s3Uri to the current job's checkpoint_config s3Uri, making it available at checkpoint_config.localPath during the job execution. Default: `None` - checkpoint_config (CheckpointConfig): Configuration that specifies the + checkpoint_config (CheckpointConfig | None): Configuration that specifies the location where checkpoint data is stored. Default: `CheckpointConfig(localPath='/opt/jobs/checkpoints', s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints')`. - role_arn (str): A str providing the IAM role ARN used to execute the + role_arn (str | None): A str providing the IAM role ARN used to execute the script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. - stopping_condition (StoppingCondition): The maximum length of time, in seconds, + stopping_condition (StoppingCondition | None): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). - output_data_config (OutputDataConfig): Specifies the location for the output of + output_data_config (OutputDataConfig | None): Specifies the location for the output of the job. Default: `OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', kmsKeyId=None)`. - aws_session (AwsSession): AwsSession for connecting to AWS Services. + aws_session (AwsSession | None): AwsSession for connecting to AWS Services. Default: AwsSession() - tags (dict[str, str]): Dict specifying the key-value pairs for tagging this job. + tags (dict[str, str] | None): Dict specifying the key-value pairs for tagging this job. Default: {}. logger (Logger): Logger object with which to write logs, such as task statuses