diff --git a/.coveragerc b/.coveragerc index 79545ce2b..fe9662c0a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,7 +3,7 @@ parallel = True branch = True source = src -omit = +omit = **/braket/ir/* **/braket/device_schema/* **/braket/schema_common/* @@ -23,9 +23,15 @@ exclude_lines = # Have to re-enable the standard pragma pragma: no cover + # Skipping import testing + from importlib.metadata import entry_points + # Don't complain if tests don't hit defensive assertion code: raise NotImplementedError + # Avoid situation where system version causes coverage issues + if sys.version_info.minor == 9: + [html] directory = build/coverage diff --git a/.github/dependabot.yml b/.github/dependabot.yml index ed79a0d63..04595aed1 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,4 +10,3 @@ updates: interval: "weekly" commit-message: prefix: infra - diff --git a/.github/workflows/check-format.yml b/.github/workflows/check-format.yml index 1795135e7..a6106b2d7 100644 --- a/.github/workflows/check-format.yml +++ b/.github/workflows/check-format.yml @@ -23,8 +23,7 @@ jobs: python-version: '3.9' - name: Install dependencies run: | - pip install --upgrade pip - pip install -e .[test] + pip install tox - name: Run code format checks run: | - tox -e linters_check + tox -e linters_check -p auto diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index b12cde750..a6a359e93 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -26,6 +26,6 @@ jobs: - name: Build a binary wheel and a source tarball run: python setup.py sdist bdist_wheel - name: Publish distribution to PyPI - uses: pypa/gh-action-pypi-publish@2f6f737ca5f74c637829c0f5c3acd0e29ea5e8bf # release/v1 + uses: pypa/gh-action-pypi-publish@81e9d935c883d0b210363ab89cf05f3894778450 # release/v1 with: password: ${{ secrets.pypi_token }} diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 74b49db9e..f4ac4e916 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -36,5 +36,7 @@ jobs: run: | tox -e unit-tests - name: Upload coverage report to Codecov - uses: codecov/codecov-action@4fe8c5f003fae66aa5ebb77cfd3e7bfbbda0b6b0 # v3.1.5 + uses: codecov/codecov-action@54bcd8715eee62d40e33596ef5e8f0f48dbbccab # v4.1.0 + with: + token: ${{ secrets.CODECOV_TOKEN }} if: ${{ strategy.job-index }} == 0 diff --git a/.gitignore b/.gitignore index d91f4d305..bca0430b0 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ __pycache__/ /build /venv /dist +/model.tar.gz diff --git a/.readthedocs.yml b/.readthedocs.yml index e824a6afc..b6ca23199 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -12,7 +12,7 @@ sphinx: # Optionally build your docs in additional formats such as PDF formats: - pdf - + # setting up build.os and the python version build: os: ubuntu-22.04 diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ac2a7325..343965f60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,245 @@ # Changelog +## v1.80.0 (2024-05-22) + +### Features + + * add support for the ARN region + * Add support for SerializableProgram abstraction to Device interface + +### Bug Fixes and Other Changes + + * job fixture for endpoint support + +## v1.79.1 (2024-05-08) + +### Bug Fixes and Other Changes + + * check the qubit set length against observables + +## v1.79.0 (2024-05-06) + +### Features + + * Direct Reservation context manager + +### Documentation Changes + + * correct the example in the measure docstring + +## v1.78.0 (2024-04-18) + +### Features + + * add phase RX gate + +## v1.77.6 (2024-04-17) + +### Bug Fixes and Other Changes + + * if rydberg local is not pulled, pass in None + +## v1.77.5 (2024-04-16) + +### Bug Fixes and Other Changes + + * remove optional discretization fields + +## v1.77.4 (2024-04-16) + +### Bug Fixes and Other Changes + + * discretize method now takes None as an arg + +### Documentation Changes + + * Correct miscellaneous spelling mistakes in docstrings + +## v1.77.3.post0 (2024-04-15) + +### Documentation Changes + + * correct gphase matrix representation + +## v1.77.3 (2024-04-11) + +### Bug Fixes and Other Changes + + * measure target qubits are required + +## v1.77.2 (2024-04-10) + +### Bug Fixes and Other Changes + + * remove shifting field from testing + +## v1.77.1 (2024-04-10) + +### Bug Fixes and Other Changes + + * add measure qubit targets in braket_program_context + +## v1.77.0 (2024-04-10) + +### Features + + * rename shifting field to local detuning + +## v1.76.3 (2024-04-09) + +### Bug Fixes and Other Changes + + * Replace pkg_resources with importlib.metadata + +### Documentation Changes + + * Improve gphase unitary matrix definition in docstring + +## v1.76.2 (2024-04-08) + +### Bug Fixes and Other Changes + + * backwards compatibility for local detuning + +## v1.76.1 (2024-04-08) + +### Bug Fixes and Other Changes + + * Support single-register measurements in `from_ir` + * prevent repeated measurements on a qubit + +## v1.76.0 (2024-04-01) + +### Features + + * add support for OpenQASM measure on a subset of qubits + +### Bug Fixes and Other Changes + + * restore the dependent test back to pennylane + +### Documentation Changes + + * fix GPI2 gate matrix representation + +## v1.75.0 (2024-03-28) + +### Features + + * upgrade to pydantic 2.x + +### Bug Fixes and Other Changes + + * change schemas constraint + +## v1.74.1 (2024-03-27) + +### Bug Fixes and Other Changes + + * temporarily pin the schemas version + +## v1.74.0 (2024-03-21) + +### Features + + * Allow sets of calibrations in batches + +### Bug Fixes and Other Changes + + * batch tasking passing lists to single tasks + +## v1.73.3 (2024-03-18) + +### Bug Fixes and Other Changes + + * store account id if already accessed + +## v1.73.2 (2024-03-13) + +### Bug Fixes and Other Changes + + * increase tol value for our integ tests + +## v1.73.1 (2024-03-11) + +### Bug Fixes and Other Changes + + * allow for braket endpoint to be set within the jobs + +## v1.73.0 (2024-03-07) + +### Features + + * update circuit drawing + +## v1.72.2 (2024-03-04) + +### Bug Fixes and Other Changes + + * validate FreeParameter name + +## v1.72.1 (2024-02-28) + +### Bug Fixes and Other Changes + + * escape slash in metrics prefix + +## v1.72.0 (2024-02-27) + +### Features + + * FreeParameterExpression division + +## v1.71.0 (2024-02-26) + +### Features + + * update log stream prefix for new jobs + +## v1.70.3 (2024-02-21) + +### Bug Fixes and Other Changes + + * remove test with job creation with qpu + * use the caller's account id based on the session + * docs: add note about using env variables for endpoint + +## v1.70.2 (2024-02-14) + +### Bug Fixes and Other Changes + + * Sort input parameters when doing testing equality of two PulseSequences + +## v1.70.1 (2024-02-13) + +### Bug Fixes and Other Changes + + * Do not autodeclare FreeParameter in OQpy + +## v1.70.0 (2024-02-12) + +### Features + + * Support noise models in DM simulators + +## v1.69.1 (2024-02-08) + +### Bug Fixes and Other Changes + + * let price tracker checks skip over devices without execution win… + +## v1.69.0 (2024-02-06) + +### Features + + * update OQpy to version 0.3.5 + +## v1.68.3 (2024-02-05) + +### Bug Fixes and Other Changes + + * Allow identities in PauliString observable + ## v1.68.2 (2024-01-31) ### Bug Fixes and Other Changes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7362a17b7..0df22f51e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -218,9 +218,9 @@ You can then find the generated HTML files in `build/documentation/html`. Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels ((enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/amazon-braket/amazon-braket-sdk-python/labels/help%20wanted) issues is a great place to start. ## Building Integrations -The Amazon Braket SDK supports integrations with popular quantum computing frameworks such as [PennyLane](https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python), [Strawberryfields](https://github.com/amazon-braket/amazon-braket-strawberryfields-plugin-python) and [DWave's Ocean library](https://github.com/amazon-braket/amazon-braket-ocean-plugin-python). These serve as a good reference for a new integration you wish to develop. +The Amazon Braket SDK supports integrations with popular quantum computing frameworks such as [PennyLane](https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python), [Strawberryfields](https://github.com/amazon-braket/amazon-braket-strawberryfields-plugin-python) and [DWave's Ocean library](https://github.com/amazon-braket/amazon-braket-ocean-plugin-python). These serve as a good reference for a new integration you wish to develop. -When developing a new integration with the Amazon Braket SDK, please remember to update the [user agent header](https://datatracker.ietf.org/doc/html/rfc7231#section-5.5.3) to include version information for your integration. An example can be found [here](https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python/commit/ccee35604afc2b04d83ee9103eccb2821a4256cb). +When developing a new integration with the Amazon Braket SDK, please remember to update the [user agent header](https://datatracker.ietf.org/doc/html/rfc7231#section-5.5.3) to include version information for your integration. An example can be found [here](https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python/commit/ccee35604afc2b04d83ee9103eccb2821a4256cb). ## Code of Conduct diff --git a/README.md b/README.md index 1c79e8034..0c935853d 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,8 @@ Many quantum algorithms need to run multiple independent circuits, and submittin ```python circuits = [bell for _ in range(5)] batch = device.run_batch(circuits, shots=100) -print(batch.results()[0].measurement_counts) # The result of the first quantum task in the batch +# The result of the first quantum task in the batch +print(batch.results()[0].measurement_counts) ``` ### Running a hybrid job @@ -139,14 +140,14 @@ from braket.aws import AwsDevice device = AwsDevice("arn:aws:braket:::device/qpu/rigetti/Aspen-8") bell = Circuit().h(0).cnot(0, 1) -task = device.run(bell) +task = device.run(bell) print(task.result().measurement_counts) ``` When you execute your task, Amazon Braket polls for a result. By default, Braket polls for 5 days; however, it is possible to change this by modifying the `poll_timeout_seconds` parameter in `AwsDevice.run`, as in the example below. Keep in mind that if your polling timeout is too short, results may not be returned within the polling time, such as when a QPU is unavailable, and a local timeout error is returned. You can always restart the polling by using `task.result()`. ```python -task = device.run(bell, poll_timeout_seconds=86400) # 1 day +task = device.run(bell, poll_timeout_seconds=86400) # 1 day print(task.result().measurement_counts) ``` @@ -205,6 +206,12 @@ To run linters and doc generators and unit tests: tox ``` +or if your machine can handle multithreaded workloads, run them in parallel with: + +```bash +tox -p auto +``` + ### Integration Tests First, configure a profile to use your account to interact with AWS. To learn more, see [Configure AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html). @@ -232,15 +239,15 @@ tox -e integ-tests -- your-arguments ### Issues and Bug Reports -If you encounter bugs or face issues while using the SDK, please let us know by posting -the issue on our [Github issue tracker](https://github.com/amazon-braket/amazon-braket-sdk-python/issues/). +If you encounter bugs or face issues while using the SDK, please let us know by posting +the issue on our [Github issue tracker](https://github.com/amazon-braket/amazon-braket-sdk-python/issues/). For other issues or general questions, please ask on the [Quantum Computing Stack Exchange](https://quantumcomputing.stackexchange.com/questions/ask?Tags=amazon-braket). ### Feedback and Feature Requests -If you have feedback or features that you would like to see on Amazon Braket, we would love to hear from you! -[Github issues](https://github.com/amazon-braket/amazon-braket-sdk-python/issues/) is our preferred mechanism for collecting feedback and feature requests, allowing other users -to engage in the conversation, and +1 issues to help drive priority. +If you have feedback or features that you would like to see on Amazon Braket, we would love to hear from you! +[Github issues](https://github.com/amazon-braket/amazon-braket-sdk-python/issues/) is our preferred mechanism for collecting feedback and feature requests, allowing other users +to engage in the conversation, and +1 issues to help drive priority. ### Code contributors diff --git a/doc/conf.py b/doc/conf.py index bc1afaafd..b5eae7acc 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,14 +1,13 @@ """Sphinx configuration.""" import datetime - -import pkg_resources +from importlib.metadata import version # Sphinx configuration below. project = "amazon-braket-sdk" -version = pkg_resources.require(project)[0].version +version = version(project) release = version -copyright = "{}, Amazon.com".format(datetime.datetime.now().year) +copyright = f"{datetime.datetime.now().year}, Amazon.com" extensions = [ "sphinxcontrib.apidoc", @@ -30,7 +29,7 @@ html_theme_options = { "prev_next_buttons_location": "both", } -htmlhelp_basename = "{}doc".format(project) +htmlhelp_basename = f"{project}doc" language = "en" diff --git a/doc/examples-adv-circuits-algorithms.rst b/doc/examples-adv-circuits-algorithms.rst index b37e87e34..23de8a047 100644 --- a/doc/examples-adv-circuits-algorithms.rst +++ b/doc/examples-adv-circuits-algorithms.rst @@ -2,30 +2,30 @@ Advanced circuits and algorithms ################################ -Learn more about working with advanced circuits and algoritms. +Learn more about working with advanced circuits and algorithms. .. toctree:: :maxdepth: 2 - + ********************************************************************************************************************************************************** `Grover's search algorithm `_ ********************************************************************************************************************************************************** -This tutorial provides a step-by-step walkthrough of Grover's quantum algorithm. -You learn how to build the corresponding quantum circuit with simple modular building -blocks using the Amazon Braket SDK. You will learn how to build custom -gates that are not part of the basic gate set provided by the SDK. A custom gate can used +This tutorial provides a step-by-step walkthrough of Grover's quantum algorithm. +You learn how to build the corresponding quantum circuit with simple modular building +blocks using the Amazon Braket SDK. You will learn how to build custom +gates that are not part of the basic gate set provided by the SDK. A custom gate can used as a core quantum gate by registering it as a subroutine. ****************************************************************************************************************************************************************************************************************** `Quantum amplitude amplification `_ ****************************************************************************************************************************************************************************************************************** -This tutorial provides a detailed discussion and implementation of the Quantum Amplitude Amplification (QAA) -algorithm using the Amazon Braket SDK. QAA is a routine in quantum computing which generalizes the idea behind -Grover's famous search algorithm, with applications across many quantum algorithms. QAA uses an iterative -approach to systematically increase the probability of finding one or multiple -target states in a given search space. In a quantum computer, QAA can be used to obtain a +This tutorial provides a detailed discussion and implementation of the Quantum Amplitude Amplification (QAA) +algorithm using the Amazon Braket SDK. QAA is a routine in quantum computing which generalizes the idea behind +Grover's famous search algorithm, with applications across many quantum algorithms. QAA uses an iterative +approach to systematically increase the probability of finding one or multiple +target states in a given search space. In a quantum computer, QAA can be used to obtain a quadratic speedup over several classical algorithms. @@ -33,18 +33,18 @@ quadratic speedup over several classical algorithms. `Quantum Fourier transform `_ ************************************************************************************************************************************************************************************************ -This tutorial provides a detailed implementation of the Quantum Fourier Transform (QFT) and -its inverse using Amazon Braket's SDK. The QFT is an important subroutine to many quantum algorithms, -most famously Shor's algorithm for factoring and the quantum phase estimation (QPE) algorithm -for estimating the eigenvalues of a unitary operator. +This tutorial provides a detailed implementation of the Quantum Fourier Transform (QFT) and +its inverse using Amazon Braket's SDK. The QFT is an important subroutine to many quantum algorithms, +most famously Shor's algorithm for factoring and the quantum phase estimation (QPE) algorithm +for estimating the eigenvalues of a unitary operator. ********************************************************************************************************************************************************************************************* `Quantum phase estimation `_ ********************************************************************************************************************************************************************************************* -This tutorial provides a detailed implementation of the Quantum Phase Estimation (QPE) -algorithm using the Amazon Braket SDK. The QPE algorithm is designed to estimate the -eigenvalues of a unitary operator. Eigenvalue problems can be found across many -disciplines and application areas, including principal component analysis (PCA) -as used in machine learning and the solution of differential equations in mathematics, physics, -engineering and chemistry. +This tutorial provides a detailed implementation of the Quantum Phase Estimation (QPE) +algorithm using the Amazon Braket SDK. The QPE algorithm is designed to estimate the +eigenvalues of a unitary operator. Eigenvalue problems can be found across many +disciplines and application areas, including principal component analysis (PCA) +as used in machine learning and the solution of differential equations in mathematics, physics, +engineering and chemistry. diff --git a/doc/examples-braket-features.rst b/doc/examples-braket-features.rst index 75361f172..25c088ab1 100644 --- a/doc/examples-braket-features.rst +++ b/doc/examples-braket-features.rst @@ -11,30 +11,30 @@ Learn more about the indivudal features of Amazon Braket. `Getting notifications when a quantum task completes `_ ***************************************************************************************************************************************************************************************************************************************************************** -This tutorial illustrates how Amazon Braket integrates with Amazon EventBridge for -event-based processing. In the tutorial, you will learn how to configure Amazon Braket -and Amazon Eventbridge to receive text notification about quantum task completions on your phone. +This tutorial illustrates how Amazon Braket integrates with Amazon EventBridge for +event-based processing. In the tutorial, you will learn how to configure Amazon Braket +and Amazon Eventbridge to receive text notification about quantum task completions on your phone. *********************************************************************************************************************************************************************** `Allocating Qubits on QPU Devices `_ *********************************************************************************************************************************************************************** -This tutorial explains how you can use the Amazon Braket SDK to allocate the qubit +This tutorial explains how you can use the Amazon Braket SDK to allocate the qubit selection for your circuits manually, when running on QPUs. *************************************************************************************************************************************************************************************************** `Getting Devices and Checking Device Properties `_ *************************************************************************************************************************************************************************************************** -This example shows how to interact with the Amazon Braket GetDevice API to -retrieve Amazon Braket devices (such as simulators and QPUs) programatically, +This example shows how to interact with the Amazon Braket GetDevice API to +retrieve Amazon Braket devices (such as simulators and QPUs) programmatically, and how to gain access to their properties. *********************************************************************************************************************************************************************************** `Using the tensor network simulator TN1 `_ *********************************************************************************************************************************************************************************** -This notebook introduces the Amazon Braket managed tensor network simulator, TN1. +This notebook introduces the Amazon Braket managed tensor network simulator, TN1. You will learn about how TN1 works, how to use it, and which problems are best suited to run on TN1. diff --git a/doc/examples-getting-started.rst b/doc/examples-getting-started.rst index 8c9eb90f5..64c6939af 100644 --- a/doc/examples-getting-started.rst +++ b/doc/examples-getting-started.rst @@ -6,7 +6,7 @@ Get started on Amazon Braket with some introductory examples. .. toctree:: :maxdepth: 2 - + ********************************************************************************************************************************************************* `Getting started `_ ********************************************************************************************************************************************************* @@ -17,11 +17,11 @@ A hello-world tutorial that shows you how to build a simple circuit and run it o `Running quantum circuits on simulators `_ ****************************************************************************************************************************************************************************************************************************** -This tutorial prepares a paradigmatic example for a multi-qubit entangled state, -the so-called GHZ state (named after the three physicists Greenberger, Horne, and Zeilinger). -The GHZ state is extremely non-classical, and therefore very sensitive to decoherence. -It is often used as a performance benchmark for today's hardware. In many quantum information -protocols it is used as a resource for quantum error correction, quantum communication, +This tutorial prepares a paradigmatic example for a multi-qubit entangled state, +the so-called GHZ state (named after the three physicists Greenberger, Horne, and Zeilinger). +The GHZ state is extremely non-classical, and therefore very sensitive to decoherence. +It is often used as a performance benchmark for today's hardware. In many quantum information +protocols it is used as a resource for quantum error correction, quantum communication, and quantum metrology. **Note:** When a circuit is ran using a simulator, customers are required to use contiguous qubits/indices. @@ -30,30 +30,29 @@ and quantum metrology. `Running quantum circuits on QPU devices `_ ********************************************************************************************************************************************************************************************************************************* -This tutorial prepares a maximally-entangled Bell state between two qubits, -for classical simulators and for QPUs. For classical devices, we can run the circuit on a -local simulator or a cloud-based managed simulator. For the quantum devices, -we run the circuit on the superconducting machine from Rigetti, and on the ion-trap -machine provided by IonQ. +This tutorial prepares a maximally-entangled Bell state between two qubits, +for classical simulators and for QPUs. For classical devices, we can run the circuit on a +local simulator or a cloud-based managed simulator. For the quantum devices, +we run the circuit on the superconducting machine from Rigetti, and on the ion-trap +machine provided by IonQ. ****************************************************************************************************************************************************************************************************************************************************** `Deep Dive into the anatomy of quantum circuits `_ ****************************************************************************************************************************************************************************************************************************************************** -This tutorial discusses in detail the anatomy of quantum circuits in the Amazon -Braket SDK. You will learn how to build (parameterized) circuits and display them +This tutorial discusses in detail the anatomy of quantum circuits in the Amazon +Braket SDK. You will learn how to build (parameterized) circuits and display them graphically, and how to append circuits to each other. Next, learn -more about circuit depth and circuit size. Finally you will learn how to execute -the circuit on a device of our choice (defining a quantum task) and how to track, log, +more about circuit depth and circuit size. Finally you will learn how to execute +the circuit on a device of our choice (defining a quantum task) and how to track, log, recover, or cancel a quantum task efficiently. *************************************************************************************************************************************************************** `Superdense coding `_ *************************************************************************************************************************************************************** -This tutorial constructs an implementation of the superdense coding protocol using -the Amazon Braket SDK. Superdense coding is a method of transmitting two classical -bits by sending only one qubit. Starting with a pair of entanged qubits, the sender -(aka Alice) applies a certain quantum gate to their qubit and sends the result +This tutorial constructs an implementation of the superdense coding protocol using +the Amazon Braket SDK. Superdense coding is a method of transmitting two classical +bits by sending only one qubit. Starting with a pair of entanged qubits, the sender +(aka Alice) applies a certain quantum gate to their qubit and sends the result to the receiver (aka Bob), who is then able to decode the full two-bit message. - diff --git a/doc/examples-hybrid-quantum.rst b/doc/examples-hybrid-quantum.rst index 9c7f3aca2..9a0a8efba 100644 --- a/doc/examples-hybrid-quantum.rst +++ b/doc/examples-hybrid-quantum.rst @@ -11,19 +11,19 @@ Learn more about hybrid quantum algorithms. `QAOA `_ ************************************************************************************************************************************* -This tutorial shows how to (approximately) solve binary combinatorial optimization problems -using the Quantum Approximate Optimization Algorithm (QAOA). +This tutorial shows how to (approximately) solve binary combinatorial optimization problems +using the Quantum Approximate Optimization Algorithm (QAOA). ************************************************************************************************************************************************************************************ `VQE Transverse Ising `_ ************************************************************************************************************************************************************************************ This tutorial shows how to solve for the ground state of the Transverse Ising Model -using the variational quantum eigenvalue solver (VQE). +using the variational quantum eigenvalue solver (VQE). **************************************************************************************************************************************************************** `VQE Chemistry `_ **************************************************************************************************************************************************************** -This tutorial shows how to implement the Variational Quantum Eigensolver (VQE) algorithm in -Amazon Braket SDK to compute the potential energy surface (PES) for the Hydrogen molecule. +This tutorial shows how to implement the Variational Quantum Eigensolver (VQE) algorithm in +Amazon Braket SDK to compute the potential energy surface (PES) for the Hydrogen molecule. diff --git a/doc/examples-ml-pennylane.rst b/doc/examples-ml-pennylane.rst index 5c7db93aa..1aa57cc4c 100644 --- a/doc/examples-ml-pennylane.rst +++ b/doc/examples-ml-pennylane.rst @@ -11,37 +11,37 @@ Learn more about how to combine PennyLane with Amazon Braket. `Combining PennyLane with Amazon Braket `_ ************************************************************************************************************************************************************************** -This tutorial shows you how to construct circuits and evaluate their gradients in +This tutorial shows you how to construct circuits and evaluate their gradients in PennyLane with execution performed using Amazon Braket. ***************************************************************************************************************************************************************************************************************************************************** `Computing gradients in parallel with PennyLane-Braket `_ ***************************************************************************************************************************************************************************************************************************************************** -Learn how to speed up training of quantum circuits by using parallel execution on -Amazon Braket. Quantum circuit training involving gradients -requires multiple device executions. The Amazon Braket SV1 simulator can be used to overcome this. -The tutorial benchmarks SV1 against a local simulator, showing that SV1 outperforms the -local simulator for both executions and gradient calculations. This illustrates how +Learn how to speed up training of quantum circuits by using parallel execution on +Amazon Braket. Quantum circuit training involving gradients +requires multiple device executions. The Amazon Braket SV1 simulator can be used to overcome this. +The tutorial benchmarks SV1 against a local simulator, showing that SV1 outperforms the +local simulator for both executions and gradient calculations. This illustrates how parallel capabilities can be combined between PennyLane and SV1. ****************************************************************************************************************************************************************************************** `Graph optimization with QAOA `_ ****************************************************************************************************************************************************************************************** -In this tutorial, you learn how quantum circuit training can be applied to a problem -of practical relevance in graph optimization. It easy it is to train a QAOA circuit in -PennyLane to solve the maximum clique problem on a simple example graph. The tutorial -then extends to a more difficult 20-node graph and uses the parallel capabilities of -the Amazon Braket SV1 simulator to speed up gradient calculations and hence train the quantum circuit faster, +In this tutorial, you learn how quantum circuit training can be applied to a problem +of practical relevance in graph optimization. It easy it is to train a QAOA circuit in +PennyLane to solve the maximum clique problem on a simple example graph. The tutorial +then extends to a more difficult 20-node graph and uses the parallel capabilities of +the Amazon Braket SV1 simulator to speed up gradient calculations and hence train the quantum circuit faster, using around 1-2 minutes per iteration. *************************************************************************************************************************************************************************************************************** `Hydrogen Molecule geometry with VQE `_ *************************************************************************************************************************************************************************************************************** -In this tutorial, you will learn how PennyLane and Amazon Braket can be combined to solve an -important problem in quantum chemistry. The ground state energy of molecular hydrogen is calculated -by optimizing a VQE circuit using the local Braket simulator. This tutorial highlights how -qubit-wise commuting observables can be measured together in PennyLane and Amazon Braket, +In this tutorial, you will learn how PennyLane and Amazon Braket can be combined to solve an +important problem in quantum chemistry. The ground state energy of molecular hydrogen is calculated +by optimizing a VQE circuit using the local Braket simulator. This tutorial highlights how +qubit-wise commuting observables can be measured together in PennyLane and Amazon Braket, making optimization more efficient. diff --git a/doc/examples.rst b/doc/examples.rst index 87c2e1f7a..93aac757b 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -1,7 +1,7 @@ ######## Examples ######## - + There are several examples available in the Amazon Braket repo: https://github.com/amazon-braket/amazon-braket-examples. @@ -14,5 +14,3 @@ https://github.com/amazon-braket/amazon-braket-examples. examples-hybrid-quantum.rst examples-ml-pennylane.rst examples-hybrid-jobs.rst - - diff --git a/doc/getting-started.rst b/doc/getting-started.rst index 205254740..31493b789 100644 --- a/doc/getting-started.rst +++ b/doc/getting-started.rst @@ -16,7 +16,7 @@ at https://docs.aws.amazon.com/braket/index.html. Getting started using an Amazon Braket notebook ************************************************ -You can use the AWS Console to enable Amazon Braket, +You can use the AWS Console to enable Amazon Braket, then create an Amazon Braket notebook instance and run your first circuit with the Amazon Braket Python SDK: @@ -25,7 +25,7 @@ and run your first circuit with the Amazon Braket Python SDK: 3. `Run your first circuit using the Amazon Braket Python SDK `_. When you use an Amazon Braket notebook, the Amazon Braket SDK and plugins are -preloaded. +preloaded. *********************************** Getting started in your environment @@ -37,4 +37,3 @@ after enabling Amazon Braket and configuring the AWS SDK for Python: 1. `Enable Amazon Braket `_. 2. Configure the AWS SDK for Python (Boto3) using the `Quickstart `_. 3. `Run your first circuit using the Amazon Braket Python SDK `_. - diff --git a/doc/index.rst b/doc/index.rst index 8d996f4cc..54d10b54d 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -2,7 +2,7 @@ Amazon Braket Python SDK ######################## -The Amazon Braket Python SDK is an open source library to design and build quantum circuits, +The Amazon Braket Python SDK is an open source library to design and build quantum circuits, submit them to Amazon Braket devices as quantum tasks, and monitor their execution. This documentation provides information about the Amazon Braket Python SDK library. The project @@ -29,7 +29,7 @@ Explore Amazon Braket examples. :maxdepth: 3 examples.rst - + *************** Python SDK APIs @@ -39,6 +39,5 @@ The Amazon Braket Python SDK APIs: .. toctree:: :maxdepth: 2 - - _apidoc/modules + _apidoc/modules diff --git a/examples/reservation.py b/examples/reservation.py index 682f71f50..83be87ebd 100644 --- a/examples/reservation.py +++ b/examples/reservation.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from braket.aws import AwsDevice +from braket.aws import AwsDevice, DirectReservation from braket.circuits import Circuit from braket.devices import Devices @@ -19,6 +19,17 @@ device = AwsDevice(Devices.IonQ.Aria1) # To run a task in a device reservation, change the device to the one you reserved -# and fill in your reservation ARN -task = device.run(bell, shots=100, reservation_arn="reservation ARN") +# and fill in your reservation ARN. +with DirectReservation(device, reservation_arn=""): + task = device.run(bell, shots=100) +print(task.result().measurement_counts) + +# Alternatively, you may start the reservation globally +reservation = DirectReservation(device, reservation_arn="").start() +task = device.run(bell, shots=100) +print(task.result().measurement_counts) +reservation.stop() # stop creating tasks in the reservation + +# Lastly, you may pass the reservation ARN directly to a quantum task +task = device.run(bell, shots=100, reservation_arn="") print(task.result().measurement_counts) diff --git a/pydoclint-baseline.txt b/pydoclint-baseline.txt new file mode 100644 index 000000000..816c4265a --- /dev/null +++ b/pydoclint-baseline.txt @@ -0,0 +1,233 @@ +src/braket/aws/aws_device.py + DOC101: Method `AwsDevice.run_batch`: Docstring contains fewer arguments than in function signature. + DOC103: Method `AwsDevice.run_batch`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**aws_quantum_task_kwargs: , *aws_quantum_task_args: ]. +-------------------- +src/braket/aws/aws_quantum_job.py + DOC502: Method `AwsQuantumJob.create` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC101: Method `AwsQuantumJob._is_valid_aws_session_region_for_job_arn`: Docstring contains fewer arguments than in function signature. + DOC109: Method `AwsQuantumJob._is_valid_aws_session_region_for_job_arn`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `AwsQuantumJob._is_valid_aws_session_region_for_job_arn`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [aws_session: AwsSession, job_arn: str]. + DOC502: Method `AwsQuantumJob.logs` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC502: Method `AwsQuantumJob.cancel` has a "Raises" section in the docstring, but there are not "raise" statements in the body +-------------------- +src/braket/aws/aws_quantum_task.py + DOC101: Method `AwsQuantumTask.create`: Docstring contains fewer arguments than in function signature. + DOC103: Method `AwsQuantumTask.create`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. + DOC501: Method `AwsQuantumTask.create` has "raise" statements, but the docstring does not have a "Raises" section + DOC101: Method `AwsQuantumTask._aws_session_for_task_arn`: Docstring contains fewer arguments than in function signature. + DOC109: Method `AwsQuantumTask._aws_session_for_task_arn`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `AwsQuantumTask._aws_session_for_task_arn`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [task_arn: str]. + DOC501: Function `_create_annealing_device_params` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/aws/aws_quantum_task_batch.py + DOC501: Method `AwsQuantumTaskBatch.results` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `AwsQuantumTaskBatch.retry_unsuccessful_tasks` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/aws/aws_session.py + DOC106: Method `AwsSession.create_quantum_task`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `AwsSession.create_quantum_task`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC106: Method `AwsSession.create_job`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `AwsSession.create_job`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC001: Function/method `parse_s3_uri`: Potential formatting errors in docstring. Error message: Expected a colon in 'a valid S3 URI.'. + DOC101: Method `AwsSession.parse_s3_uri`: Docstring contains fewer arguments than in function signature. + DOC109: Method `AwsSession.parse_s3_uri`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `AwsSession.parse_s3_uri`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [s3_uri: str]. + DOC201: Method `AwsSession.parse_s3_uri` does not have a return section in docstring + DOC203: Method `AwsSession.parse_s3_uri` return type(s) in docstring not consistent with the return annotation. Return annotation has 1 type(s); docstring return section has 0 type(s). + DOC501: Method `AwsSession.parse_s3_uri` has "raise" statements, but the docstring does not have a "Raises" section + DOC001: Function/method `construct_s3_uri`: Potential formatting errors in docstring. Error message: Expected a colon in 'valid to generate an S3 URI'. + DOC101: Method `AwsSession.construct_s3_uri`: Docstring contains fewer arguments than in function signature. + DOC109: Method `AwsSession.construct_s3_uri`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `AwsSession.construct_s3_uri`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [*dirs: str, bucket: str]. + DOC201: Method `AwsSession.construct_s3_uri` does not have a return section in docstring + DOC203: Method `AwsSession.construct_s3_uri` return type(s) in docstring not consistent with the return annotation. Return annotation has 1 type(s); docstring return section has 0 type(s). + DOC501: Method `AwsSession.construct_s3_uri` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `AwsSession.get_full_image_tag` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/angled_gate.py + DOC101: Method `AngledGate.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `AngledGate.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `AngledGate.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `AngledGate.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC501: Method `DoubleAngledGate.adjoint` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `TripleAngledGate.adjoint` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/braket_program_context.py + DOC101: Method `BraketProgramContext.add_gate_instruction`: Docstring contains fewer arguments than in function signature. + DOC103: Method `BraketProgramContext.add_gate_instruction`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [*params: ]. +-------------------- +src/braket/circuits/circuit.py + DOC101: Method `Circuit.__init__`: Docstring contains fewer arguments than in function signature. + DOC103: Method `Circuit.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. + DOC502: Method `Circuit.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC105: Method `Circuit.apply_gate_noise`: Argument names match, but type hints do not match + DOC001: Function/method `_validate_parameters`: Potential formatting errors in docstring. Error message: No specification for "Raises": "" + DOC101: Method `Circuit._validate_parameters`: Docstring contains fewer arguments than in function signature. + DOC109: Method `Circuit._validate_parameters`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `Circuit._validate_parameters`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [parameter_values: dict[str, Number]]. + DOC501: Method `Circuit._validate_parameters` has "raise" statements, but the docstring does not have a "Raises" section + DOC101: Method `Circuit.add`: Docstring contains fewer arguments than in function signature. + DOC103: Method `Circuit.add`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. + DOC502: Method `Circuit.to_unitary` has a "Raises" section in the docstring, but there are not "raise" statements in the body +-------------------- +src/braket/circuits/compiler_directive.py + DOC501: Method `CompilerDirective.__init__` has "raise" statements, but the docstring does not have a "Raises" section + DOC101: Method `CompilerDirective.to_ir`: Docstring contains fewer arguments than in function signature. + DOC103: Method `CompilerDirective.to_ir`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC501: Method `CompilerDirective.counterpart` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/gate.py + DOC502: Method `Gate.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC501: Method `Gate.adjoint` has "raise" statements, but the docstring does not have a "Raises" section + DOC001: Function/method `to_ir`: Potential formatting errors in docstring. Error message: Expected a colon in "properties don't correspond to the `ir_type`.". + DOC101: Method `Gate.to_ir`: Docstring contains fewer arguments than in function signature. + DOC109: Method `Gate.to_ir`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `Gate.to_ir`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [control: Optional[QubitSet], control_state: Optional[BasisStateInput], ir_type: IRType, power: float, serialization_properties: Optional[SerializationProperties], target: QubitSet]. + DOC201: Method `Gate.to_ir` does not have a return section in docstring + DOC203: Method `Gate.to_ir` return type(s) in docstring not consistent with the return annotation. Return annotation has 1 type(s); docstring return section has 0 type(s). + DOC501: Method `Gate.to_ir` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Gate._to_jaqcd` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/gates.py + DOC105: Method `Unitary.__init__`: Argument names match, but type hints do not match + DOC105: Method `Unitary.unitary`: Argument names match, but type hints do not match + DOC501: Method `PulseGate.__init__` has "raise" statements, but the docstring does not have a "Raises" section + DOC101: Method `PulseGate.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `PulseGate.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `PulseGate.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `PulseGate.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. +-------------------- +src/braket/circuits/noise.py + DOC502: Method `Noise.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC001: Function/method `to_ir`: Potential formatting errors in docstring. Error message: Expected a colon in "properties don't correspond to the `ir_type`.". + DOC101: Method `Noise.to_ir`: Docstring contains fewer arguments than in function signature. + DOC109: Method `Noise.to_ir`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `Noise.to_ir`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [ir_type: IRType, serialization_properties: SerializationProperties | None, target: QubitSet]. + DOC201: Method `Noise.to_ir` does not have a return section in docstring + DOC203: Method `Noise.to_ir` return type(s) in docstring not consistent with the return annotation. Return annotation has 1 type(s); docstring return section has 0 type(s). + DOC501: Method `Noise.to_ir` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Noise._to_jaqcd` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Noise._to_openqasm` has "raise" statements, but the docstring does not have a "Raises" section + DOC101: Method `Noise.to_matrix`: Docstring contains fewer arguments than in function signature. + DOC106: Method `Noise.to_matrix`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `Noise.to_matrix`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `Noise.to_matrix`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. + DOC203: Method `Noise.to_matrix` return type(s) in docstring not consistent with the return annotation. Return annotation types: ['Iterable[np.ndarray]']; docstring return section types: ['Iterable[ndarray]'] + DOC501: Method `Noise.to_matrix` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Noise.from_dict` has "raise" statements, but the docstring does not have a "Raises" section + DOC502: Method `SingleProbabilisticNoise.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC101: Method `SingleProbabilisticNoise.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `SingleProbabilisticNoise.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `SingleProbabilisticNoise.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `SingleProbabilisticNoise.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC502: Method `SingleProbabilisticNoise_34.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC502: Method `SingleProbabilisticNoise_1516.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC101: Method `MultiQubitPauliNoise.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `MultiQubitPauliNoise.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `MultiQubitPauliNoise.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `MultiQubitPauliNoise.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC203: Method `PauliNoise.probX` return type(s) in docstring not consistent with the return annotation. Return annotation types: ['Union[FreeParameterExpression, float]']; docstring return section types: [''] + DOC203: Method `PauliNoise.probY` return type(s) in docstring not consistent with the return annotation. Return annotation types: ['Union[FreeParameterExpression, float]']; docstring return section types: [''] + DOC203: Method `PauliNoise.probZ` return type(s) in docstring not consistent with the return annotation. Return annotation types: ['Union[FreeParameterExpression, float]']; docstring return section types: [''] + DOC101: Method `PauliNoise.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `PauliNoise.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `PauliNoise.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `PauliNoise.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC502: Method `DampingNoise.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC101: Method `DampingNoise.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `DampingNoise.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `DampingNoise.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `DampingNoise.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC502: Method `GeneralizedAmplitudeDampingNoise.__init__` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC501: Function `_validate_param_value` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/noise_helpers.py + DOC501: Function `check_noise_target_gates` has "raise" statements, but the docstring does not have a "Raises" section + DOC105: Function `check_noise_target_unitary`: Argument names match, but type hints do not match + DOC501: Function `check_noise_target_unitary` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Function `check_noise_target_qubits` has "raise" statements, but the docstring does not have a "Raises" section + DOC105: Function `apply_noise_to_gates`: Argument names match, but type hints do not match + DOC502: Function `apply_noise_to_gates` has a "Raises" section in the docstring, but there are not "raise" statements in the body +-------------------- +src/braket/circuits/noise_model/criteria.py + DOC501: Method `Criteria.applicable_key_types` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Criteria.get_keys` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Criteria.to_dict` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Criteria.from_dict` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/noise_model/criteria_input_parsing.py + DOC501: Function `parse_operator_input` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Function `parse_qubit_input` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/noise_model/initialization_criteria.py + DOC501: Method `InitializationCriteria.qubit_intersection` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/noise_model/result_type_criteria.py + DOC501: Method `ResultTypeCriteria.result_type_matches` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/noises.py + DOC101: Method `PauliChannel.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `PauliChannel.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `PauliChannel.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `PauliChannel.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `Depolarizing.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `Depolarizing.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `Depolarizing.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `Depolarizing.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `TwoQubitDepolarizing.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `TwoQubitDepolarizing.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `TwoQubitDepolarizing.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `TwoQubitDepolarizing.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `TwoQubitDephasing.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `TwoQubitDephasing.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `TwoQubitDephasing.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `TwoQubitDephasing.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `TwoQubitPauliChannel.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `TwoQubitPauliChannel.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `TwoQubitPauliChannel.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `TwoQubitPauliChannel.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `AmplitudeDamping.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `AmplitudeDamping.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `AmplitudeDamping.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `AmplitudeDamping.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `GeneralizedAmplitudeDamping.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `GeneralizedAmplitudeDamping.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `GeneralizedAmplitudeDamping.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `GeneralizedAmplitudeDamping.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC101: Method `PhaseDamping.bind_values`: Docstring contains fewer arguments than in function signature. + DOC106: Method `PhaseDamping.bind_values`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `PhaseDamping.bind_values`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `PhaseDamping.bind_values`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. + DOC501: Method `Kraus.kraus` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Kraus.to_dict` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Kraus.from_dict` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/observable.py + DOC501: Method `Observable._to_openqasm` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Observable.basis_rotation_gates` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Observable.eigenvalues` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `Observable.eigenvalue` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/observables.py + DOC501: Method `TensorProduct.__init__` has "raise" statements, but the docstring does not have a "Raises" section + DOC501: Method `TensorProduct.eigenvalue` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/circuits/operator.py + DOC101: Method `Operator.to_ir`: Docstring contains fewer arguments than in function signature. + DOC106: Method `Operator.to_ir`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature + DOC109: Method `Operator.to_ir`: The option `--arg-type-hints-in-docstring` is `True` but there are no type hints in the docstring arg list + DOC103: Method `Operator.to_ir`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. +-------------------- +src/braket/circuits/result_type.py + DOC101: Method `ResultType.to_ir`: Docstring contains fewer arguments than in function signature. + DOC103: Method `ResultType.to_ir`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ]. +-------------------- +src/braket/devices/local_simulator.py + DOC101: Method `LocalSimulator.run_batch`: Docstring contains fewer arguments than in function signature. + DOC103: Method `LocalSimulator.run_batch`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. + DOC501: Method `LocalSimulator.run_batch` has "raise" statements, but the docstring does not have a "Raises" section +-------------------- +src/braket/tasks/gate_model_quantum_task_result.py + DOC502: Method `GateModelQuantumTaskResult.from_object` has a "Raises" section in the docstring, but there are not "raise" statements in the body + DOC502: Method `GateModelQuantumTaskResult.from_string` has a "Raises" section in the docstring, but there are not "raise" statements in the body +-------------------- diff --git a/setup.cfg b/setup.cfg index 3bad103a3..ab93f5955 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ test=pytest xfail_strict = true # https://pytest-xdist.readthedocs.io/en/latest/known-limitations.html addopts = - --verbose -n logical --durations=0 --durations-min=1 + --verbose -n logical --durations=0 --durations-min=1 --dist worksteal testpaths = test/unit_tests filterwarnings= # Issue #557 in `pytest-cov` (currently v4.x) has not moved for a while now, @@ -18,7 +18,7 @@ line_length = 100 multi_line_output = 3 include_trailing_comma = true profile = black - + [flake8] ignore = # not pep8, black adds whitespace before ':' @@ -32,13 +32,10 @@ ignore = RST201,RST203,RST301, max_line_length = 100 max-complexity = 10 -exclude = +exclude = __pycache__ .tox .git bin build venv -rst-roles = - # Python programming language: - py:func,py:mod,mod diff --git a/setup.py b/setup.py index d31f89f16..6763c4a55 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ from setuptools import find_namespace_packages, setup -with open("README.md", "r") as fh: +with open("README.md") as fh: long_description = fh.read() with open("src/braket/_sdk/_version.py") as f: @@ -27,20 +27,20 @@ packages=find_namespace_packages(where="src", exclude=("test",)), package_dir={"": "src"}, install_requires=[ - "amazon-braket-schemas>=1.19.1", - "amazon-braket-default-simulator>=1.19.1", - "oqpy~=0.2.1", - "setuptools", + "amazon-braket-schemas>=1.21.3", + "amazon-braket-default-simulator>=1.21.4", + "oqpy~=0.3.5", "backoff", "boltons", "boto3>=1.28.53", "cloudpickle==2.2.1", "nest-asyncio", "networkx", - "numpy", + "numpy<2", "openpulse", "openqasm3", "sympy", + "backports.entry-points-selectable", ], extras_require={ "test": [ diff --git a/src/braket/_sdk/_version.py b/src/braket/_sdk/_version.py index 30917282e..ce2f1eb66 100644 --- a/src/braket/_sdk/_version.py +++ b/src/braket/_sdk/_version.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """Version information. - Version number (major.minor.patch[-label]) +Version number (major.minor.patch[-label]) """ -__version__ = "1.68.3.dev0" +__version__ = "1.80.1.dev0" diff --git a/src/braket/ahs/__init__.py b/src/braket/ahs/__init__.py index 8a9fd2666..5bf3cc61e 100644 --- a/src/braket/ahs/__init__.py +++ b/src/braket/ahs/__init__.py @@ -17,5 +17,6 @@ from braket.ahs.driving_field import DrivingField # noqa: F401 from braket.ahs.field import Field # noqa: F401 from braket.ahs.hamiltonian import Hamiltonian # noqa: F401 +from braket.ahs.local_detuning import LocalDetuning # noqa: F401 from braket.ahs.pattern import Pattern # noqa: F401 from braket.ahs.shifting_field import ShiftingField # noqa: F401 diff --git a/src/braket/ahs/analog_hamiltonian_simulation.py b/src/braket/ahs/analog_hamiltonian_simulation.py index b02091d4f..af02471e2 100644 --- a/src/braket/ahs/analog_hamiltonian_simulation.py +++ b/src/braket/ahs/analog_hamiltonian_simulation.py @@ -21,12 +21,12 @@ from braket.ahs.discretization_types import DiscretizationError, DiscretizationProperties from braket.ahs.driving_field import DrivingField from braket.ahs.hamiltonian import Hamiltonian -from braket.ahs.shifting_field import ShiftingField +from braket.ahs.local_detuning import LocalDetuning from braket.device_schema import DeviceActionType class AnalogHamiltonianSimulation: - SHIFTING_FIELDS_PROPERTY = "shifting_fields" + LOCAL_DETUNING_PROPERTY = "local_detuning" DRIVING_FIELDS_PROPERTY = "driving_fields" def __init__(self, register: AtomArrangement, hamiltonian: Hamiltonian) -> None: @@ -54,7 +54,7 @@ def to_ir(self) -> ir.Program: representation. Returns: - Program: A representation of the circuit in the IR format. + ir.Program: A representation of the circuit in the IR format. """ return ir.Program( setup=ir.Setup(ahs_register=self._register_to_ir()), @@ -74,10 +74,10 @@ def _hamiltonian_to_ir(self) -> ir.Hamiltonian: terms[term_type].append(term_ir) return ir.Hamiltonian( drivingFields=terms[AnalogHamiltonianSimulation.DRIVING_FIELDS_PROPERTY], - shiftingFields=terms[AnalogHamiltonianSimulation.SHIFTING_FIELDS_PROPERTY], + localDetuning=terms[AnalogHamiltonianSimulation.LOCAL_DETUNING_PROPERTY], ) - def discretize(self, device) -> AnalogHamiltonianSimulation: # noqa + def discretize(self, device: AwsDevice) -> AnalogHamiltonianSimulation: # noqa """Creates a new AnalogHamiltonianSimulation with all numerical values represented as Decimal objects with fixed precision based on the capabilities of the device. @@ -88,9 +88,8 @@ def discretize(self, device) -> AnalogHamiltonianSimulation: # noqa AnalogHamiltonianSimulation: A discretized version of this program. Raises: - DiscretizeError: If unable to discretize the program. + DiscretizationError: If unable to discretize the program. """ - required_action_schema = DeviceActionType.AHS if (required_action_schema not in device.properties.action) or ( device.properties.action[required_action_schema].actionType != required_action_schema @@ -117,8 +116,8 @@ def _get_term_ir( @_get_term_ir.register -def _(term: ShiftingField) -> tuple[str, ir.ShiftingField]: - return AnalogHamiltonianSimulation.SHIFTING_FIELDS_PROPERTY, ir.ShiftingField( +def _(term: LocalDetuning) -> tuple[str, ir.LocalDetuning]: + return AnalogHamiltonianSimulation.LOCAL_DETUNING_PROPERTY, ir.LocalDetuning( magnitude=ir.PhysicalField( time_series=ir.TimeSeries( times=term.magnitude.time_series.times(), diff --git a/src/braket/ahs/atom_arrangement.py b/src/braket/ahs/atom_arrangement.py index bb7088347..24d4fc9aa 100644 --- a/src/braket/ahs/atom_arrangement.py +++ b/src/braket/ahs/atom_arrangement.py @@ -73,6 +73,7 @@ def add( atom (in meters). The coordinates can be a numpy array of shape (2,) or a tuple of int, float, Decimal site_type (SiteType): The type of site. Optional. Default is FILLED. + Returns: AtomArrangement: returns self (to allow for chaining). """ @@ -109,6 +110,9 @@ def discretize(self, properties: DiscretizationProperties) -> AtomArrangement: properties (DiscretizationProperties): Capabilities of a device that represent the resolution with which the device can implement the parameters. + Raises: + DiscretizationError: If unable to discretize the program. + Returns: AtomArrangement: A new discretized atom arrangement. """ @@ -117,9 +121,9 @@ def discretize(self, properties: DiscretizationProperties) -> AtomArrangement: discretized_arrangement = AtomArrangement() for site in self._sites: new_coordinates = tuple( - (round(Decimal(c) / position_res) * position_res for c in site.coordinate) + round(Decimal(c) / position_res) * position_res for c in site.coordinate ) discretized_arrangement.add(new_coordinates, site.site_type) return discretized_arrangement except Exception as e: - raise DiscretizationError(f"Failed to discretize register {e}") + raise DiscretizationError(f"Failed to discretize register {e}") from e diff --git a/src/braket/ahs/discretization_types.py b/src/braket/ahs/discretization_types.py index c7df1fcfc..49efa0d34 100644 --- a/src/braket/ahs/discretization_types.py +++ b/src/braket/ahs/discretization_types.py @@ -18,8 +18,6 @@ class DiscretizationError(Exception): """Raised if the discretization of the numerical values of the AHS program fails.""" - pass - @dataclass class DiscretizationProperties: diff --git a/src/braket/ahs/driving_field.py b/src/braket/ahs/driving_field.py index 02c8bd276..cbf01838d 100644 --- a/src/braket/ahs/driving_field.py +++ b/src/braket/ahs/driving_field.py @@ -104,7 +104,6 @@ def stitch( Returns: DrivingField: The stitched DrivingField object. """ - amplitude = self.amplitude.time_series.stitch(other.amplitude.time_series, boundary) detuning = self.detuning.time_series.stitch(other.detuning.time_series, boundary) phase = self.phase.time_series.stitch(other.phase.time_series, boundary) @@ -123,17 +122,23 @@ def discretize(self, properties: DiscretizationProperties) -> DrivingField: """ driving_parameters = properties.rydberg.rydbergGlobal time_resolution = driving_parameters.timeResolution + + amplitude_value_resolution = driving_parameters.rabiFrequencyResolution discretized_amplitude = self.amplitude.discretize( time_resolution=time_resolution, - value_resolution=driving_parameters.rabiFrequencyResolution, + value_resolution=amplitude_value_resolution, ) + + phase_value_resolution = driving_parameters.phaseResolution discretized_phase = self.phase.discretize( time_resolution=time_resolution, - value_resolution=driving_parameters.phaseResolution, + value_resolution=phase_value_resolution, ) + + detuning_value_resolution = driving_parameters.detuningResolution discretized_detuning = self.detuning.discretize( time_resolution=time_resolution, - value_resolution=driving_parameters.detuningResolution, + value_resolution=detuning_value_resolution, ) return DrivingField( amplitude=discretized_amplitude, phase=discretized_phase, detuning=discretized_detuning @@ -143,8 +148,7 @@ def discretize(self, properties: DiscretizationProperties) -> DrivingField: def from_lists( times: list[float], amplitudes: list[float], detunings: list[float], phases: list[float] ) -> DrivingField: - """ - Builds DrivingField Hamiltonian from lists defining time evolution + """Builds DrivingField Hamiltonian from lists defining time evolution of Hamiltonian parameters (Rabi frequency, detuning, phase). The values of the parameters at each time points are global for all atoms. @@ -154,6 +158,9 @@ def from_lists( detunings (list[float]): The values of the detuning phases (list[float]): The values of the phase + Raises: + ValueError: If any of the input args length is different from the rest. + Returns: DrivingField: DrivingField Hamiltonian. """ diff --git a/src/braket/ahs/field.py b/src/braket/ahs/field.py index 9a473fd99..1522b9d65 100644 --- a/src/braket/ahs/field.py +++ b/src/braket/ahs/field.py @@ -16,7 +16,6 @@ from decimal import Decimal from typing import Optional -from braket.ahs.discretization_types import DiscretizationError from braket.ahs.pattern import Pattern from braket.timings.time_series import TimeSeries @@ -44,8 +43,8 @@ def pattern(self) -> Optional[Pattern]: def discretize( self, - time_resolution: Decimal, - value_resolution: Decimal, + time_resolution: Optional[Decimal] = None, + value_resolution: Optional[Decimal] = None, pattern_resolution: Optional[Decimal] = None, ) -> Field: """Creates a discretized version of the field, @@ -53,24 +52,17 @@ def discretize( closest multiple of their corresponding resolutions. Args: - time_resolution (Decimal): Time resolution - value_resolution (Decimal): Value resolution + time_resolution (Optional[Decimal]): Time resolution + value_resolution (Optional[Decimal]): Value resolution pattern_resolution (Optional[Decimal]): Pattern resolution Returns: Field: A new discretized field. - - Raises: - ValueError: if pattern_resolution is None, but there is a Pattern """ discretized_time_series = self.time_series.discretize(time_resolution, value_resolution) if self.pattern is None: discretized_pattern = None else: - if pattern_resolution is None: - raise DiscretizationError( - f"{self.pattern} is defined but has no pattern_resolution defined" - ) discretized_pattern = self.pattern.discretize(pattern_resolution) discretized_field = Field(time_series=discretized_time_series, pattern=discretized_pattern) return discretized_field diff --git a/src/braket/ahs/local_detuning.py b/src/braket/ahs/local_detuning.py new file mode 100644 index 000000000..00b420210 --- /dev/null +++ b/src/braket/ahs/local_detuning.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +from braket.ahs.discretization_types import DiscretizationProperties +from braket.ahs.field import Field +from braket.ahs.hamiltonian import Hamiltonian +from braket.ahs.pattern import Pattern +from braket.timings.time_series import StitchBoundaryCondition, TimeSeries + + +class LocalDetuning(Hamiltonian): + def __init__(self, magnitude: Field) -> None: + r"""Creates a Hamiltonian term :math:`H_{shift}` representing the local detuning + that changes the energy of the Rydberg level in an AnalogHamiltonianSimulation, + defined by the formula + + .. math:: + H_{shift} (t) := -\Delta(t) \sum_k h_k | r_k \rangle \langle r_k | + + where + + :math:`\Delta(t)` is the magnitude of the frequency shift in rad/s, + + :math:`h_k` is the site coefficient of atom :math:`k`, + a dimensionless real number between 0 and 1, + + :math:`|r_k \rangle` is the Rydberg state of atom :math:`k`. + + with the sum :math:`\sum_k` taken over all target atoms. + + Args: + magnitude (Field): containing the global magnitude time series :math:`\Delta(t)`, + where time is measured in seconds (s) and values are measured in rad/s, and the + local pattern :math:`h_k` of dimensionless real numbers between 0 and 1. + """ + super().__init__() + self._magnitude = magnitude + + @property + def terms(self) -> list[Hamiltonian]: + return [self] + + @property + def magnitude(self) -> Field: + r"""Field: containing the global magnitude time series :math:`\Delta(t)`, + where time is measured in seconds (s) and values measured in rad/s) + and the local pattern :math:`h_k` of dimensionless real numbers between 0 and 1. + """ + return self._magnitude + + @staticmethod + def from_lists(times: list[float], values: list[float], pattern: list[float]) -> LocalDetuning: + """Get the shifting field from a set of time points, values and pattern + + Args: + times (list[float]): The time points of the shifting field + values (list[float]): The values of the shifting field + pattern (list[float]): The pattern of the shifting field + + Raises: + ValueError: If the length of times and values differs. + + Returns: + LocalDetuning: The shifting field obtained + """ + if len(times) != len(values): + raise ValueError("The length of the times and values lists must be equal.") + + magnitude = TimeSeries() + for t, v in zip(times, values): + magnitude.put(t, v) + shift = LocalDetuning(Field(magnitude, Pattern(pattern))) + + return shift + + def stitch( + self, other: LocalDetuning, boundary: StitchBoundaryCondition = StitchBoundaryCondition.MEAN + ) -> LocalDetuning: + """Stitches two shifting fields based on TimeSeries.stitch method. + The time points of the second LocalDetuning are shifted such that the first time point of + the second LocalDetuning coincides with the last time point of the first LocalDetuning. + The boundary point value is handled according to StitchBoundaryCondition argument value. + + Args: + other (LocalDetuning): The second local detuning to be stitched with. + boundary (StitchBoundaryCondition): {"mean", "left", "right"}. Boundary point handler. + + Possible options are + - "mean" - take the average of the boundary value points of the first + and the second time series. + - "left" - use the last value from the left time series as the boundary point. + - "right" - use the first value from the right time series as the boundary point. + + Raises: + ValueError: The LocalDetuning patterns differ. + + Returns: + LocalDetuning: The stitched LocalDetuning object. + + Example (StitchBoundaryCondition.MEAN): + :: + time_series_1 = TimeSeries.from_lists(times=[0, 0.1], values=[1, 2]) + time_series_2 = TimeSeries.from_lists(times=[0.2, 0.4], values=[4, 5]) + + stitch_ts = time_series_1.stitch(time_series_2, boundary=StitchBoundaryCondition.MEAN) + + Result: + stitch_ts.times() = [0, 0.1, 0.3] + stitch_ts.values() = [1, 3, 5] + + Example (StitchBoundaryCondition.LEFT): + :: + stitch_ts = time_series_1.stitch(time_series_2, boundary=StitchBoundaryCondition.LEFT) + + Result: + stitch_ts.times() = [0, 0.1, 0.3] + stitch_ts.values() = [1, 2, 5] + + Example (StitchBoundaryCondition.RIGHT): + :: + stitch_ts = time_series_1.stitch(time_series_2, boundary=StitchBoundaryCondition.RIGHT) + + Result: + stitch_ts.times() = [0, 0.1, 0.3] + stitch_ts.values() = [1, 4, 5] + """ + if self.magnitude.pattern.series != other.magnitude.pattern.series: + raise ValueError("The LocalDetuning pattern for both fields must be equal.") + + new_ts = self.magnitude.time_series.stitch(other.magnitude.time_series, boundary) + return LocalDetuning(Field(new_ts, self.magnitude.pattern)) + + def discretize(self, properties: DiscretizationProperties) -> LocalDetuning: + """Creates a discretized version of the LocalDetuning. + + Args: + properties (DiscretizationProperties): Capabilities of a device that represent the + resolution with which the device can implement the parameters. + + Returns: + LocalDetuning: A new discretized LocalDetuning. + """ + local_detuning_parameters = properties.rydberg.rydbergLocal + time_resolution = ( + local_detuning_parameters.timeResolution if local_detuning_parameters else None + ) + discretized_magnitude = self.magnitude.discretize( + time_resolution=time_resolution, + ) + return LocalDetuning(discretized_magnitude) diff --git a/src/braket/ahs/pattern.py b/src/braket/ahs/pattern.py index 17e40a36f..92637fe0f 100644 --- a/src/braket/ahs/pattern.py +++ b/src/braket/ahs/pattern.py @@ -15,6 +15,7 @@ from decimal import Decimal from numbers import Number +from typing import Optional class Pattern: @@ -30,19 +31,25 @@ def __init__(self, series: list[Number]): @property def series(self) -> list[Number]: """list[Number]: A series of numbers representing the local - pattern of real numbers.""" + pattern of real numbers. + """ return self._series - def discretize(self, resolution: Decimal) -> Pattern: + def discretize(self, resolution: Optional[Decimal]) -> Pattern: """Creates a discretized version of the pattern, where each value is rounded to the closest multiple of the resolution. Args: - resolution (Decimal): Resolution of the discretization + resolution (Optional[Decimal]): Resolution of the discretization Returns: Pattern: The new discretized pattern """ - discretized_series = [round(Decimal(num) / resolution) * resolution for num in self.series] + if resolution is None: + discretized_series = [Decimal(num) for num in self.series] + else: + discretized_series = [ + round(Decimal(num) / resolution) * resolution for num in self.series + ] return Pattern(series=discretized_series) diff --git a/src/braket/ahs/shifting_field.py b/src/braket/ahs/shifting_field.py index 846d7ad2e..34e24191e 100644 --- a/src/braket/ahs/shifting_field.py +++ b/src/braket/ahs/shifting_field.py @@ -11,144 +11,10 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from __future__ import annotations +from braket.ahs.local_detuning import LocalDetuning -from braket.ahs.discretization_types import DiscretizationProperties -from braket.ahs.field import Field -from braket.ahs.hamiltonian import Hamiltonian -from braket.ahs.pattern import Pattern -from braket.timings.time_series import StitchBoundaryCondition, TimeSeries - - -class ShiftingField(Hamiltonian): - def __init__(self, magnitude: Field) -> None: - r"""Creates a Hamiltonian term :math:`H_{shift}` representing the shifting field - that changes the energy of the Rydberg level in an AnalogHamiltonianSimulation, - defined by the formula - - .. math:: - H_{shift} (t) := -\Delta(t) \sum_k h_k | r_k \rangle \langle r_k | - - where - - :math:`\Delta(t)` is the magnitude of the frequency shift in rad/s, - - :math:`h_k` is the site coefficient of atom :math:`k`, - a dimensionless real number between 0 and 1, - - :math:`|r_k \rangle` is the Rydberg state of atom :math:`k`. - - with the sum :math:`\sum_k` taken over all target atoms. - - Args: - magnitude (Field): containing the global magnitude time series :math:`\Delta(t)`, - where time is measured in seconds (s) and values are measured in rad/s, and the - local pattern :math:`h_k` of dimensionless real numbers between 0 and 1. - """ - super().__init__() - self._magnitude = magnitude - - @property - def terms(self) -> list[Hamiltonian]: - return [self] - - @property - def magnitude(self) -> Field: - r"""Field: containing the global magnitude time series :math:`\Delta(t)`, - where time is measured in seconds (s) and values measured in rad/s) - and the local pattern :math:`h_k` of dimensionless real numbers between 0 and 1.""" - return self._magnitude - - @staticmethod - def from_lists(times: list[float], values: list[float], pattern: list[float]) -> ShiftingField: - """Get the shifting field from a set of time points, values and pattern - - Args: - times (list[float]): The time points of the shifting field - values (list[float]): The values of the shifting field - pattern (list[float]): The pattern of the shifting field - - Returns: - ShiftingField: The shifting field obtained - """ - if len(times) != len(values): - raise ValueError("The length of the times and values lists must be equal.") - - magnitude = TimeSeries() - for t, v in zip(times, values): - magnitude.put(t, v) - shift = ShiftingField(Field(magnitude, Pattern(pattern))) - - return shift - - def stitch( - self, other: ShiftingField, boundary: StitchBoundaryCondition = StitchBoundaryCondition.MEAN - ) -> ShiftingField: - """Stitches two shifting fields based on TimeSeries.stitch method. - The time points of the second ShiftingField are shifted such that the first time point of - the second ShiftingField coincides with the last time point of the first ShiftingField. - The boundary point value is handled according to StitchBoundaryCondition argument value. - - Args: - other (ShiftingField): The second shifting field to be stitched with. - boundary (StitchBoundaryCondition): {"mean", "left", "right"}. Boundary point handler. - - Possible options are - - "mean" - take the average of the boundary value points of the first - and the second time series. - - "left" - use the last value from the left time series as the boundary point. - - "right" - use the first value from the right time series as the boundary point. - - Returns: - ShiftingField: The stitched ShiftingField object. - - Example (StitchBoundaryCondition.MEAN): - :: - time_series_1 = TimeSeries.from_lists(times=[0, 0.1], values=[1, 2]) - time_series_2 = TimeSeries.from_lists(times=[0.2, 0.4], values=[4, 5]) - - stitch_ts = time_series_1.stitch(time_series_2, boundary=StitchBoundaryCondition.MEAN) - - Result: - stitch_ts.times() = [0, 0.1, 0.3] - stitch_ts.values() = [1, 3, 5] - - Example (StitchBoundaryCondition.LEFT): - :: - stitch_ts = time_series_1.stitch(time_series_2, boundary=StitchBoundaryCondition.LEFT) - - Result: - stitch_ts.times() = [0, 0.1, 0.3] - stitch_ts.values() = [1, 2, 5] - - Example (StitchBoundaryCondition.RIGHT): - :: - stitch_ts = time_series_1.stitch(time_series_2, boundary=StitchBoundaryCondition.RIGHT) - - Result: - stitch_ts.times() = [0, 0.1, 0.3] - stitch_ts.values() = [1, 4, 5] - """ - if not (self.magnitude.pattern.series == other.magnitude.pattern.series): - raise ValueError("The ShiftingField pattern for both fields must be equal.") - - new_ts = self.magnitude.time_series.stitch(other.magnitude.time_series, boundary) - return ShiftingField(Field(new_ts, self.magnitude.pattern)) - - def discretize(self, properties: DiscretizationProperties) -> ShiftingField: - """Creates a discretized version of the ShiftingField. - - Args: - properties (DiscretizationProperties): Capabilities of a device that represent the - resolution with which the device can implement the parameters. - - Returns: - ShiftingField: A new discretized ShiftingField. - """ - shifting_parameters = properties.rydberg.rydbergLocal - discretized_magnitude = self.magnitude.discretize( - time_resolution=shifting_parameters.timeResolution, - value_resolution=shifting_parameters.commonDetuningResolution, - pattern_resolution=shifting_parameters.localDetuningResolution, - ) - return ShiftingField(discretized_magnitude) +# The class `ShiftingField` is deprecated. Please use `LocalDetuning` instead. +# This file and class will be removed in a future version. +# We are retaining this now to avoid breaking backwards compatibility for users already +# utilizing this nomenclature. +ShiftingField = LocalDetuning diff --git a/src/braket/annealing/problem.py b/src/braket/annealing/problem.py index d8b40c372..9515cbfa6 100644 --- a/src/braket/annealing/problem.py +++ b/src/braket/annealing/problem.py @@ -14,7 +14,6 @@ from __future__ import annotations from enum import Enum -from typing import Dict, Tuple import braket.ir.annealing as ir @@ -37,16 +36,16 @@ class Problem: def __init__( self, problem_type: ProblemType, - linear: Dict[int, float] | None = None, - quadratic: Dict[Tuple[int, int], float] | None = None, + linear: dict[int, float] | None = None, + quadratic: dict[tuple[int, int], float] | None = None, ): - """ + """Initializes a `Problem`. Args: problem_type (ProblemType): The type of annealing problem - linear (Dict[int, float] | None): The linear terms of this problem, + linear (dict[int, float] | None): The linear terms of this problem, as a map of variable to coefficient - quadratic (Dict[Tuple[int, int], float] | None): The quadratic terms of this problem, + quadratic (dict[tuple[int, int], float] | None): The quadratic terms of this problem, as a map of variables to coefficient Examples: @@ -71,20 +70,20 @@ def problem_type(self) -> ProblemType: return self._problem_type @property - def linear(self) -> Dict[int, float]: + def linear(self) -> dict[int, float]: """The linear terms of this problem. Returns: - Dict[int, float]: The linear terms of this problem, as a map of variable to coefficient + dict[int, float]: The linear terms of this problem, as a map of variable to coefficient """ return self._linear @property - def quadratic(self) -> Dict[Tuple[int, int], float]: + def quadratic(self) -> dict[tuple[int, int], float]: """The quadratic terms of this problem. Returns: - Dict[Tuple[int, int], float]: The quadratic terms of this problem, + dict[tuple[int, int], float]: The quadratic terms of this problem, as a map of variables to coefficient """ return self._quadratic @@ -102,11 +101,11 @@ def add_linear_term(self, term: int, coefficient: float) -> Problem: self._linear[term] = coefficient return self - def add_linear_terms(self, coefficients: Dict[int, float]) -> Problem: + def add_linear_terms(self, coefficients: dict[int, float]) -> Problem: """Adds linear terms to the problem. Args: - coefficients (Dict[int, float]): A map of variable to coefficient + coefficients (dict[int, float]): A map of variable to coefficient Returns: Problem: This problem object @@ -114,11 +113,11 @@ def add_linear_terms(self, coefficients: Dict[int, float]) -> Problem: self._linear.update(coefficients) return self - def add_quadratic_term(self, term: Tuple[int, int], coefficient: float) -> Problem: + def add_quadratic_term(self, term: tuple[int, int], coefficient: float) -> Problem: """Adds a quadratic term to the problem. Args: - term (Tuple[int, int]): The variables of the quadratic term + term (tuple[int, int]): The variables of the quadratic term coefficient (float): The coefficient of the quadratic term Returns: @@ -127,11 +126,11 @@ def add_quadratic_term(self, term: Tuple[int, int], coefficient: float) -> Probl self._quadratic[term] = coefficient return self - def add_quadratic_terms(self, coefficients: Dict[Tuple[int, int], float]) -> Problem: + def add_quadratic_terms(self, coefficients: dict[tuple[int, int], float]) -> Problem: """Adds quadratic terms to the problem. Args: - coefficients (Dict[Tuple[int, int], float]): A map of variables to coefficient + coefficients (dict[tuple[int, int], float]): A map of variables to coefficient Returns: Problem: This problem object diff --git a/src/braket/aws/__init__.py b/src/braket/aws/__init__.py index d0b3a3411..3be348f34 100644 --- a/src/braket/aws/__init__.py +++ b/src/braket/aws/__init__.py @@ -16,3 +16,4 @@ from braket.aws.aws_quantum_task import AwsQuantumTask # noqa: F401 from braket.aws.aws_quantum_task_batch import AwsQuantumTaskBatch # noqa: F401 from braket.aws.aws_session import AwsSession # noqa: F401 +from braket.aws.direct_reservations import DirectReservation # noqa: F401 diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index 14adacb3a..041098f5a 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -20,7 +20,7 @@ import warnings from datetime import datetime from enum import Enum -from typing import Optional, Union +from typing import Any, ClassVar, Optional, Union from botocore.errorfactory import ClientError from networkx import DiGraph, complete_graph, from_edgelist @@ -33,11 +33,12 @@ from braket.aws.queue_information import QueueDepthInfo, QueueType from braket.circuits import Circuit, Gate, QubitSet from braket.circuits.gate_calibrations import GateCalibrations +from braket.circuits.noise_model import NoiseModel from braket.device_schema import DeviceCapabilities, ExecutionDay, GateModelQpuParadigmProperties from braket.device_schema.dwave import DwaveProviderProperties -from braket.device_schema.pulse.pulse_device_action_properties_v1 import ( # noqa TODO: Remove device_action module once this is added to init in the schemas repo - PulseDeviceActionProperties, -) + +# TODO: Remove device_action module once this is added to init in the schemas repo +from braket.device_schema.pulse.pulse_device_action_properties_v1 import PulseDeviceActionProperties from braket.devices.device import Device from braket.ir.blackbird import Program as BlackbirdProgram from braket.ir.openqasm import Program as OpenQasmProgram @@ -56,13 +57,12 @@ class AwsDeviceType(str, Enum): class AwsDevice(Device): - """ - Amazon Braket implementation of a device. + """Amazon Braket implementation of a device. Use this class to retrieve the latest metadata about the device and to run a quantum task on the device. """ - REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2") + REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2", "eu-north-1") DEFAULT_SHOTS_QPU = 1000 DEFAULT_SHOTS_SIMULATOR = 0 @@ -70,7 +70,7 @@ class AwsDevice(Device): _GET_DEVICES_ORDER_BY_KEYS = frozenset({"arn", "name", "type", "provider_name", "status"}) - _RIGETTI_GATES_TO_BRAKET = { + _RIGETTI_GATES_TO_BRAKET: ClassVar[dict[str, str | None]] = { # Rx_12 does not exist in the Braket SDK, it is a gate between |1> and |2>. "Rx_12": None, "Cz": "CZ", @@ -78,11 +78,20 @@ class AwsDevice(Device): "Xy": "XY", } - def __init__(self, arn: str, aws_session: Optional[AwsSession] = None): - """ + def __init__( + self, + arn: str, + aws_session: Optional[AwsSession] = None, + noise_model: Optional[NoiseModel] = None, + ): + """Initializes an `AwsDevice`. + Args: arn (str): The ARN of the device aws_session (Optional[AwsSession]): An AWS session object. Default is `None`. + noise_model (Optional[NoiseModel]): The Braket noise model to apply to the circuit + before execution. Noise model can only be added to the devices that support + noise simulation. Note: Some devices (QPUs) are physically located in specific AWS Regions. In some cases, @@ -104,6 +113,9 @@ def __init__(self, arn: str, aws_session: Optional[AwsSession] = None): self._aws_session = self._get_session_and_initialize(aws_session or AwsSession()) self._ports = None self._frames = None + if noise_model: + self._validate_device_noise_model_support(noise_model) + self._noise_model = noise_model def run( self, @@ -122,15 +134,14 @@ def run( inputs: Optional[dict[str, float]] = None, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] = None, reservation_arn: str | None = None, - *aws_quantum_task_args, - **aws_quantum_task_kwargs, + *aws_quantum_task_args: Any, + **aws_quantum_task_kwargs: Any, ) -> AwsQuantumTask: - """ - Run a quantum task specification on this device. A quantum task can be a circuit or an + """Run a quantum task specification on this device. A quantum task can be a circuit or an annealing problem. Args: - task_specification (Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, PulseSequence, AnalogHamiltonianSimulation]): # noqa + task_specification (Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, PulseSequence, AnalogHamiltonianSimulation]): Specification of quantum task (circuit, OpenQASM program or AHS program) to run on device. s3_destination_folder (Optional[S3DestinationFolder]): The S3 location to @@ -156,6 +167,8 @@ def run( Note: If you are creating tasks in a job that itself was created reservation ARN, those tasks do not need to be created with the reservation ARN. Default: None. + *aws_quantum_task_args (Any): Arbitrary arguments. + **aws_quantum_task_kwargs (Any): Arbitrary keyword arguments. Returns: AwsQuantumTask: An AwsQuantumTask that tracks the execution on the device. @@ -188,7 +201,9 @@ def run( See Also: `braket.aws.aws_quantum_task.AwsQuantumTask.create()` - """ + """ # noqa E501 + if self._noise_model: + task_specification = self._apply_noise_model_to_circuit(task_specification) return AwsQuantumTask.create( self._aws_session, self._arn, @@ -283,7 +298,12 @@ def run_batch( See Also: `braket.aws.aws_quantum_task_batch.AwsQuantumTaskBatch` - """ + """ # noqa E501 + if self._noise_model: + task_specifications = [ + self._apply_noise_model_to_circuit(task_specification) + for task_specification in task_specifications + ] return AwsQuantumTaskBatch( AwsSession.copy_session(self._aws_session, max_connections=max_connections), self._arn, @@ -308,9 +328,7 @@ def run_batch( ) def refresh_metadata(self) -> None: - """ - Refresh the `AwsDevice` object with the most recent Device metadata. - """ + """Refresh the `AwsDevice` object with the most recent Device metadata.""" self._populate_properties(self._aws_session) def _get_session_and_initialize(self, session: AwsSession) -> AwsSession: @@ -336,7 +354,7 @@ def _get_regional_device_session(self, session: AwsSession) -> AwsSession: ValueError(f"'{self._arn}' not found") if e.response["Error"]["Code"] == "ResourceNotFoundException" else e - ) + ) from e def _get_non_regional_device_session(self, session: AwsSession) -> AwsSession: current_region = session.region @@ -344,11 +362,10 @@ def _get_non_regional_device_session(self, session: AwsSession) -> AwsSession: self._populate_properties(session) return session except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - if "qpu" not in self._arn: - raise ValueError(f"Simulator '{self._arn}' not found in '{current_region}'") - else: + if e.response["Error"]["Code"] != "ResourceNotFoundException": raise e + if "qpu" not in self._arn: + raise ValueError(f"Simulator '{self._arn}' not found in '{current_region}'") from e # Search remaining regions for QPU for region in frozenset(AwsDevice.REGIONS) - {current_region}: region_session = AwsSession.copy_session(session, region) @@ -398,8 +415,7 @@ def arn(self) -> str: @property def gate_calibrations(self) -> Optional[GateCalibrations]: - """ - Calibration data for a QPU. Calibration data is shown for gates on particular gubits. + """Calibration data for a QPU. Calibration data is shown for gates on particular gubits. If a QPU does not expose these calibrations, None is returned. Returns: @@ -413,6 +429,7 @@ def gate_calibrations(self) -> Optional[GateCalibrations]: @property def is_available(self) -> bool: """Returns true if the device is currently available. + Returns: bool: Return if the device is currently available. """ @@ -474,7 +491,8 @@ def properties(self) -> DeviceCapabilities: Please see `braket.device_schema` in amazon-braket-schemas-python_ - .. _amazon-braket-schemas-python: https://github.com/aws/amazon-braket-schemas-python""" + .. _amazon-braket-schemas-python: https://github.com/aws/amazon-braket-schemas-python + """ return self._properties @property @@ -500,8 +518,7 @@ def topology_graph(self) -> DiGraph: return self._topology_graph def _construct_topology_graph(self) -> DiGraph: - """ - Construct topology graph. If no such metadata is available, return `None`. + """Construct topology graph. If no such metadata is available, return `None`. Returns: DiGraph: topology of QPU as a networkx `DiGraph` object. @@ -538,9 +555,9 @@ def _default_max_parallel(self) -> int: return AwsDevice.DEFAULT_MAX_PARALLEL def __repr__(self): - return "Device('name': {}, 'arn': {})".format(self.name, self.arn) + return f"Device('name': {self.name}, 'arn': {self.arn})" - def __eq__(self, other): + def __eq__(self, other: AwsDevice): if isinstance(other, AwsDevice): return self.arn == other.arn return NotImplemented @@ -548,16 +565,18 @@ def __eq__(self, other): @property def frames(self) -> dict[str, Frame]: """Returns a dict mapping frame ids to the frame objects for predefined frames - for this device.""" + for this device. + """ self._update_pulse_properties() - return self._frames or dict() + return self._frames or {} @property def ports(self) -> dict[str, Port]: """Returns a dict mapping port ids to the port objects for predefined ports - for this device.""" + for this device. + """ self._update_pulse_properties() - return self._ports or dict() + return self._ports or {} @staticmethod def get_devices( @@ -569,8 +588,7 @@ def get_devices( order_by: str = "name", aws_session: Optional[AwsSession] = None, ) -> list[AwsDevice]: - """ - Get devices based on filters and desired ordering. The result is the AND of + """Get devices based on filters and desired ordering. The result is the AND of all the filters `arns`, `names`, `types`, `statuses`, `provider_names`. Examples: @@ -593,18 +611,18 @@ def get_devices( aws_session (Optional[AwsSession]): An AWS session object. Default is `None`. + Raises: + ValueError: order_by not in ['arn', 'name', 'type', 'provider_name', 'status'] + Returns: list[AwsDevice]: list of AWS devices """ - if order_by not in AwsDevice._GET_DEVICES_ORDER_BY_KEYS: raise ValueError( f"order_by '{order_by}' must be in {AwsDevice._GET_DEVICES_ORDER_BY_KEYS}" ) - types = ( - frozenset(types) if types else frozenset({device_type for device_type in AwsDeviceType}) - ) - aws_session = aws_session if aws_session else AwsSession() + types = frozenset(types or AwsDeviceType) + aws_session = aws_session or AwsSession() device_map = {} session_region = aws_session.boto_session.region_name search_regions = ( @@ -631,19 +649,18 @@ def get_devices( provider_names=provider_names, ) ] - device_map.update( - { - arn: AwsDevice(arn, session_for_region) - for arn in region_device_arns - if arn not in device_map - } - ) + device_map |= { + arn: AwsDevice(arn, session_for_region) + for arn in region_device_arns + if arn not in device_map + } except ClientError as e: error_code = e.response["Error"]["Code"] warnings.warn( f"{error_code}: Unable to search region '{region}' for devices." " Please check your settings or try again later." - f" Continuing without devices in '{region}'." + f" Continuing without devices in '{region}'.", + stacklevel=1, ) devices = list(device_map.values()) @@ -651,33 +668,34 @@ def get_devices( return devices def _update_pulse_properties(self) -> None: - if hasattr(self.properties, "pulse") and isinstance( + if not hasattr(self.properties, "pulse") or not isinstance( self.properties.pulse, PulseDeviceActionProperties ): - if self._ports is None: - self._ports = dict() - port_data = self.properties.pulse.ports - for port_id, port in port_data.items(): - self._ports[port_id] = Port( - port_id=port_id, dt=port.dt, properties=json.loads(port.json()) + return + if self._ports is None: + self._ports = {} + port_data = self.properties.pulse.ports + for port_id, port in port_data.items(): + self._ports[port_id] = Port( + port_id=port_id, dt=port.dt, properties=json.loads(port.json()) + ) + if self._frames is None: + self._frames = {} + if frame_data := self.properties.pulse.frames: + for frame_id, frame in frame_data.items(): + self._frames[frame_id] = Frame( + frame_id=frame_id, + port=self._ports[frame.portId], + frequency=frame.frequency, + phase=frame.phase, + is_predefined=True, + properties=json.loads(frame.json()), ) - if self._frames is None: - self._frames = dict() - frame_data = self.properties.pulse.frames - if frame_data: - for frame_id, frame in frame_data.items(): - self._frames[frame_id] = Frame( - frame_id=frame_id, - port=self._ports[frame.portId], - frequency=frame.frequency, - phase=frame.phase, - is_predefined=True, - properties=json.loads(frame.json()), - ) @staticmethod def get_device_region(device_arn: str) -> str: """Gets the region from a device arn. + Args: device_arn (str): The device ARN. @@ -689,15 +707,14 @@ def get_device_region(device_arn: str) -> str: """ try: return device_arn.split(":")[3] - except IndexError: + except IndexError as e: raise ValueError( f"Device ARN is not a valid format: {device_arn}. For valid Braket ARNs, " "see 'https://docs.aws.amazon.com/braket/latest/developerguide/braket-devices.html'" - ) + ) from e def queue_depth(self) -> QueueDepthInfo: - """ - Task queue depth refers to the total number of quantum tasks currently waiting + """Task queue depth refers to the total number of quantum tasks currently waiting to run on a particular device. Returns: @@ -744,8 +761,7 @@ def queue_depth(self) -> QueueDepthInfo: return QueueDepthInfo(**queue_info) def refresh_gate_calibrations(self) -> Optional[GateCalibrations]: - """ - Refreshes the gate calibration data upon request. + """Refreshes the gate calibration data upon request. If the device does not have calibration data, None is returned. @@ -769,15 +785,15 @@ def refresh_gate_calibrations(self) -> Optional[GateCalibrations]: json.loads(f.read().decode("utf-8")) ) return GateCalibrations(json_calibration_data) - except urllib.error.URLError: + except urllib.error.URLError as e: raise urllib.error.URLError( f"Unable to reach {self.properties.pulse.nativeGateCalibrationsRef}" - ) + ) from e else: return None def _parse_waveforms(self, waveforms_json: dict) -> dict: - waveforms = dict() + waveforms = {} for waveform in waveforms_json: parsed_waveform = _parse_waveform_from_calibration_schema(waveforms_json[waveform]) waveforms[parsed_waveform.id] = parsed_waveform @@ -791,8 +807,7 @@ def _parse_pulse_sequence( def _parse_calibration_json( self, calibration_data: dict ) -> dict[tuple[Gate, QubitSet], PulseSequence]: - """ - Takes the json string from the device calibration URL and returns a structured dictionary of + """Takes the json string from the device calibration URL and returns a structured dictionary of corresponding `dict[tuple[Gate, QubitSet], PulseSequence]` to represent the calibration data. Args: @@ -801,7 +816,7 @@ def _parse_calibration_json( Returns: dict[tuple[Gate, QubitSet], PulseSequence]: The - structured data based on a mapping of `tuple[Gate, Qubit]` to its calibration repesented as a + structured data based on a mapping of `tuple[Gate, Qubit]` to its calibration represented as a `PulseSequence`. """ # noqa: E501 diff --git a/src/braket/aws/aws_quantum_job.py b/src/braket/aws/aws_quantum_job.py index 31347311e..37eb67603 100644 --- a/src/braket/aws/aws_quantum_job.py +++ b/src/braket/aws/aws_quantum_job.py @@ -20,7 +20,7 @@ from enum import Enum from logging import Logger, getLogger from pathlib import Path -from typing import Any +from typing import Any, ClassVar import boto3 from botocore.exceptions import ClientError @@ -49,12 +49,14 @@ class AwsQuantumJob(QuantumJob): """Amazon Braket implementation of a quantum job.""" - TERMINAL_STATES = {"CANCELLED", "COMPLETED", "FAILED"} + TERMINAL_STATES: ClassVar[set[str]] = {"CANCELLED", "COMPLETED", "FAILED"} RESULTS_FILENAME = "results.json" RESULTS_TAR_FILENAME = "model.tar.gz" LOG_GROUP = "/aws/braket/jobs" class LogState(Enum): + """Log state enum.""" + TAILING = "tailing" JOB_COMPLETE = "job_complete" COMPLETE = "complete" @@ -223,7 +225,8 @@ def create( return job def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool = False): - """ + """Initializes an `AwsQuantumJob`. + Args: arn (str): The ARN of the hybrid job. aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services. @@ -231,6 +234,9 @@ def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool region of the hybrid job. quiet (bool): Sets the verbosity of the logger to low and does not report queue position. Default is `False`. + + Raises: + ValueError: Supplied region and session region do not match. """ self._arn: str = arn self._quiet = quiet @@ -246,8 +252,11 @@ def __init__(self, arn: str, aws_session: AwsSession | None = None, quiet: bool @staticmethod def _is_valid_aws_session_region_for_job_arn(aws_session: AwsSession, job_arn: str) -> bool: - """ - bool: `True` when the aws_session region matches the job_arn region; otherwise `False`. + """Checks whether the job region and session region match. + + Returns: + bool: `True` when the aws_session region matches the job_arn region; otherwise + `False`. """ job_region = job_arn.split(":")[3] return job_region == aws_session.region @@ -277,6 +286,17 @@ def name(self) -> str: """str: The name of the quantum job.""" return self.metadata(use_cached_value=True).get("jobName") + @property + def _logs_prefix(self) -> str: + """str: the prefix for the job logs.""" + # jobs ARNs used to contain the job name and use a log prefix of `job-name` + # now job ARNs use a UUID and a log prefix of `job-name/UUID` + return ( + f"{self.name}" + if self.arn.endswith(self.name) + else f"{self.name}/{self.arn.split('/')[-1]}" + ) + def state(self, use_cached_value: bool = False) -> str: """The state of the quantum hybrid job. @@ -285,6 +305,7 @@ def state(self, use_cached_value: bool = False) -> str: value from the Amazon Braket `GetJob` operation. If `False`, calls the `GetJob` operation to retrieve metadata, which also updates the cached value. Default = `False`. + Returns: str: The value of `status` in `metadata()`. This is the value of the `status` key in the Amazon Braket `GetJob` operation. @@ -295,8 +316,7 @@ def state(self, use_cached_value: bool = False) -> str: return self.metadata(use_cached_value).get("status") def queue_position(self) -> HybridJobQueueInfo: - """ - The queue position details for the hybrid job. + """The queue position details for the hybrid job. Returns: HybridJobQueueInfo: Instance of HybridJobQueueInfo class representing @@ -372,7 +392,6 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: ) log_group = AwsQuantumJob.LOG_GROUP - stream_prefix = f"{self.name}/" stream_names = [] # The list of log streams positions = {} # The current position in each stream, map of stream name -> position instance_count = self.metadata(use_cached_value=True)["instanceConfig"]["instanceCount"] @@ -386,14 +405,14 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: has_streams = logs.flush_log_streams( self._aws_session, log_group, - stream_prefix, + self._logs_prefix, stream_names, positions, instance_count, has_streams, color_wrap, [previous_state, current_state], - self.queue_position().queue_position if not self._quiet else None, + None if self._quiet else self.queue_position().queue_position, ) previous_state = current_state @@ -413,6 +432,7 @@ def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: from the Amazon Braket `GetJob` operation, if it exists; if does not exist, `GetJob` is called to retrieve the metadata. If `False`, always calls `GetJob`, which also updates the cached value. Default: `False`. + Returns: dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket. """ @@ -441,18 +461,19 @@ def metrics( when there is a conflict. Default: MetricStatistic.MAX. Returns: - dict[str, list[Any]] : The metrics data. + dict[str, list[Any]]: The metrics data. """ fetcher = CwlInsightsMetricsFetcher(self._aws_session) metadata = self.metadata(True) - job_name = metadata["jobName"] job_start = None job_end = None if "startedAt" in metadata: job_start = int(metadata["startedAt"].timestamp()) if self.state() in AwsQuantumJob.TERMINAL_STATES and "endedAt" in metadata: job_end = int(math.ceil(metadata["endedAt"].timestamp())) - return fetcher.get_metrics_for_job(job_name, metric_type, statistic, job_start, job_end) + return fetcher.get_metrics_for_job( + self.name, metric_type, statistic, job_start, job_end, self._logs_prefix + ) def cancel(self) -> str: """Cancels the job. @@ -471,7 +492,7 @@ def result( poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, ) -> dict[str, Any]: - """Retrieves the hybrid job result persisted using save_job_result() function. + """Retrieves the hybrid job result persisted using the `save_job_result` function. Args: poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`. @@ -486,7 +507,6 @@ def result( RuntimeError: if hybrid job is in a FAILED or CANCELLED state. TimeoutError: if hybrid job execution exceeds the polling timeout period. """ - with tempfile.TemporaryDirectory() as temp_dir: job_name = self.metadata(True)["jobName"] @@ -526,7 +546,6 @@ def download_result( RuntimeError: if hybrid job is in a FAILED or CANCELLED state. TimeoutError: if hybrid job execution exceeds the polling timeout period. """ - extract_to = extract_to or Path.cwd() timeout_time = time.time() + poll_timeout_seconds @@ -556,17 +575,16 @@ def _attempt_results_download(self, output_bucket_uri: str, output_s3_path: str) s3_uri=output_bucket_uri, filename=AwsQuantumJob.RESULTS_TAR_FILENAME ) except ClientError as e: - if e.response["Error"]["Code"] == "404": - exception_response = { - "Error": { - "Code": "404", - "Message": f"Error retrieving results, " - f"could not find results at '{output_s3_path}'", - } - } - raise ClientError(exception_response, "HeadObject") from e - else: + if e.response["Error"]["Code"] != "404": raise e + exception_response = { + "Error": { + "Code": "404", + "Message": f"Error retrieving results, " + f"could not find results at '{output_s3_path}'", + } + } + raise ClientError(exception_response, "HeadObject") from e @staticmethod def _extract_tar_file(extract_path: str) -> None: @@ -576,10 +594,8 @@ def _extract_tar_file(extract_path: str) -> None: def __repr__(self) -> str: return f"AwsQuantumJob('arn':'{self.arn}')" - def __eq__(self, other) -> bool: - if isinstance(other, AwsQuantumJob): - return self.arn == other.arn - return False + def __eq__(self, other: AwsQuantumJob) -> bool: + return self.arn == other.arn if isinstance(other, AwsQuantumJob) else False def __hash__(self) -> int: return hash(self.arn) @@ -613,7 +629,7 @@ def _initialize_regional_device_session( ValueError(f"'{device}' not found.") if e.response["Error"]["Code"] == "ResourceNotFoundException" else e - ) + ) from e @staticmethod def _initialize_non_regional_device_session( @@ -624,12 +640,11 @@ def _initialize_non_regional_device_session( aws_session.get_device(device) return aws_session except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - if "qpu" not in device: - raise ValueError(f"Simulator '{device}' not found in '{original_region}'") - else: + if e.response["Error"]["Code"] != "ResourceNotFoundException": raise e + if "qpu" not in device: + raise ValueError(f"Simulator '{device}' not found in '{original_region}'") from e for region in frozenset(AwsDevice.REGIONS) - {original_region}: device_session = aws_session.copy_session(region=region) try: diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index 7b349310b..a21c7782c 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -17,7 +17,7 @@ import time from functools import singledispatch from logging import Logger, getLogger -from typing import Any, Optional, Union +from typing import Any, ClassVar, Optional, Union import boto3 @@ -34,6 +34,7 @@ IRType, OpenQASMSerializationProperties, QubitReferenceType, + SerializableProgram, ) from braket.device_schema import GateModelParameters from braket.device_schema.dwave import ( @@ -75,12 +76,13 @@ class AwsQuantumTask(QuantumTask): """Amazon Braket implementation of a quantum task. A quantum task can be a circuit, - an OpenQASM program or an AHS program.""" + an OpenQASM program or an AHS program. + """ # TODO: Add API documentation that defines these states. Make it clear this is the contract. - NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"} - RESULTS_READY_STATES = {"COMPLETED"} - TERMINAL_STATES = RESULTS_READY_STATES.union(NO_RESULT_TERMINAL_STATES) + NO_RESULT_TERMINAL_STATES: ClassVar[set[str]] = {"FAILED", "CANCELLED"} + RESULTS_READY_STATES: ClassVar[set[str]] = {"COMPLETED"} + TERMINAL_STATES: ClassVar[set[str]] = RESULTS_READY_STATES.union(NO_RESULT_TERMINAL_STATES) DEFAULT_RESULTS_POLL_TIMEOUT = 432000 DEFAULT_RESULTS_POLL_INTERVAL = 1 @@ -104,7 +106,7 @@ def create( disable_qubit_rewiring: bool = False, tags: dict[str, str] | None = None, inputs: dict[str, float] | None = None, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None, + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None, quiet: bool = False, reservation_arn: str | None = None, *args, @@ -147,10 +149,9 @@ def create( IR. If the IR supports inputs, the inputs will be updated with this value. Default: {}. - gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None): - A `Dict` for user defined gate calibration. The calibration is defined for - for a particular `Gate` on a particular `QubitSet` and is represented by - a `PulseSequence`. + gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): A `dict` + of user defined gate calibrations. Each calibration is defined for a particular + `Gate` on a particular `QubitSet` and is represented by a `PulseSequence`. Default: None. quiet (bool): Sets the verbosity of the logger to low and does not report queue @@ -174,7 +175,7 @@ def create( See Also: `braket.aws.aws_quantum_simulator.AwsQuantumSimulator.run()` `braket.aws.aws_qpu.AwsQpu.run()` - """ + """ # noqa E501 if len(s3_destination_folder) != 2: raise ValueError( "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively." @@ -189,6 +190,7 @@ def create( if tags is not None: create_task_kwargs.update({"tags": tags}) inputs = inputs or {} + gate_definitions = gate_definitions or {} if reservation_arn: create_task_kwargs.update( @@ -204,10 +206,9 @@ def create( if isinstance(task_specification, Circuit): param_names = {param.name for param in task_specification.parameters} - unbounded_parameters = param_names - set(inputs.keys()) - if unbounded_parameters: + if unbounded_parameters := param_names - set(inputs.keys()): raise ValueError( - f"Cannot execute circuit with unbound parameters: " f"{unbounded_parameters}" + f"Cannot execute circuit with unbound parameters: {unbounded_parameters}" ) return _create_internal( @@ -233,7 +234,8 @@ def __init__( logger: Logger = getLogger(__name__), quiet: bool = False, ): - """ + """Initializes an `AwsQuantumTask`. + Args: arn (str): The ARN of the quantum task. aws_session (AwsSession | None): The `AwsSession` for connecting to AWS services. @@ -258,7 +260,6 @@ def __init__( >>> result = task.result() GateModelQuantumTaskResult(...) """ - self._arn: str = arn self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn( task_arn=arn @@ -276,9 +277,8 @@ def __init__( @staticmethod def _aws_session_for_task_arn(task_arn: str) -> AwsSession: - """ - Get an AwsSession for the Quantum Task ARN. The AWS session should be in the region of the - quantum task. + """Get an AwsSession for the Quantum Task ARN. The AWS session should be in the region of + the quantum task. Returns: AwsSession: `AwsSession` object with default `boto_session` in quantum task's region. @@ -294,58 +294,57 @@ def id(self) -> str: def _cancel_future(self) -> None: """Cancel the future if it exists. Else, create a cancelled future.""" - if hasattr(self, "_future"): - self._future.cancel() - else: + if not hasattr(self, "_future"): self._future = asyncio.Future() - self._future.cancel() + self._future.cancel() def cancel(self) -> None: """Cancel the quantum task. This cancels the future and the quantum task in Amazon - Braket.""" + Braket. + """ self._cancel_future() self._aws_session.cancel_quantum_task(self._arn) def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: - """ - Get quantum task metadata defined in Amazon Braket. + """Get quantum task metadata defined in Amazon Braket. Args: use_cached_value (bool): If `True`, uses the value most recently retrieved from the Amazon Braket `GetQuantumTask` operation, if it exists; if not, `GetQuantumTask` will be called to retrieve the metadata. If `False`, always calls `GetQuantumTask`, which also updates the cached value. Default: `False`. + Returns: dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, Amazon Braket is not called and the most recently retrieved value is used, unless `GetQuantumTask` was never called, in which case - it wil still be called to populate the metadata for the first time. + it will still be called to populate the metadata for the first time. """ if not use_cached_value or not self._metadata: self._metadata = self._aws_session.get_quantum_task(self._arn) return self._metadata def state(self, use_cached_value: bool = False) -> str: - """ - The state of the quantum task. + """The state of the quantum task. Args: use_cached_value (bool): If `True`, uses the value most recently retrieved from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the `GetQuantumTask` operation to retrieve metadata, which also updates the cached value. Default = `False`. + Returns: str: The value of `status` in `metadata()`. This is the value of the `status` key in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, the value most recently returned from the `GetQuantumTask` operation is used. + See Also: `metadata()` """ return self._status(use_cached_value) def queue_position(self) -> QuantumTaskQueueInfo: - """ - The queue position details for the quantum task. + """The queue position details for the quantum task. Returns: QuantumTaskQueueInfo: Instance of QuantumTaskQueueInfo class @@ -399,8 +398,7 @@ def result( ) -> Union[ GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult ]: - """ - Get the quantum task result by polling Amazon Braket to see if the task is completed. + """Get the quantum task result by polling Amazon Braket to see if the task is completed. Once the quantum task is completed, the result is retrieved from S3 and returned as a `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult` @@ -409,10 +407,10 @@ def result( Consecutive calls to this method return a cached result from the preceding request. Returns: - Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: # noqa - The result of the quantum task, if the quantum task completed successfully; returns + Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: The + result of the quantum task, if the quantum task completed successfully; returns `None` if the quantum task did not complete successfully or the future timed out. - """ + """ # noqa E501 if self._result or ( self._metadata and self._status(True) in self.NO_RESULT_TERMINAL_STATES ): @@ -445,15 +443,13 @@ def _get_future(self) -> asyncio.Future: return self._future def async_result(self) -> asyncio.Task: - """ - Get the quantum task result asynchronously. Consecutive calls to this method return + """Get the quantum task result asynchronously. Consecutive calls to this method return the result cached from the most recent request. """ return self._get_future() async def _create_future(self) -> asyncio.Task: - """ - Wrap the `_wait_for_completion` coroutine inside a future-like object. + """Wrap the `_wait_for_completion` coroutine inside a future-like object. Invoking this method starts the coroutine and returns back the future-like object that contains it. Note that this does not block on the coroutine to finish. @@ -467,19 +463,19 @@ async def _wait_for_completion( ) -> Union[ GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult ]: - """ - Waits for the quantum task to be completed, then returns the result from the S3 bucket. + """Waits for the quantum task to be completed, then returns the result from the S3 bucket. Returns: - Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in the + Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit, the result from the S3 bucket is loaded and returned. `None` is returned if a timeout occurs or task state is in `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`. + Note: Timeout and sleep intervals are defined in the constructor fields `poll_timeout_seconds` and `poll_interval_seconds` respectively. - """ + """ # noqa E501 self._logger.debug(f"Task {self._arn}: start polling for completion") start_time = time.time() @@ -511,10 +507,10 @@ async def _wait_for_completion( return None def _has_reservation_arn_from_metadata(self, current_metadata: dict[str, Any]) -> bool: - for association in current_metadata.get("associations", []): - if association.get("type") == "RESERVATION_TIME_WINDOW_ARN": - return True - return False + return any( + association.get("type") == "RESERVATION_TIME_WINDOW_ARN" + for association in current_metadata.get("associations", []) + ) def _download_result( self, @@ -545,10 +541,8 @@ def _download_result( def __repr__(self) -> str: return f"AwsQuantumTask('id/taskArn':'{self.id}')" - def __eq__(self, other) -> bool: - if isinstance(other, AwsQuantumTask): - return self.id == other.id - return False + def __eq__(self, other: AwsQuantumTask) -> bool: + return self.id == other.id if isinstance(other, AwsQuantumTask) else False def __hash__(self) -> int: return hash(self.id) @@ -563,7 +557,7 @@ def _create_internal( device_parameters: Union[dict, BraketSchemaBase], disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -579,11 +573,16 @@ def _( _device_parameters: Union[dict, BraketSchemaBase], # Not currently used for OpenQasmProgram _disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: - create_task_kwargs.update({"action": OpenQASMProgram(source=pulse_sequence.to_ir()).json()}) + openqasm_program = OpenQASMProgram( + source=pulse_sequence.to_ir(), + inputs=inputs or {}, + ) + + create_task_kwargs["action"] = openqasm_program.json() task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) @@ -597,7 +596,7 @@ def _( device_parameters: Union[dict, BraketSchemaBase], _disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -608,7 +607,7 @@ def _( source=openqasm_program.source, inputs=inputs_copy, ) - create_task_kwargs.update({"action": openqasm_program.json()}) + create_task_kwargs["action"] = openqasm_program.json() if device_parameters: final_device_parameters = ( _circuit_device_params_from_dict( @@ -616,17 +615,43 @@ def _( device_arn, GateModelParameters(qubitCount=0), # qubitCount unused ) - if type(device_parameters) is dict + if isinstance(device_parameters, dict) else device_parameters ) - create_task_kwargs.update( - {"deviceParameters": final_device_parameters.json(exclude_none=True)} - ) + create_task_kwargs["deviceParameters"] = final_device_parameters.json(exclude_none=True) task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) +@_create_internal.register +def _( + serializable_program: SerializableProgram, + aws_session: AwsSession, + create_task_kwargs: dict[str, Any], + device_arn: str, + device_parameters: Union[dict, BraketSchemaBase], + _disable_qubit_rewiring: bool, + inputs: dict[str, float], + gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + *args, + **kwargs, +) -> AwsQuantumTask: + openqasm_program = OpenQASMProgram(source=serializable_program.to_ir(ir_type=IRType.OPENQASM)) + return _create_internal( + openqasm_program, + aws_session, + create_task_kwargs, + device_arn, + device_parameters, + _disable_qubit_rewiring, + inputs, + gate_definitions, + *args, + **kwargs, + ) + + @_create_internal.register def _( blackbird_program: BlackbirdProgram, @@ -636,11 +661,11 @@ def _( _device_parameters: Union[dict, BraketSchemaBase], _disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: - create_task_kwargs.update({"action": blackbird_program.json()}) + create_task_kwargs["action"] = blackbird_program.json() task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) @@ -654,7 +679,7 @@ def _( device_parameters: Union[dict, BraketSchemaBase], disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -666,7 +691,7 @@ def _( ) final_device_parameters = ( _circuit_device_params_from_dict(device_parameters or {}, device_arn, paradigm_parameters) - if type(device_parameters) is dict + if isinstance(device_parameters, dict) else device_parameters ) @@ -675,7 +700,7 @@ def _( if ( disable_qubit_rewiring or Instruction(StartVerbatimBox()) in circuit.instructions - or gate_definitions is not None + or gate_definitions or any(isinstance(instruction.operator, PulseGate) for instruction in circuit.instructions) ): qubit_reference_type = QubitReferenceType.PHYSICAL @@ -698,12 +723,10 @@ def _( inputs=inputs_copy, ) - create_task_kwargs.update( - { - "action": openqasm_program.json(), - "deviceParameters": final_device_parameters.json(exclude_none=True), - } - ) + create_task_kwargs |= { + "action": openqasm_program.json(), + "deviceParameters": final_device_parameters.json(exclude_none=True), + } task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) @@ -720,19 +743,17 @@ def _( DwaveAdvantageDeviceParameters, Dwave2000QDeviceParameters, ], - _, + _: bool, inputs: dict[str, float], gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], *args, **kwargs, ) -> AwsQuantumTask: device_params = _create_annealing_device_params(device_parameters, device_arn) - create_task_kwargs.update( - { - "action": problem.to_ir().json(), - "deviceParameters": device_params.json(exclude_none=True), - } - ) + create_task_kwargs |= { + "action": problem.to_ir().json(), + "deviceParameters": device_params.json(exclude_none=True), + } task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) @@ -745,13 +766,13 @@ def _( create_task_kwargs: dict[str, Any], device_arn: str, device_parameters: dict, - _, + _: AnalogHamiltonianSimulationTaskResult, inputs: dict[str, float], gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], *args, **kwargs, ) -> AwsQuantumTask: - create_task_kwargs.update({"action": analog_hamiltonian_simulation.to_ir().json()}) + create_task_kwargs["action"] = analog_hamiltonian_simulation.to_ir().json() task_arn = aws_session.create_quantum_task(**create_task_kwargs) return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) @@ -788,7 +809,7 @@ def _create_annealing_device_params( Union[DwaveAdvantageDeviceParameters, Dwave2000QDeviceParameters]: The device parameters. """ - if type(device_params) is not dict: + if not isinstance(device_params, dict): device_params = device_params.dict() # check for device level or provider level parameters @@ -829,7 +850,7 @@ def _create_common_params( @singledispatch def _format_result( - result: Union[GateModelTaskResult, AnnealingTaskResult, PhotonicModelTaskResult] + result: Union[GateModelTaskResult, AnnealingTaskResult, PhotonicModelTaskResult], ) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: raise TypeError("Invalid result specification type") diff --git a/src/braket/aws/aws_quantum_task_batch.py b/src/braket/aws/aws_quantum_task_batch.py index 6c505e6ef..300963a6f 100644 --- a/src/braket/aws/aws_quantum_task_batch.py +++ b/src/braket/aws/aws_quantum_task_batch.py @@ -16,15 +16,18 @@ import time from concurrent.futures.thread import ThreadPoolExecutor from itertools import repeat -from typing import Union +from typing import Any, Union from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.annealing import Problem from braket.aws.aws_quantum_task import AwsQuantumTask from braket.aws.aws_session import AwsSession from braket.circuits import Circuit +from braket.circuits.gate import Gate from braket.ir.blackbird import Program as BlackbirdProgram from braket.ir.openqasm import Program as OpenQasmProgram +from braket.pulse.pulse_sequence import PulseSequence +from braket.registers.qubit_set import QubitSet from braket.tasks.quantum_task_batch import QuantumTaskBatch @@ -61,9 +64,16 @@ def __init__( poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, inputs: Union[dict[str, float], list[dict[str, float]]] | None = None, + gate_definitions: ( + Union[ + dict[tuple[Gate, QubitSet], PulseSequence], + list[dict[tuple[Gate, QubitSet], PulseSequence]], + ] + | None + ) = None, reservation_arn: str | None = None, - *aws_quantum_task_args, - **aws_quantum_task_kwargs, + *aws_quantum_task_args: Any, + **aws_quantum_task_kwargs: Any, ): """Creates a batch of quantum tasks. @@ -92,12 +102,17 @@ def __init__( inputs (Union[dict[str, float], list[dict[str, float]]] | None): Inputs to be passed along with the IR. If the IR supports inputs, the inputs will be updated with this value. Default: {}. + gate_definitions (Union[dict[tuple[Gate, QubitSet], PulseSequence], list[dict[tuple[Gate, QubitSet], PulseSequence]]] | None): # noqa: E501 + User-defined gate calibration. The calibration is defined for a particular `Gate` on a + particular `QubitSet` and is represented by a `PulseSequence`. Default: None. reservation_arn (str | None): The reservation ARN provided by Braket Direct to reserve exclusive usage for the device to run the quantum task on. Note: If you are creating tasks in a job that itself was created reservation ARN, those tasks do not need to be created with the reservation ARN. Default: None. - """ + *aws_quantum_task_args (Any): Arbitrary args for `QuantumTask`. + **aws_quantum_task_kwargs (Any): Arbitrary kwargs for `QuantumTask`., + """ # noqa E501 self._tasks = AwsQuantumTaskBatch._execute( aws_session, device_arn, @@ -109,6 +124,7 @@ def __init__( poll_timeout_seconds, poll_interval_seconds, inputs, + gate_definitions, reservation_arn, *aws_quantum_task_args, **aws_quantum_task_kwargs, @@ -132,7 +148,7 @@ def __init__( self._aws_quantum_task_kwargs = aws_quantum_task_kwargs @staticmethod - def _tasks_and_inputs( + def _tasks_inputs_gatedefs( task_specifications: Union[ Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation], list[ @@ -142,58 +158,64 @@ def _tasks_and_inputs( ], ], inputs: Union[dict[str, float], list[dict[str, float]]] = None, + gate_definitions: Union[ + dict[tuple[Gate, QubitSet], PulseSequence], + list[dict[tuple[Gate, QubitSet], PulseSequence]], + ] = None, ) -> list[ tuple[ Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation], dict[str, float], + dict[tuple[Gate, QubitSet], PulseSequence], ] ]: inputs = inputs or {} - - max_inputs_tasks = 1 - single_task = isinstance( - task_specifications, - (Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation), + gate_definitions = gate_definitions or {} + + single_task_type = ( + Circuit, + Problem, + OpenQasmProgram, + BlackbirdProgram, + AnalogHamiltonianSimulation, ) - single_input = isinstance(inputs, dict) - - max_inputs_tasks = ( - max(max_inputs_tasks, len(task_specifications)) if not single_task else max_inputs_tasks - ) - max_inputs_tasks = ( - max(max_inputs_tasks, len(inputs)) if not single_input else max_inputs_tasks - ) - - if not single_task and not single_input: - if len(task_specifications) != len(inputs): - raise ValueError( - "Multiple inputs and task specifications must " "be equal in number." - ) + single_input_type = dict + single_gate_definitions_type = dict - if single_task: - task_specifications = repeat(task_specifications, times=max_inputs_tasks) + args = [task_specifications, inputs, gate_definitions] + single_arg_types = [single_task_type, single_input_type, single_gate_definitions_type] - if single_input: - inputs = repeat(inputs, times=max_inputs_tasks) + batch_length = 1 + arg_lengths = [] + for arg, single_arg_type in zip(args, single_arg_types): + arg_length = 1 if isinstance(arg, single_arg_type) else len(arg) + arg_lengths.append(arg_length) - tasks_and_inputs = zip(task_specifications, inputs) + if arg_length != 1: + if batch_length != 1 and arg_length != batch_length: + raise ValueError( + "Multiple inputs, task specifications and gate definitions must " + "be equal in length." + ) + else: + batch_length = arg_length - if single_task and single_input: - tasks_and_inputs = list(tasks_and_inputs) + for i, arg_length in enumerate(arg_lengths): + if isinstance(args[i], (dict, single_task_type)): + args[i] = repeat(args[i], batch_length) - tasks_and_inputs = list(tasks_and_inputs) + tasks_inputs_definitions = list(zip(*args)) - for task_specification, input_map in tasks_and_inputs: + for task_specification, input_map, _gate_definitions in tasks_inputs_definitions: if isinstance(task_specification, Circuit): param_names = {param.name for param in task_specification.parameters} - unbounded_parameters = param_names - set(input_map.keys()) - if unbounded_parameters: + if unbounded_parameters := param_names - set(input_map.keys()): raise ValueError( f"Cannot execute circuit with unbound parameters: " f"{unbounded_parameters}" ) - return tasks_and_inputs + return tasks_inputs_definitions @staticmethod def _execute( @@ -214,13 +236,22 @@ def _execute( poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, inputs: Union[dict[str, float], list[dict[str, float]]] = None, + gate_definitions: ( + Union[ + dict[tuple[Gate, QubitSet], PulseSequence], + list[dict[tuple[Gate, QubitSet], PulseSequence]], + ] + | None + ) = None, reservation_arn: str | None = None, *args, **kwargs, ) -> list[AwsQuantumTask]: - tasks_and_inputs = AwsQuantumTaskBatch._tasks_and_inputs(task_specifications, inputs) + tasks_inputs_gatedefs = AwsQuantumTaskBatch._tasks_inputs_gatedefs( + task_specifications, inputs, gate_definitions + ) max_threads = min(max_parallel, max_workers) - remaining = [0 for _ in tasks_and_inputs] + remaining = [0 for _ in tasks_inputs_gatedefs] try: with ThreadPoolExecutor(max_workers=max_threads) as executor: task_futures = [ @@ -235,11 +266,12 @@ def _execute( poll_timeout_seconds=poll_timeout_seconds, poll_interval_seconds=poll_interval_seconds, inputs=input_map, + gate_definitions=gatedefs, reservation_arn=reservation_arn, *args, **kwargs, ) - for task, input_map in tasks_and_inputs + for task, input_map, gatedefs in tasks_inputs_gatedefs ] except KeyboardInterrupt: # If an exception is thrown before the thread pool has finished, @@ -267,6 +299,7 @@ def _create_task( shots: int, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, inputs: dict[str, float] = None, + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None, reservation_arn: str | None = None, *args, **kwargs, @@ -279,6 +312,7 @@ def _create_task( shots, poll_interval_seconds=poll_interval_seconds, inputs=inputs, + gate_definitions=gate_definitions, reservation_arn=reservation_arn, *args, **kwargs, @@ -288,9 +322,7 @@ def _create_task( # If the quantum task hits a terminal state before all quantum tasks have been created, # it can be returned immediately - while remaining: - if task.state() in AwsQuantumTask.TERMINAL_STATES: - break + while remaining and task.state() not in AwsQuantumTask.TERMINAL_STATES: time.sleep(poll_interval_seconds) return task @@ -328,7 +360,7 @@ def results( retries = 0 while self._unsuccessful and retries < max_retries: self.retry_unsuccessful_tasks() - retries = retries + 1 + retries += 1 if fail_unsuccessful and self._unsuccessful: raise RuntimeError( @@ -386,7 +418,8 @@ def retry_unsuccessful_tasks(self) -> bool: @property def tasks(self) -> list[AwsQuantumTask]: """list[AwsQuantumTask]: The quantum tasks in this batch, as a list of AwsQuantumTask - objects""" + objects + """ return list(self._tasks) @property @@ -396,7 +429,8 @@ def size(self) -> int: @property def unfinished(self) -> set[str]: - """Gets all the IDs of all the quantum tasks in teh batch that have yet to complete. + """Gets all the IDs of all the quantum tasks in the batch that have yet to complete. + Returns: set[str]: The IDs of all the quantum tasks in the batch that have yet to complete. """ @@ -414,5 +448,6 @@ def unfinished(self) -> set[str]: @property def unsuccessful(self) -> set[str]: """set[str]: The IDs of all the FAILED, CANCELLED, or timed out quantum tasks in the - batch.""" + batch. + """ return set(self._unsuccessful) diff --git a/src/braket/aws/aws_session.py b/src/braket/aws/aws_session.py index 0534871a2..b4cdfcd31 100644 --- a/src/braket/aws/aws_session.py +++ b/src/braket/aws/aws_session.py @@ -17,6 +17,7 @@ import os import os.path import re +import warnings from functools import cache from pathlib import Path from typing import Any, NamedTuple, Optional @@ -33,10 +34,14 @@ from braket.tracking.tracking_events import _TaskCreationEvent, _TaskStatusEvent -class AwsSession(object): +class AwsSession: """Manage interactions with AWS services.""" - S3DestinationFolder = NamedTuple("S3DestinationFolder", [("bucket", str), ("key", str)]) + class S3DestinationFolder(NamedTuple): + """A `NamedTuple` for an S3 bucket and object key.""" + + bucket: str + key: str def __init__( self, @@ -45,12 +50,16 @@ def __init__( config: Config | None = None, default_bucket: str | None = None, ): - """ + """Initializes an `AwsSession`. + Args: - boto_session (Session | None): A boto3 session object. + boto_session (boto3.Session | None): A boto3 session object. braket_client (client | None): A boto3 Braket client. config (Config | None): A botocore Config object. default_bucket (str | None): The name of the default bucket of the AWS Session. + + Raises: + ValueError: invalid boto_session or braket_client. """ if ( boto_session @@ -77,7 +86,6 @@ def __init__( self.braket_client = self.boto_session.client( "braket", config=self._config, endpoint_url=os.environ.get("BRAKET_ENDPOINT") ) - self._update_user_agent() self._custom_default_bucket = bool(default_bucket) self._default_bucket = default_bucket or os.environ.get("AMZN_BRAKET_OUT_S3_BUCKET") @@ -93,6 +101,7 @@ def __init__( self._sts = None self._logs = None self._ecr = None + self._account_id = None @property def region(self) -> str: @@ -100,7 +109,14 @@ def region(self) -> str: @property def account_id(self) -> str: - return self.sts_client.get_caller_identity()["Account"] + """Gets the caller's account number. + + Returns: + str: The account number of the caller. + """ + if not self._account_id: + self._account_id = self.sts_client.get_caller_identity()["Account"] + return self._account_id @property def iam_client(self) -> client: @@ -158,8 +174,7 @@ def ecr_client(self) -> client: return self._ecr def _update_user_agent(self) -> None: - """ - Updates the `User-Agent` header forwarded by boto3 to include the braket-sdk, + """Updates the `User-Agent` header forwarded by boto3 to include the braket-sdk, braket-schemas and the notebook instance version. The header is a string of space delimited values (For example: "Boto3/1.14.43 Python/3.7.9 Botocore/1.17.44"). """ @@ -176,8 +191,7 @@ def _notebook_instance_version() -> str: ) def add_braket_user_agent(self, user_agent: str) -> None: - """ - Appends the `user-agent` value to the User-Agent header, if it does not yet exist in the + """Appends the `user-agent` value to the User-Agent header, if it does not yet exist in the header. This method is typically only relevant for libraries integrating with the Amazon Braket SDK. @@ -204,8 +218,7 @@ def _add_cost_tracker_count_handler(request: awsrequest.AWSRequest, **kwargs) -> # Quantum Tasks # def cancel_quantum_task(self, arn: str) -> None: - """ - Cancel the quantum task. + """Cancel the quantum task. Args: arn (str): The ARN of the quantum task to cancel. @@ -214,20 +227,45 @@ def cancel_quantum_task(self, arn: str) -> None: broadcast_event(_TaskStatusEvent(arn=arn, status=response["cancellationStatus"])) def create_quantum_task(self, **boto3_kwargs) -> str: - """ - Create a quantum task. + """Create a quantum task. Args: - ``**boto3_kwargs``: Keyword arguments for the Amazon Braket `CreateQuantumTask` + **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateQuantumTask` operation. Returns: str: The ARN of the quantum task. """ + # Add reservation arn if available and device is correct. + context_device_arn = os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") + context_reservation_arn = os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") + + # if the task has a reservation_arn and also context does, raise a warning + # Raise warning if reservation ARN is found in both context and task parameters + task_has_reservation = any( + item.get("type") == "RESERVATION_TIME_WINDOW_ARN" + for item in boto3_kwargs.get("associations", []) + ) + if task_has_reservation and context_reservation_arn: + warnings.warn( + "A reservation ARN was passed to 'CreateQuantumTask', but it is being overridden " + "by a 'DirectReservation' context. If this was not intended, please review your " + "reservation ARN settings or the context in which 'CreateQuantumTask' is called." + ) + + # Ensure reservation only applies to specific device + if context_device_arn == boto3_kwargs["deviceArn"] and context_reservation_arn: + boto3_kwargs["associations"] = [ + { + "arn": context_reservation_arn, + "type": "RESERVATION_TIME_WINDOW_ARN", + } + ] + # Add job token to request, if available. job_token = os.getenv("AMZN_BRAKET_JOB_TOKEN") if job_token: - boto3_kwargs.update({"jobToken": job_token}) + boto3_kwargs["jobToken"] = job_token response = self.braket_client.create_quantum_task(**boto3_kwargs) broadcast_event( _TaskCreationEvent( @@ -240,11 +278,10 @@ def create_quantum_task(self, **boto3_kwargs) -> str: return response["quantumTaskArn"] def create_job(self, **boto3_kwargs) -> str: - """ - Create a quantum hybrid job. + """Create a quantum hybrid job. Args: - ``**boto3_kwargs``: Keyword arguments for the Amazon Braket `CreateJob` operation. + **boto3_kwargs: Keyword arguments for the Amazon Braket `CreateJob` operation. Returns: str: The ARN of the hybrid job. @@ -271,8 +308,7 @@ def _should_giveup(err: Exception) -> bool: giveup=_should_giveup.__func__, ) def get_quantum_task(self, arn: str) -> dict[str, Any]: - """ - Gets the quantum task. + """Gets the quantum task. Args: arn (str): The ARN of the quantum task to get. @@ -287,9 +323,8 @@ def get_quantum_task(self, arn: str) -> dict[str, Any]: return response def get_default_jobs_role(self) -> str: - """ - Returns the role ARN for the default hybrid jobs role created in the Amazon Braket Console. - It will pick the first role it finds with the `RoleName` prefix + """This returns the role ARN for the default hybrid jobs role created in the Amazon Braket + Console. It will pick the first role it finds with the `RoleName` prefix `AmazonBraketJobsExecutionRole` with a `PathPrefix` of `/service-role/`. Returns: @@ -298,7 +333,7 @@ def get_default_jobs_role(self) -> str: Raises: RuntimeError: If no roles can be found with the prefix - `/service-role/AmazonBraketJobsExecutionRole`. + `/service-role/AmazonBraketJobsExecutionRole`. """ roles_paginator = self.iam_client.get_paginator("list_roles") for page in roles_paginator.paginate(PathPrefix="/service-role/"): @@ -318,8 +353,7 @@ def get_default_jobs_role(self) -> str: giveup=_should_giveup.__func__, ) def get_job(self, arn: str) -> dict[str, Any]: - """ - Gets the hybrid job. + """Gets the hybrid job. Args: arn (str): The ARN of the hybrid job to get. @@ -330,8 +364,7 @@ def get_job(self, arn: str) -> dict[str, Any]: return self.braket_client.get_job(jobArn=arn, additionalAttributeNames=["QueueInfo"]) def cancel_job(self, arn: str) -> dict[str, Any]: - """ - Cancel the hybrid job. + """Cancel the hybrid job. Args: arn (str): The ARN of the hybrid job to cancel. @@ -342,8 +375,7 @@ def cancel_job(self, arn: str) -> dict[str, Any]: return self.braket_client.cancel_job(jobArn=arn) def retrieve_s3_object_body(self, s3_bucket: str, s3_object_key: str) -> str: - """ - Retrieve the S3 object body. + """Retrieve the S3 object body. Args: s3_bucket (str): The S3 bucket name. @@ -367,8 +399,7 @@ def upload_to_s3(self, filename: str, s3_uri: str) -> None: self.s3_client.upload_file(filename, bucket, key) def upload_local_data(self, local_prefix: str, s3_prefix: str) -> None: - """ - Upload local data matching a prefix to a corresponding location in S3 + """Upload local data matching a prefix to a corresponding location in S3 Args: local_prefix (str): a prefix designating files to be uploaded to S3. All files @@ -398,7 +429,7 @@ def upload_local_data(self, local_prefix: str, s3_prefix: str) -> None: relative_prefix = str(Path(local_prefix).relative_to(base_dir)) else: base_dir = Path() - relative_prefix = str(local_prefix) + relative_prefix = local_prefix for file in itertools.chain( # files that match the prefix base_dir.glob(f"{relative_prefix}*"), @@ -410,8 +441,7 @@ def upload_local_data(self, local_prefix: str, s3_prefix: str) -> None: self.upload_to_s3(str(file), s3_uri) def download_from_s3(self, s3_uri: str, filename: str) -> None: - """ - Download file from S3 + """Download file from S3 Args: s3_uri (str): The S3 uri from where the file will be downloaded. @@ -421,8 +451,7 @@ def download_from_s3(self, s3_uri: str, filename: str) -> None: self.s3_client.download_file(bucket, key, filename) def copy_s3_object(self, source_s3_uri: str, destination_s3_uri: str) -> None: - """ - Copy object from another location in s3. Does nothing if source and + """Copy object from another location in s3. Does nothing if source and destination URIs are the same. Args: @@ -445,8 +474,7 @@ def copy_s3_object(self, source_s3_uri: str, destination_s3_uri: str) -> None: ) def copy_s3_directory(self, source_s3_path: str, destination_s3_path: str) -> None: - """ - Copy all objects from a specified directory in S3. Does nothing if source and + """Copy all objects from a specified directory in S3. Does nothing if source and destination URIs are the same. Preserves nesting structure, will not overwrite other files in the destination location unless they share a name with a file being copied. @@ -475,8 +503,7 @@ def copy_s3_directory(self, source_s3_path: str, destination_s3_path: str) -> No ) def list_keys(self, bucket: str, prefix: str) -> list[str]: - """ - Lists keys matching prefix in bucket. + """Lists keys matching prefix in bucket. Args: bucket (str): Bucket to be queried. @@ -501,8 +528,7 @@ def list_keys(self, bucket: str, prefix: str) -> list[str]: return keys def default_bucket(self) -> str: - """ - Returns the name of the default bucket of the AWS Session. In the following order + """Returns the name of the default bucket of the AWS Session. In the following order of priority, it will return either the parameter `default_bucket` set during initialization of the AwsSession (if not None), the bucket being used by the currently running Braket Hybrid Job (if evoked inside of a Braket Hybrid Job), or a default @@ -580,7 +606,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name: str, region: str) error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] - if error_code == "BucketAlreadyOwnedByYou": + if ( + error_code == "BucketAlreadyOwnedByYou" + or error_code != "BucketAlreadyExists" + and error_code == "OperationAborted" + and "conflicting conditional operation" in message + ): pass elif error_code == "BucketAlreadyExists": raise ValueError( @@ -588,18 +619,11 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name: str, region: str) f"for another account. Please supply alternative " f"bucket name via AwsSession constructor `AwsSession()`." ) from None - elif ( - error_code == "OperationAborted" and "conflicting conditional operation" in message - ): - # If this bucket is already being concurrently created, we don't need to create - # it again. - pass else: raise def get_device(self, arn: str) -> dict[str, Any]: - """ - Calls the Amazon Braket `get_device` API to retrieve device metadata. + """Calls the Amazon Braket `get_device` API to retrieve device metadata. Args: arn (str): The ARN of the device. @@ -617,8 +641,7 @@ def search_devices( statuses: Optional[list[str]] = None, provider_names: Optional[list[str]] = None, ) -> list[dict[str, Any]]: - """ - Get devices based on filters. The result is the AND of + """Get devices based on filters. The result is the AND of all the filters `arns`, `names`, `types`, `statuses`, `provider_names`. Args: @@ -657,6 +680,7 @@ def search_devices( @staticmethod def is_s3_uri(string: str) -> bool: """Determines if a given string is an S3 URI. + Args: string (str): the string to check. @@ -671,8 +695,7 @@ def is_s3_uri(string: str) -> bool: @staticmethod def parse_s3_uri(s3_uri: str) -> tuple[str, str]: - """ - Parse S3 URI to get bucket and key + """Parse S3 URI to get bucket and key Args: s3_uri (str): S3 URI. @@ -690,12 +713,12 @@ def parse_s3_uri(s3_uri: str) -> tuple[str, str]: s3_uri_match = re.match(r"^https://([^./]+)\.[sS]3\.[^/]+/(.+)$", s3_uri) or re.match( r"^[sS]3://([^./]+)/(.+)$", s3_uri ) - assert s3_uri_match + if s3_uri_match is None: + raise AssertionError bucket, key = s3_uri_match.groups() - assert bucket and key return bucket, key - except (AssertionError, ValueError): - raise ValueError(f"Not a valid S3 uri: {s3_uri}") + except (AssertionError, ValueError) as e: + raise ValueError(f"Not a valid S3 uri: {s3_uri}") from e @staticmethod def construct_s3_uri(bucket: str, *dirs: str) -> str: @@ -703,7 +726,7 @@ def construct_s3_uri(bucket: str, *dirs: str) -> str: Args: bucket (str): S3 URI. - ``*dirs`` (str): directories to be appended in the resulting S3 URI + *dirs (str): directories to be appended in the resulting S3 URI Returns: str: S3 URI @@ -723,8 +746,7 @@ def describe_log_streams( limit: Optional[int] = None, next_token: Optional[str] = None, ) -> dict[str, Any]: - """ - Describes CloudWatch log streams in a log group with a given prefix. + """Describes CloudWatch log streams in a log group with a given prefix. Args: log_group (str): Name of the log group. @@ -735,7 +757,7 @@ def describe_log_streams( Would have been received in a previous call. Returns: - dict[str, Any]: Dicionary containing logStreams and nextToken + dict[str, Any]: Dictionary containing logStreams and nextToken """ log_stream_args = { "logGroupName": log_group, @@ -744,10 +766,10 @@ def describe_log_streams( } if limit: - log_stream_args.update({"limit": limit}) + log_stream_args["limit"] = limit if next_token: - log_stream_args.update({"nextToken": next_token}) + log_stream_args["nextToken"] = next_token return self.logs_client.describe_log_streams(**log_stream_args) @@ -759,8 +781,7 @@ def get_log_events( start_from_head: bool = True, next_token: Optional[str] = None, ) -> dict[str, Any]: - """ - Gets CloudWatch log events from a given log stream. + """Gets CloudWatch log events from a given log stream. Args: log_group (str): Name of the log group. @@ -772,7 +793,7 @@ def get_log_events( Would have been received in a previous call. Returns: - dict[str, Any]: Dicionary containing events, nextForwardToken, and nextBackwardToken + dict[str, Any]: Dictionary containing events, nextForwardToken, and nextBackwardToken """ log_events_args = { "logGroupName": log_group, @@ -782,7 +803,7 @@ def get_log_events( } if next_token: - log_events_args.update({"nextToken": next_token}) + log_events_args["nextToken"] = next_token return self.logs_client.get_log_events(**log_events_args) @@ -791,8 +812,7 @@ def copy_session( region: Optional[str] = None, max_connections: Optional[int] = None, ) -> AwsSession: - """ - Creates a new AwsSession based on the region. + """Creates a new AwsSession based on the region. Args: region (Optional[str]): Name of the region. Default = `None`. @@ -805,6 +825,11 @@ def copy_session( config = Config(max_pool_connections=max_connections) if max_connections else None session_region = self.boto_session.region_name new_region = region or session_region + + # note that this method does not copy a custom Braket endpoint URL, since those are + # region-specific. If you have an endpoint that you wish to be used by copied AwsSessions + # (i.e. for task batching), please use the `BRAKET_ENDPOINT` environment variable. + creds = self.boto_session.get_credentials() default_bucket = self._default_bucket if self._custom_default_bucket else None profile_name = self.boto_session.profile_name @@ -833,8 +858,7 @@ def copy_session( @cache def get_full_image_tag(self, image_uri: str) -> str: - """ - Get verbose image tag from image uri. + """Get verbose image tag from image uri. Args: image_uri (str): Image uri to get tag for. diff --git a/src/braket/aws/direct_reservations.py b/src/braket/aws/direct_reservations.py new file mode 100644 index 000000000..4ffcb8fce --- /dev/null +++ b/src/braket/aws/direct_reservations.py @@ -0,0 +1,98 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +import os +import warnings +from contextlib import AbstractContextManager + +from braket.aws.aws_device import AwsDevice +from braket.devices import Device + + +class DirectReservation(AbstractContextManager): + """ + Context manager that modifies AwsQuantumTasks created within the context to use a reservation + ARN for all tasks targeting the specified device. Note: this context manager only allows for + one reservation at a time. + + Reservations are AWS account and device specific. Only the AWS account that created the + reservation can use your reservation ARN. Additionally, the reservation ARN is only valid on the + reserved device at the chosen start and end times. + + Args: + device (Device | str | None): The Braket device for which you have a reservation ARN, or + optionally the device ARN. + reservation_arn (str | None): The Braket Direct reservation ARN to be applied to all + quantum tasks run within the context. + + Examples: + As a context manager + >>> with DirectReservation(device_arn, reservation_arn=""): + ... task1 = device.run(circuit, shots) + ... task2 = device.run(circuit, shots) + + or start the reservation + >>> DirectReservation(device_arn, reservation_arn="").start() + ... task1 = device.run(circuit, shots) + ... task2 = device.run(circuit, shots) + + References: + + [1] https://docs.aws.amazon.com/braket/latest/developerguide/braket-reservations.html + """ + + _is_active = False # Class variable to track active reservation context + + def __init__(self, device: Device | str | None, reservation_arn: str | None): + if isinstance(device, AwsDevice): + self.device_arn = device.arn + elif isinstance(device, str): + self.device_arn = AwsDevice(device).arn # validate ARN early + elif isinstance(device, Device) or device is None: # LocalSimulator + warnings.warn( + "Using a local simulator with the reservation. For a reservation on a QPU, please " + "ensure the device matches the reserved Braket device." + ) + self.device_arn = "" # instead of None, use empty string + else: + raise TypeError("Device must be an AwsDevice or its ARN, or a local simulator device.") + + self.reservation_arn = reservation_arn + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop() + + def start(self) -> None: + """Start the reservation context.""" + if DirectReservation._is_active: + raise RuntimeError("Another reservation is already active.") + + os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] = self.device_arn + if self.reservation_arn: + os.environ["AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN"] = self.reservation_arn + DirectReservation._is_active = True + + def stop(self) -> None: + """Stop the reservation context.""" + if not DirectReservation._is_active: + warnings.warn("Reservation context is not active.") + return + os.environ.pop("AMZN_BRAKET_RESERVATION_DEVICE_ARN", None) + os.environ.pop("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN", None) + DirectReservation._is_active = False diff --git a/src/braket/aws/queue_information.py b/src/braket/aws/queue_information.py index 109632751..77e5f3554 100644 --- a/src/braket/aws/queue_information.py +++ b/src/braket/aws/queue_information.py @@ -17,8 +17,7 @@ class QueueType(str, Enum): - """ - Enumerates the possible priorities for the queue. + """Enumerates the possible priorities for the queue. Values: NORMAL: Represents normal queue for the device. @@ -31,8 +30,7 @@ class QueueType(str, Enum): @dataclass() class QueueDepthInfo: - """ - Represents quantum tasks and hybrid jobs queue depth information. + """Represents quantum tasks and hybrid jobs queue depth information. Attributes: quantum_tasks (dict[QueueType, str]): number of quantum tasks waiting @@ -49,8 +47,7 @@ class QueueDepthInfo: @dataclass class QuantumTaskQueueInfo: - """ - Represents quantum tasks queue information. + """Represents quantum tasks queue information. Attributes: queue_type (QueueType): type of the quantum_task queue either 'Normal' @@ -68,8 +65,7 @@ class QuantumTaskQueueInfo: @dataclass class HybridJobQueueInfo: - """ - Represents hybrid job queue information. + """Represents hybrid job queue information. Attributes: queue_position (Optional[str]): current position of your hybrid job within a respective diff --git a/src/braket/circuits/__init__.py b/src/braket/circuits/__init__.py index d2788746c..a5fb52980 100644 --- a/src/braket/circuits/__init__.py +++ b/src/braket/circuits/__init__.py @@ -20,7 +20,6 @@ result_types, ) from braket.circuits.angled_gate import AngledGate, DoubleAngledGate # noqa: F401 -from braket.circuits.ascii_circuit_diagram import AsciiCircuitDiagram # noqa: F401 from braket.circuits.circuit import Circuit # noqa: F401 from braket.circuits.circuit_diagram import CircuitDiagram # noqa: F401 from braket.circuits.compiler_directive import CompilerDirective # noqa: F401 @@ -38,3 +37,9 @@ from braket.circuits.qubit import Qubit, QubitInput # noqa: F401 from braket.circuits.qubit_set import QubitSet, QubitSetInput # noqa: F401 from braket.circuits.result_type import ObservableResultType, ResultType # noqa: F401 +from braket.circuits.text_diagram_builders.ascii_circuit_diagram import ( # noqa: F401 + AsciiCircuitDiagram, +) +from braket.circuits.text_diagram_builders.unicode_circuit_diagram import ( # noqa: F401 + UnicodeCircuitDiagram, +) diff --git a/src/braket/circuits/angled_gate.py b/src/braket/circuits/angled_gate.py index e453177a7..49447e58f 100644 --- a/src/braket/circuits/angled_gate.py +++ b/src/braket/circuits/angled_gate.py @@ -27,9 +27,7 @@ class AngledGate(Gate, Parameterizable): - """ - Class `AngledGate` represents a quantum gate that operates on N qubits and an angle. - """ + """Class `AngledGate` represents a quantum gate that operates on N qubits and an angle.""" def __init__( self, @@ -37,7 +35,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Initializes an `AngledGate`. + Args: angle (Union[FreeParameterExpression, float]): The angle of the gate in radians or expression representation. @@ -63,8 +62,7 @@ def __init__( @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameters or + """Returns the parameters associated with the object, either unbound free parameters or bound values. Returns: @@ -75,8 +73,7 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: @property def angle(self) -> Union[FreeParameterExpression, float]: - """ - Returns the angle for the gate + """Returns the angle of the gate Returns: Union[FreeParameterExpression, float]: The angle of the gate in radians @@ -110,7 +107,7 @@ def adjoint(self) -> list[Gate]: new._ascii_symbols = new_ascii_symbols return [new] - def __eq__(self, other): + def __eq__(self, other: AngledGate): return ( isinstance(other, AngledGate) and self.name == other.name @@ -125,8 +122,8 @@ def __hash__(self): class DoubleAngledGate(Gate, Parameterizable): - """ - Class `DoubleAngledGate` represents a quantum gate that operates on N qubits and two angles. + """Class `DoubleAngledGate` represents a quantum gate that operates on N qubits and + two angles. """ def __init__( @@ -136,7 +133,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Inits a `DoubleAngledGate`. + Args: angle_1 (Union[FreeParameterExpression, float]): The first angle of the gate in radians or expression representation. @@ -168,8 +166,7 @@ def __init__( @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameters or + """Returns the parameters associated with the object, either unbound free parameters or bound values. Returns: @@ -180,8 +177,7 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: @property def angle_1(self) -> Union[FreeParameterExpression, float]: - """ - Returns the first angle for the gate + """Returns the first angle of the gate Returns: Union[FreeParameterExpression, float]: The first angle of the gate in radians @@ -190,20 +186,18 @@ def angle_1(self) -> Union[FreeParameterExpression, float]: @property def angle_2(self) -> Union[FreeParameterExpression, float]: - """ - Returns the second angle for the gate + """Returns the second angle of the gate Returns: Union[FreeParameterExpression, float]: The second angle of the gate in radians """ return self._parameters[1] - def bind_values(self, **kwargs) -> AngledGate: - """ - Takes in parameters and attempts to assign them to values. + def bind_values(self, **kwargs: FreeParameterExpression | str) -> AngledGate: + """Takes in parameters and attempts to assign them to values. Args: - ``**kwargs``: The parameters that are being assigned. + **kwargs (FreeParameterExpression | str): The parameters that are being assigned. Returns: AngledGate: A new Gate of the same type with the requested parameters bound. @@ -221,7 +215,7 @@ def adjoint(self) -> list[Gate]: """ raise NotImplementedError - def __eq__(self, other): + def __eq__(self, other: DoubleAngledGate): return ( isinstance(other, DoubleAngledGate) and self.name == other.name @@ -240,8 +234,8 @@ def __hash__(self): class TripleAngledGate(Gate, Parameterizable): - """ - Class `TripleAngledGate` represents a quantum gate that operates on N qubits and three angles. + """Class `TripleAngledGate` represents a quantum gate that operates on N qubits and + three angles. """ def __init__( @@ -252,7 +246,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Inits a `TripleAngledGate`. + Args: angle_1 (Union[FreeParameterExpression, float]): The first angle of the gate in radians or expression representation. @@ -287,8 +282,7 @@ def __init__( @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameters or + """Returns the parameters associated with the object, either unbound free parameters or bound values. Returns: @@ -299,8 +293,7 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: @property def angle_1(self) -> Union[FreeParameterExpression, float]: - """ - Returns the first angle for the gate + """Returns the first angle of the gate Returns: Union[FreeParameterExpression, float]: The first angle of the gate in radians @@ -309,8 +302,7 @@ def angle_1(self) -> Union[FreeParameterExpression, float]: @property def angle_2(self) -> Union[FreeParameterExpression, float]: - """ - Returns the second angle for the gate + """Returns the second angle of the gate Returns: Union[FreeParameterExpression, float]: The second angle of the gate in radians @@ -319,20 +311,18 @@ def angle_2(self) -> Union[FreeParameterExpression, float]: @property def angle_3(self) -> Union[FreeParameterExpression, float]: - """ - Returns the second angle for the gate + """Returns the third angle of the gate Returns: Union[FreeParameterExpression, float]: The third angle of the gate in radians """ return self._parameters[2] - def bind_values(self, **kwargs) -> AngledGate: - """ - Takes in parameters and attempts to assign them to values. + def bind_values(self, **kwargs: FreeParameterExpression | str) -> AngledGate: + """Takes in parameters and attempts to assign them to values. Args: - ``**kwargs``: The parameters that are being assigned. + **kwargs (FreeParameterExpression | str): The parameters that are being assigned. Returns: AngledGate: A new Gate of the same type with the requested parameters bound. @@ -350,7 +340,7 @@ def adjoint(self) -> list[Gate]: """ raise NotImplementedError - def __eq__(self, other): + def __eq__(self, other: TripleAngledGate): return ( isinstance(other, TripleAngledGate) and self.name == other.name @@ -382,8 +372,7 @@ def _(angle_1: FreeParameterExpression, angle_2: FreeParameterExpression): def angled_ascii_characters(gate: str, angle: Union[FreeParameterExpression, float]) -> str: - """ - Generates a formatted ascii representation of an angled gate. + """Generates a formatted ascii representation of an angled gate. Args: gate (str): The name of the gate. @@ -400,24 +389,22 @@ def _multi_angled_ascii_characters( gate: str, *angles: Union[FreeParameterExpression, float], ) -> str: - """ - Generates a formatted ascii representation of an angled gate. + """Generates a formatted ascii representation of an angled gate. Args: gate (str): The name of the gate. - `*angles` (Union[FreeParameterExpression, float]): angles in radians. + *angles (Union[FreeParameterExpression, float]): angles in radians. Returns: str: Returns the ascii representation for an angled gate. """ - def format_string(angle: Union[FreeParameterExpression, float]) -> str: - """ - Formats an angle for ASCII representation. + def format_string(angle: FreeParameterExpression | float) -> str: + """Formats an angle for ASCII representation. Args: - angle (Union[FreeParameterExpression, float]): The angle to format. + angle (FreeParameterExpression | float): The angle to format. Returns: str: The ASCII representation of the angle. @@ -427,13 +414,13 @@ def format_string(angle: Union[FreeParameterExpression, float]) -> str: return f"{gate}({', '.join(f'{angle:{format_string(angle)}}' for angle in angles)})" -def get_angle(gate: AngledGate, **kwargs) -> AngledGate: - """ - Gets the angle with all values substituted in that are requested. +def get_angle(gate: AngledGate, **kwargs: FreeParameterExpression | str) -> AngledGate: + """Gets the angle with all values substituted in that are requested. Args: gate (AngledGate): The subclass of AngledGate for which the angle is being obtained. - ``**kwargs``: The named parameters that are being filled for a particular gate. + **kwargs (FreeParameterExpression | str): The named parameters that are being filled + for a particular gate. Returns: AngledGate: A new gate of the type of the AngledGate originally used with all @@ -445,25 +432,28 @@ def get_angle(gate: AngledGate, **kwargs) -> AngledGate: return type(gate)(angle=new_angle) -def _get_angles(gate: TripleAngledGate, **kwargs) -> TripleAngledGate: - """ - Gets the angle with all values substituted in that are requested. +def _get_angles( + gate: DoubleAngledGate | TripleAngledGate, **kwargs: FreeParameterExpression | str +) -> DoubleAngledGate | TripleAngledGate: + """Gets the angle with all values substituted in that are requested. Args: - gate (TripleAngledGate): The subclass of TripleAngledGate for which the angle is being - obtained. - ``**kwargs``: The named parameters that are being filled for a particular gate. + gate (DoubleAngledGate | TripleAngledGate): The subclass of multi angle AngledGate for + which the angle is being obtained. + **kwargs (FreeParameterExpression | str): The named parameters that are being filled + for a particular gate. Returns: - TripleAngledGate: A new gate of the type of the AngledGate originally used with all angles - updated. + DoubleAngledGate | TripleAngledGate: A new gate of the type of the AngledGate + originally used with all angles updated. """ - new_angles = [ - ( + angles = [f"angle_{i + 1}" for i in range(len(gate._parameters))] + new_angles_args = { + angle: ( getattr(gate, angle).subs(kwargs) if isinstance(getattr(gate, angle), FreeParameterExpression) else getattr(gate, angle) ) - for angle in ("angle_1", "angle_2", "angle_3") - ] - return type(gate)(angle_1=new_angles[0], angle_2=new_angles[1], angle_3=new_angles[2]) + for angle in angles + } + return type(gate)(**new_angles_args) diff --git a/src/braket/circuits/ascii_circuit_diagram.py b/src/braket/circuits/ascii_circuit_diagram.py index 2c7024574..accbce161 100644 --- a/src/braket/circuits/ascii_circuit_diagram.py +++ b/src/braket/circuits/ascii_circuit_diagram.py @@ -11,411 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from __future__ import annotations - -from functools import reduce -from typing import Union - -import braket.circuits.circuit as cir -from braket.circuits.circuit_diagram import CircuitDiagram -from braket.circuits.compiler_directive import CompilerDirective -from braket.circuits.gate import Gate -from braket.circuits.instruction import Instruction -from braket.circuits.moments import MomentType -from braket.circuits.noise import Noise -from braket.circuits.result_type import ResultType -from braket.registers.qubit import Qubit -from braket.registers.qubit_set import QubitSet - - -class AsciiCircuitDiagram(CircuitDiagram): - """Builds ASCII string circuit diagrams.""" - - @staticmethod - def build_diagram(circuit: cir.Circuit) -> str: - """ - Build an ASCII string circuit diagram. - - Args: - circuit (Circuit): Circuit for which to build a diagram. - - Returns: - str: ASCII string circuit diagram. - """ - - if not circuit.instructions: - return "" - - if all(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments): - return f"Global phase: {circuit.global_phase}" - - circuit_qubits = circuit.qubits - circuit_qubits.sort() - - y_axis_str, global_phase = AsciiCircuitDiagram._prepare_diagram_vars( - circuit, circuit_qubits - ) - - time_slices = circuit.moments.time_slices() - column_strs = [] - - # Moment columns - for time, instructions in time_slices.items(): - global_phase = AsciiCircuitDiagram._compute_moment_global_phase( - global_phase, instructions - ) - moment_str = AsciiCircuitDiagram._ascii_diagram_column_set( - str(time), circuit_qubits, instructions, global_phase - ) - column_strs.append(moment_str) - - # Result type columns - additional_result_types, target_result_types = AsciiCircuitDiagram._categorize_result_types( - circuit.result_types - ) - if target_result_types: - column_strs.append( - AsciiCircuitDiagram._ascii_diagram_column_set( - "Result Types", circuit_qubits, target_result_types, global_phase - ) - ) - - # Unite strings - lines = y_axis_str.split("\n") - for col_str in column_strs: - for i, line_in_col in enumerate(col_str.split("\n")): - lines[i] += line_in_col - - # Time on top and bottom - lines.append(lines[0]) - - if global_phase: - lines.append(f"\nGlobal phase: {global_phase}") - - # Additional result types line on bottom - if additional_result_types: - lines.append(f"\nAdditional result types: {', '.join(additional_result_types)}") - - # A list of parameters in the circuit to the currently assigned values. - if circuit.parameters: - lines.append( - "\nUnassigned parameters: " - f"{sorted(circuit.parameters, key=lambda param: param.name)}." - ) - - return "\n".join(lines) - - @staticmethod - def _prepare_diagram_vars( - circuit: cir.Circuit, circuit_qubits: QubitSet - ) -> tuple[str, float | None]: - # Y Axis Column - y_axis_width = len(str(int(max(circuit_qubits)))) - y_axis_str = "{0:{width}} : |\n".format("T", width=y_axis_width + 1) - - global_phase = None - if any(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments): - y_axis_str += "{0:{width}} : |\n".format("GP", width=y_axis_width) - global_phase = 0 - - for qubit in circuit_qubits: - y_axis_str += "{0:{width}}\n".format(" ", width=y_axis_width + 5) - y_axis_str += "q{0:{width}} : -\n".format(str(int(qubit)), width=y_axis_width) - - return y_axis_str, global_phase - - @staticmethod - def _compute_moment_global_phase( - global_phase: float | None, items: list[Instruction] - ) -> float | None: - """ - Compute the integrated phase at a certain moment. - - Args: - global_phase (float | None): The integrated phase up to the computed moment - items (list[Instruction]): list of instructions - - Returns: - float | None: The updated integrated phase. - """ - moment_phase = 0 - for item in items: - if ( - isinstance(item, Instruction) - and isinstance(item.operator, Gate) - and item.operator.name == "GPhase" - ): - moment_phase += item.operator.angle - return global_phase + moment_phase if global_phase is not None else None - - @staticmethod - def _ascii_group_items( - circuit_qubits: QubitSet, - items: list[Union[Instruction, ResultType]], - ) -> list[tuple[QubitSet, list[Instruction]]]: - """ - Group instructions in a moment for ASCII diagram - - Args: - circuit_qubits (QubitSet): set of qubits in circuit - items (list[Union[Instruction, ResultType]]): list of instructions or result types - - Returns: - list[tuple[QubitSet, list[Instruction]]]: list of grouped instructions or result types. - """ - groupings = [] - for item in items: - # Can only print Gate and Noise operators for instructions at the moment - if isinstance(item, Instruction) and not isinstance( - item.operator, (Gate, Noise, CompilerDirective) - ): - continue - - # As a zero-qubit gate, GPhase can be grouped with anything. We set qubit_range - # to an empty list and we just add it to the first group below. - if ( - isinstance(item, Instruction) - and isinstance(item.operator, Gate) - and item.operator.name == "GPhase" - ): - qubit_range = QubitSet() - elif (isinstance(item, ResultType) and not item.target) or ( - isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective) - ): - qubit_range = circuit_qubits - else: - if isinstance(item.target, list): - target = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) - else: - target = item.target - control = getattr(item, "control", QubitSet()) - target_and_control = target.union(control) - qubit_range = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) - - found_grouping = False - for group in groupings: - qubits_added = group[0] - instr_group = group[1] - # Take into account overlapping multi-qubit gates - if not qubits_added.intersection(set(qubit_range)): - instr_group.append(item) - qubits_added.update(qubit_range) - found_grouping = True - break - - if not found_grouping: - groupings.append((qubit_range, [item])) - - return groupings - - @staticmethod - def _categorize_result_types( - result_types: list[ResultType], - ) -> tuple[list[str], list[ResultType]]: - """ - Categorize result types into result types with target and those without. - - Args: - result_types (list[ResultType]): list of result types - - Returns: - tuple[list[str], list[ResultType]]: first element is a list of result types - without `target` attribute; second element is a list of result types with - `target` attribute - """ - additional_result_types = [] - target_result_types = [] - for result_type in result_types: - if hasattr(result_type, "target"): - target_result_types.append(result_type) - else: - additional_result_types.extend(result_type.ascii_symbols) - return additional_result_types, target_result_types - - @staticmethod - def _ascii_diagram_column_set( - col_title: str, - circuit_qubits: QubitSet, - items: list[Union[Instruction, ResultType]], - global_phase: float | None, - ) -> str: - """ - Return a set of columns in the ASCII string diagram of the circuit for a list of items. - - Args: - col_title (str): title of column set - circuit_qubits (QubitSet): qubits in circuit - items (list[Union[Instruction, ResultType]]): list of instructions or result types - global_phase (float | None): the integrated global phase up to this set - - Returns: - str: An ASCII string diagram for the column set. - """ - - # Group items to separate out overlapping multi-qubit items - groupings = AsciiCircuitDiagram._ascii_group_items(circuit_qubits, items) - - column_strs = [ - AsciiCircuitDiagram._ascii_diagram_column(circuit_qubits, grouping[1], global_phase) - for grouping in groupings - ] - - # Unite column strings - lines = column_strs[0].split("\n") - for column_str in column_strs[1:]: - for i, moment_line in enumerate(column_str.split("\n")): - lines[i] += moment_line - - # Adjust for column title width - col_title_width = len(col_title) - symbols_width = len(lines[0]) - 1 - if symbols_width < col_title_width: - diff = col_title_width - symbols_width - for i in range(len(lines) - 1): - if lines[i].endswith("-"): - lines[i] += "-" * diff - else: - lines[i] += " " - - first_line = "{:^{width}}|\n".format(col_title, width=len(lines[0]) - 1) - - return first_line + "\n".join(lines) - - @staticmethod - def _ascii_diagram_column( - circuit_qubits: QubitSet, - items: list[Union[Instruction, ResultType]], - global_phase: float | None = None, - ) -> str: - """ - Return a column in the ASCII string diagram of the circuit for a given list of items. - - Args: - circuit_qubits (QubitSet): qubits in circuit - items (list[Union[Instruction, ResultType]]): list of instructions or result types - global_phase (float | None): the integrated global phase up to this column - - Returns: - str: an ASCII string diagram for the specified moment in time for a column. - """ - symbols = {qubit: "-" for qubit in circuit_qubits} - margins = {qubit: " " for qubit in circuit_qubits} - - for item in items: - if isinstance(item, ResultType) and not item.target: - target_qubits = circuit_qubits - control_qubits = QubitSet() - target_and_control = target_qubits.union(control_qubits) - qubits = circuit_qubits - ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits) - elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective): - target_qubits = circuit_qubits - control_qubits = QubitSet() - target_and_control = target_qubits.union(control_qubits) - qubits = circuit_qubits - ascii_symbol = item.ascii_symbols[0] - marker = "*" * len(ascii_symbol) - num_after = len(circuit_qubits) - 1 - after = ["|"] * (num_after - 1) + ([marker] if num_after else []) - ascii_symbols = [ascii_symbol] + after - elif ( - isinstance(item, Instruction) - and isinstance(item.operator, Gate) - and item.operator.name == "GPhase" - ): - target_qubits = circuit_qubits - control_qubits = QubitSet() - target_and_control = QubitSet() - qubits = circuit_qubits - ascii_symbols = "-" * len(circuit_qubits) - else: - if isinstance(item.target, list): - target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) - else: - target_qubits = item.target - control_qubits = getattr(item, "control", QubitSet()) - map_control_qubit_states = AsciiCircuitDiagram._build_map_control_qubits( - item, control_qubits - ) - - target_and_control = target_qubits.union(control_qubits) - qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) - - ascii_symbols = item.ascii_symbols - - for qubit in qubits: - # Determine if the qubit is part of the item or in the middle of a - # multi qubit item. - if qubit in target_qubits: - item_qubit_index = [ - index for index, q in enumerate(target_qubits) if q == qubit - ][0] - power_string = ( - f"^{power}" - if ( - (power := getattr(item, "power", 1)) != 1 - # this has the limitation of not printing the power - # when a user has a gate genuinely named C, but - # is necessary to enable proper printing of custom - # gates with built-in control qubits - and ascii_symbols[item_qubit_index] != "C" - ) - else "" - ) - symbols[qubit] = ( - f"({ascii_symbols[item_qubit_index]}{power_string})" - if power_string - else ascii_symbols[item_qubit_index] - ) - elif qubit in control_qubits: - symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N" - else: - symbols[qubit] = "|" - - # Set the margin to be a connector if not on the first qubit - if target_and_control and qubit != min(target_and_control): - margins[qubit] = "|" - - output = AsciiCircuitDiagram._create_output(symbols, margins, circuit_qubits, global_phase) - return output - - @staticmethod - def _create_output( - symbols: dict[Qubit, str], - margins: dict[Qubit, str], - qubits: QubitSet, - global_phase: float | None, - ) -> str: - symbols_width = max([len(symbol) for symbol in symbols.values()]) - output = "" - - if global_phase is not None: - global_phase_str = ( - f"{global_phase:.2f}" if isinstance(global_phase, float) else str(global_phase) - ) - symbols_width = max([symbols_width, len(global_phase_str)]) - output += "{0:{fill}{align}{width}}|\n".format( - global_phase_str, - fill=" ", - align="^", - width=symbols_width, - ) - - for qubit in qubits: - output += "{0:{width}}\n".format(margins[qubit], width=symbols_width + 1) - output += "{0:{fill}{align}{width}}\n".format( - symbols[qubit], fill="-", align="<", width=symbols_width + 1 - ) - return output - - @staticmethod - def _build_map_control_qubits(item: Instruction, control_qubits: QubitSet) -> dict(Qubit, int): - control_state = getattr(item, "control_state", None) - if control_state is not None: - map_control_qubit_states = { - qubit: state for qubit, state in zip(control_qubits, control_state) - } - else: - map_control_qubit_states = {qubit: 1 for qubit in control_qubits} - - return map_control_qubit_states +# Moving ascii_circuit_diagram.py into the text_diagram_builders folder in order +# to group all classes that print circuits in a text format. +from braket.circuits.text_diagram_builders.ascii_circuit_diagram import ( # noqa: F401 + AsciiCircuitDiagram, +) diff --git a/src/braket/circuits/basis_state.py b/src/braket/circuits/basis_state.py index b6ce11bc8..86578fc89 100644 --- a/src/braket/circuits/basis_state.py +++ b/src/braket/circuits/basis_state.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import singledispatch from typing import Optional, Union @@ -5,7 +7,7 @@ class BasisState: - def __init__(self, state: "BasisStateInput", size: Optional[int] = None): + def __init__(self, state: BasisStateInput, size: Optional[int] = None): self.state = _as_tuple(state, size) @property @@ -30,7 +32,7 @@ def __len__(self) -> int: def __iter__(self): return iter(self.state) - def __eq__(self, other): + def __eq__(self, other: BasisState): return self.state == other.state def __bool__(self): @@ -42,7 +44,7 @@ def __str__(self): def __repr__(self): return f'BasisState("{self.as_string}")' - def __getitem__(self, item): + def __getitem__(self, item: int): return BasisState(self.state[item]) diff --git a/src/braket/circuits/braket_program_context.py b/src/braket/circuits/braket_program_context.py index 863513565..4371637d3 100644 --- a/src/braket/circuits/braket_program_context.py +++ b/src/braket/circuits/braket_program_context.py @@ -18,6 +18,7 @@ from braket.circuits import Circuit, Instruction from braket.circuits.gates import Unitary +from braket.circuits.measure import Measure from braket.circuits.noises import Kraus from braket.circuits.translations import ( BRAKET_GATES, @@ -31,7 +32,8 @@ class BraketProgramContext(AbstractProgramContext): def __init__(self, circuit: Optional[Circuit] = None): - """ + """Inits a `BraketProgramContext`. + Args: circuit (Optional[Circuit]): A partially-built circuit to continue building with this context. Default: None. @@ -133,8 +135,7 @@ def add_kraus_instruction(self, matrices: list[np.ndarray], target: list[int]) - self._circuit.add_instruction(instruction) def add_result(self, result: Results) -> None: - """ - Abstract method to add result type to the circuit + """Abstract method to add result type to the circuit Args: result (Results): The result object representing the measurement results @@ -159,3 +160,17 @@ def handle_parameter_value( return evaluated_value return FreeParameterExpression(evaluated_value) return value + + def add_measure(self, target: tuple[int]) -> None: + """Add a measure instruction to the circuit + + Args: + target (tuple[int]): the target qubits to be measured. + """ + for index, qubit in enumerate(target): + instruction = Instruction(Measure(index=index), qubit) + self._circuit.add_instruction(instruction) + if self._circuit._measure_targets: + self._circuit._measure_targets.append(qubit) + else: + self._circuit._measure_targets = [qubit] diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index a10da287c..c19bb0a06 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -13,19 +13,21 @@ from __future__ import annotations +from collections import Counter from collections.abc import Callable, Iterable from numbers import Number from typing import Any, Optional, TypeVar, Union import numpy as np import oqpy +from sympy import Expr from braket.circuits import compiler_directives -from braket.circuits.ascii_circuit_diagram import AsciiCircuitDiagram from braket.circuits.free_parameter import FreeParameter from braket.circuits.free_parameter_expression import FreeParameterExpression from braket.circuits.gate import Gate from braket.circuits.instruction import Instruction +from braket.circuits.measure import Measure from braket.circuits.moments import Moments, MomentType from braket.circuits.noise import Noise from braket.circuits.noise_helpers import ( @@ -50,14 +52,16 @@ QubitReferenceType, SerializationProperties, ) +from braket.circuits.text_diagram_builders.unicode_circuit_diagram import UnicodeCircuitDiagram from braket.circuits.unitary_calculation import calculate_unitary_big_endian from braket.default_simulator.openqasm.interpreter import Interpreter from braket.ir.jaqcd import Program as JaqcdProgram from braket.ir.openqasm import Program as OpenQasmProgram from braket.ir.openqasm.program_v1 import io_type -from braket.pulse import ArbitraryWaveform, Frame from braket.pulse.ast.qasm_parser import ast_to_qasm +from braket.pulse.frame import Frame from braket.pulse.pulse_sequence import PulseSequence, _validate_uniqueness +from braket.pulse.waveforms import Waveform from braket.registers.qubit import QubitInput from braket.registers.qubit_set import QubitSet, QubitSetInput @@ -69,8 +73,7 @@ class Circuit: - """ - A representation of a quantum circuit that contains the instructions to be performed on a + """A representation of a quantum circuit that contains the instructions to be performed on a quantum device and the requested result types. See :mod:`braket.circuits.gates` module for all of the supported instructions. @@ -85,8 +88,7 @@ class Circuit: @classmethod def register_subroutine(cls, func: SubroutineCallable) -> None: - """ - Register the subroutine `func` as an attribute of the `Circuit` class. The attribute name + """Register the subroutine `func` as an attribute of the `Circuit` class. The attribute name is the name of `func`. Args: @@ -115,10 +117,11 @@ def method_from_subroutine(self, *args, **kwargs) -> SubroutineReturn: setattr(cls, function_name, method_from_subroutine) function_attr = getattr(cls, function_name) - setattr(function_attr, "__doc__", func.__doc__) + function_attr.__doc__ = func.__doc__ def __init__(self, addable: AddableTypes | None = None, *args, **kwargs): - """ + """Inits a `Circuit`. + Args: addable (AddableTypes | None): The item(s) to add to self. Default = None. @@ -147,6 +150,7 @@ def __init__(self, addable: AddableTypes | None = None, *args, **kwargs): self._parameters = set() self._observables_simultaneously_measurable = True self._has_compiler_directives = False + self._measure_targets = None if addable is not None: self.add(addable, *args, **kwargs) @@ -160,11 +164,9 @@ def depth(self) -> int: def global_phase(self) -> float: """float: Get the global phase of the circuit.""" return sum( - [ - instr.operator.angle - for moment, instr in self._moments.items() - if moment.moment_type == MomentType.GLOBAL_PHASE - ] + instr.operator.angle + for moment, instr in self._moments.items() + if moment.moment_type == MomentType.GLOBAL_PHASE ) @property @@ -192,8 +194,7 @@ def basis_rotation_instructions(self) -> list[Instruction]: # Note that basis_rotation_instructions can change each time a new instruction # is added to the circuit because `self._moments.qubits` would change basis_rotation_instructions = [] - all_qubit_observable = self._qubit_observable_mapping.get(Circuit._ALL_QUBITS) - if all_qubit_observable: + if all_qubit_observable := self._qubit_observable_mapping.get(Circuit._ALL_QUBITS): for target in self.qubits: basis_rotation_instructions += Circuit._observable_to_instruction( all_qubit_observable, target @@ -222,6 +223,7 @@ def moments(self) -> Moments: @property def qubit_count(self) -> int: """Get the qubit count for this circuit. Note that this includes observables. + Returns: int: The qubit count for this circuit. """ @@ -235,8 +237,7 @@ def qubits(self) -> QubitSet: @property def parameters(self) -> set[FreeParameter]: - """ - Gets a set of the parameters in the Circuit. + """Gets a set of the parameters in the Circuit. Returns: set[FreeParameter]: The `FreeParameters` in the Circuit. @@ -249,8 +250,7 @@ def add_result_type( target: QubitSetInput | None = None, target_mapping: dict[QubitInput, QubitInput] | None = None, ) -> Circuit: - """ - Add a requested result type to `self`, returns `self` for chaining ability. + """Add a requested result type to `self`, returns `self` for chaining ability. Args: result_type (ResultType): `ResultType` to add into `self`. @@ -273,6 +273,7 @@ def add_result_type( Raises: TypeError: If both `target_mapping` and `target` are supplied. + ValueError: If a measure instruction exists on the current circuit. Examples: >>> result_type = ResultType.Probability(target=[0, 1]) @@ -298,6 +299,12 @@ def add_result_type( if target_mapping and target is not None: raise TypeError("Only one of 'target_mapping' or 'target' can be supplied.") + if self._measure_targets: + raise ValueError( + "cannot add a result type to a circuit which already contains a " + "measure instruction." + ) + if not target_mapping and not target: # Nothing has been supplied, add result_type result_type_to_add = result_type @@ -407,14 +414,48 @@ def _add_to_qubit_observable_set(self, result_type: ResultType) -> None: if isinstance(result_type, ObservableResultType) and result_type.target: self._qubit_observable_set.update(result_type.target) + def _check_if_qubit_measured( + self, + instruction: Instruction, + target: QubitSetInput | None = None, + target_mapping: dict[QubitInput, QubitInput] | None = None, + ) -> None: + """Checks if the target qubits are measured. If the qubit is already measured + the instruction will not be added to the Circuit. + + Args: + instruction (Instruction): `Instruction` to add into `self`. + target (QubitSetInput | None): Target qubits for the + `instruction`. If a single qubit gate, an instruction is created for every index + in `target`. + Default = `None`. + target_mapping (dict[QubitInput, QubitInput] | None): A dictionary of + qubit mappings to apply to the `instruction.target`. Key is the qubit in + `instruction.target` and the value is what the key will be changed to. + Default = `None`. + + Raises: + ValueError: If adding a gate or noise operation after a measure instruction. + """ + if self._measure_targets: + measure_on_target_mapping = target_mapping and any( + targ in self._measure_targets for targ in target_mapping.values() + ) + if ( + # check if there is a measure instruction on the targeted qubit(s) + measure_on_target_mapping + or any(tar in self._measure_targets for tar in QubitSet(target)) + or any(tar in self._measure_targets for tar in QubitSet(instruction.target)) + ): + raise ValueError("cannot apply instruction to measured qubits.") + def add_instruction( self, instruction: Instruction, target: QubitSetInput | None = None, target_mapping: dict[QubitInput, QubitInput] | None = None, ) -> Circuit: - """ - Add an instruction to `self`, returns `self` for chaining ability. + """Add an instruction to `self`, returns `self` for chaining ability. Args: instruction (Instruction): `Instruction` to add into `self`. @@ -432,6 +473,7 @@ def add_instruction( Raises: TypeError: If both `target_mapping` and `target` are supplied. + ValueError: If adding a gate or noise after a measure instruction. Examples: >>> instr = Instruction(Gate.CNot(), [0, 1]) @@ -459,6 +501,9 @@ def add_instruction( if target_mapping and target is not None: raise TypeError("Only one of 'target_mapping' or 'target' can be supplied.") + # Check if there is a measure instruction on the circuit + self._check_if_qubit_measured(instruction, target, target_mapping) + if not target_mapping and not target: # Nothing has been supplied, add instruction instructions_to_add = [instruction] @@ -474,7 +519,9 @@ def add_instruction( if self._check_for_params(instruction): for param in instruction.operator.parameters: - if isinstance(param, FreeParameterExpression): + if isinstance(param, FreeParameterExpression) and isinstance( + param.expression, Expr + ): free_params = param.expression.free_symbols for parameter in free_params: self._parameters.add(FreeParameter(parameter.name)) @@ -483,8 +530,7 @@ def add_instruction( return self def _check_for_params(self, instruction: Instruction) -> bool: - """ - This checks for free parameters in an :class:{Instruction}. Checks children classes of + """This checks for free parameters in an :class:{Instruction}. Checks children classes of :class:{Parameterizable}. Args: @@ -505,8 +551,7 @@ def add_circuit( target: QubitSetInput | None = None, target_mapping: dict[QubitInput, QubitInput] | None = None, ) -> Circuit: - """ - Add a `circuit` to self, returns self for chaining ability. + """Add a `Circuit` to `self`, returning `self` for chaining ability. Args: circuit (Circuit): Circuit to add into self. @@ -581,9 +626,8 @@ def add_verbatim_box( target: QubitSetInput | None = None, target_mapping: dict[QubitInput, QubitInput] | None = None, ) -> Circuit: - """ - Add a verbatim `circuit` to self, that is, ensures that `circuit` is not modified in any way - by the compiler. + """Add a verbatim `Circuit` to `self`, ensuring that the circuit is not modified in + any way by the compiler. Args: verbatim_circuit (Circuit): Circuit to add into self. @@ -637,6 +681,9 @@ def add_verbatim_box( if verbatim_circuit.result_types: raise ValueError("Verbatim subcircuit is not measured and cannot have result types") + if verbatim_circuit._measure_targets: + raise ValueError("cannot measure a subcircuit inside a verbatim box.") + if verbatim_circuit.instructions: self.add_instruction(Instruction(compiler_directives.StartVerbatimBox())) for instruction in verbatim_circuit.instructions: @@ -645,6 +692,71 @@ def add_verbatim_box( self._has_compiler_directives = True return self + def _add_measure(self, target_qubits: QubitSetInput) -> None: + """Adds a measure instruction to the the circuit + + Args: + target_qubits (QubitSetInput): target qubits to measure. + """ + for idx, target in enumerate(target_qubits): + num_qubits_measured = ( + len(self._measure_targets) + if self._measure_targets and len(target_qubits) == 1 + else 0 + ) + self.add_instruction( + Instruction( + operator=Measure(index=idx + num_qubits_measured), + target=target, + ) + ) + if self._measure_targets: + self._measure_targets.append(target) + else: + self._measure_targets = [target] + + def measure(self, target_qubits: QubitSetInput) -> Circuit: + """ + Add a `measure` operator to `self` ensuring only the target qubits are measured. + + Args: + target_qubits (QubitSetInput): target qubits to measure. + + Returns: + Circuit: self + + Raises: + IndexError: If `self` has no qubits. + IndexError: If target qubits are not within the range of the current circuit. + ValueError: If the current circuit contains any result types. + ValueError: If the target qubit is already measured. + + Examples: + >>> circ = Circuit.h(0).cnot(0, 1).measure([0]) + >>> circ.print(list(circ.instructions)) + [Instruction('operator': H('qubit_count': 1), 'target': QubitSet([Qubit(0)]), + Instruction('operator': CNot('qubit_count': 2), 'target': QubitSet([Qubit(0), + Qubit(1)]), + Instruction('operator': Measure, 'target': QubitSet([Qubit(0)])] + """ + if not isinstance(target_qubits, Iterable): + target_qubits = QubitSet(target_qubits) + + # Check if result types are added on the circuit + if self.result_types: + raise ValueError("a circuit cannot contain both measure instructions and result types.") + + # Check if there are repeated qubits in the same measurement + if len(target_qubits) != len(set(target_qubits)): + intersection = [qubit for qubit, count in Counter(target_qubits).items() if count > 1] + raise ValueError( + f"cannot repeat qubit(s) {', '.join(map(str, intersection))} " + "in the same measurement." + ) + self._add_measure(target_qubits=target_qubits) + + return self + def apply_gate_noise( self, noise: Union[type[Noise], Iterable[type[Noise]]], @@ -699,7 +811,8 @@ def apply_gate_noise( If `target_unitary` is not a unitary. If `noise` is multi-qubit noise and `target_gates` contain gates with the number of qubits not the same as `noise.qubit_count`. - Warning: + + Warning: If `noise` is multi-qubit noise while there is no gate with the same number of qubits in `target_qubits` or in the whole circuit when `target_qubits` is not given. @@ -763,9 +876,12 @@ def apply_gate_noise( # check target_qubits target_qubits = check_noise_target_qubits(self, target_qubits) - if not all(qubit in self.qubits for qubit in target_qubits): + if any(qubit not in self.qubits for qubit in target_qubits): raise IndexError("target_qubits must be within the range of the current circuit.") + # Check if there is a measure instruction on the circuit + self._check_if_qubit_measured(instruction=noise, target=target_qubits) + # make noise a list noise = wrap_with_list(noise) @@ -859,8 +975,7 @@ def apply_initialization_noise( return apply_noise_to_moments(self, noise, target_qubits, "initialization") def make_bound_circuit(self, param_values: dict[str, Number], strict: bool = False) -> Circuit: - """ - Binds FreeParameters based upon their name and values passed in. If parameters + """Binds `FreeParameter`s based upon their name and values passed in. If parameters share the same name, all the parameters of that name will be set to the mapped value. Args: @@ -878,27 +993,23 @@ def make_bound_circuit(self, param_values: dict[str, Number], strict: bool = Fal return self._use_parameter_value(param_values) def _validate_parameters(self, parameter_values: dict[str, Number]) -> None: - """ - This runs a check to see that the parameters are in the Circuit. + """Checks that the parameters are in the `Circuit`. Args: parameter_values (dict[str, Number]): A mapping of FreeParameter names to a value to assign to them. Raises: - ValueError: If a parameter name is given which does not appear in the circuit. - + ValueError: If there are no parameters that match the key for the arg + param_values. """ - parameter_strings = set() - for parameter in self.parameters: - parameter_strings.add(str(parameter)) + parameter_strings = {str(parameter) for parameter in self.parameters} for param in parameter_values: if param not in parameter_strings: raise ValueError(f"No parameter in the circuit named: {param}") def _use_parameter_value(self, param_values: dict[str, Number]) -> Circuit: - """ - Creates a Circuit that uses the parameter values passed in. + """Creates a `Circuit` that uses the parameter values passed in. Args: param_values (dict[str, Number]): A mapping of FreeParameter names @@ -926,8 +1037,7 @@ def _use_parameter_value(self, param_values: dict[str, Number]) -> Circuit: @staticmethod def _validate_parameter_value(val: Any) -> None: - """ - Validates the value being used is a Number. + """Validates the value being used is a `Number`. Args: val (Any): The value be verified. @@ -937,7 +1047,7 @@ def _validate_parameter_value(val: Any) -> None: """ if not isinstance(val, Number): raise ValueError( - f"Parameters can only be assigned numeric values. " f"Invalid inputs: {val}" + f"Parameters can only be assigned numeric values. Invalid inputs: {val}" ) def apply_readout_noise( @@ -1000,7 +1110,7 @@ def apply_readout_noise( target_qubits = [target_qubits] if not all(isinstance(q, int) for q in target_qubits): raise TypeError("target_qubits must be integer(s)") - if not all(q >= 0 for q in target_qubits): + if any(q < 0 for q in target_qubits): raise ValueError("target_qubits must contain only non-negative integers.") target_qubits = QubitSet(target_qubits) @@ -1019,8 +1129,7 @@ def apply_readout_noise( return apply_noise_to_moments(self, noise, target_qubits, "readout") def add(self, addable: AddableTypes, *args, **kwargs) -> Circuit: - """ - Generic add method for adding item(s) to self. Any arguments that + """Generic add method for adding item(s) to self. Any arguments that `add_circuit()` and / or `add_instruction()` and / or `add_result_type` supports are supported by this method. If adding a subroutine, check with that subroutines documentation to determine what @@ -1092,9 +1201,8 @@ def adjoint(self) -> Circuit: circ.add_result_type(result_type) return circ - def diagram(self, circuit_diagram_class: type = AsciiCircuitDiagram) -> str: - """ - Get a diagram for the current circuit. + def diagram(self, circuit_diagram_class: type = UnicodeCircuitDiagram) -> str: + """Get a diagram for the current circuit. Args: circuit_diagram_class (type): A `CircuitDiagram` class that builds the @@ -1108,21 +1216,20 @@ def diagram(self, circuit_diagram_class: type = AsciiCircuitDiagram) -> str: def to_ir( self, ir_type: IRType = IRType.JAQCD, - serialization_properties: Optional[SerializationProperties] = None, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] = None, + serialization_properties: SerializationProperties | None = None, + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None, ) -> Union[OpenQasmProgram, JaqcdProgram]: - """ - Converts the circuit into the canonical intermediate representation. + """Converts the circuit into the canonical intermediate representation. If the circuit is sent over the wire, this method is called before it is sent. Args: ir_type (IRType): The IRType to use for converting the circuit object to its IR representation. - serialization_properties (Optional[SerializationProperties]): The serialization + serialization_properties (SerializationProperties | None): The serialization properties to use while serializing the object to the IR representation. The serialization properties supplied must correspond to the supplied `ir_type`. Defaults to None. - gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]]): The + gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): The calibration data for the device. default: None. Returns: @@ -1131,8 +1238,9 @@ def to_ir( Raises: ValueError: If the supplied `ir_type` is not supported, or if the supplied serialization - properties don't correspond to the `ir_type`. + properties don't correspond to the `ir_type`. """ + gate_definitions = gate_definitions or {} if ir_type == IRType.JAQCD: return self._to_jaqcd() elif ir_type == IRType.OPENQASM: @@ -1145,7 +1253,7 @@ def to_ir( ) return self._to_openqasm( serialization_properties or OpenQASMSerializationProperties(), - gate_definitions.copy() if gate_definitions is not None else None, + gate_definitions.copy(), ) else: raise ValueError(f"Supplied ir_type {ir_type} is not supported.") @@ -1154,8 +1262,7 @@ def to_ir( def from_ir( source: Union[str, OpenQasmProgram], inputs: Optional[dict[str, io_type]] = None ) -> Circuit: - """ - Converts an OpenQASM program to a Braket Circuit object. + """Converts an OpenQASM program to a Braket Circuit object. Args: source (Union[str, OpenQasmProgram]): OpenQASM string. @@ -1194,7 +1301,7 @@ def _to_jaqcd(self) -> JaqcdProgram: def _to_openqasm( self, serialization_properties: OpenQASMSerializationProperties, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], ) -> OpenQasmProgram: ir_instructions = self._create_openqasm_header(serialization_properties, gate_definitions) openqasm_ir_type = IRType.OPENQASM @@ -1216,7 +1323,8 @@ def _to_openqasm( for result_type in self.result_types ] ) - else: + # measure all the qubits if a measure instruction is not provided + elif self._measure_targets is None: qubits = ( sorted(self.qubits) if serialization_properties.qubit_reference_type == QubitReferenceType.VIRTUAL @@ -1231,13 +1339,18 @@ def _to_openqasm( def _create_openqasm_header( self, serialization_properties: OpenQASMSerializationProperties, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], ) -> list[str]: ir_instructions = ["OPENQASM 3.0;"] - for parameter in self.parameters: - ir_instructions.append(f"input float {parameter};") + frame_wf_declarations = self._generate_frame_wf_defcal_declarations(gate_definitions) + ir_instructions.extend(f"input float {parameter};" for parameter in self.parameters) if not self.result_types: - ir_instructions.append(f"bit[{self.qubit_count}] b;") + bit_count = ( + len(self._measure_targets) + if self._measure_targets is not None + else self.qubit_count + ) + ir_instructions.append(f"bit[{bit_count}] b;") if serialization_properties.qubit_reference_type == QubitReferenceType.VIRTUAL: total_qubits = max(self.qubits).real + 1 @@ -1248,18 +1361,17 @@ def _create_openqasm_header( f"{serialization_properties.qubit_reference_type} supplied." ) - frame_wf_declarations = self._generate_frame_wf_defcal_declarations(gate_definitions) if frame_wf_declarations: ir_instructions.append(frame_wf_declarations) return ir_instructions - def _validate_gate_calbrations_uniqueness( + def _validate_gate_calibrations_uniqueness( self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], - frames: dict[Frame], - waveforms: dict[ArbitraryWaveform], + frames: dict[str, Frame], + waveforms: dict[str, Waveform], ) -> None: - for key, calibration in gate_definitions.items(): + for calibration in gate_definitions.values(): for frame in calibration._frames.values(): _validate_uniqueness(frames, frame) frames[frame.id] = frame @@ -1268,45 +1380,59 @@ def _validate_gate_calbrations_uniqueness( waveforms[waveform.id] = waveform def _generate_frame_wf_defcal_declarations( - self, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] - ) -> Optional[str]: - program = oqpy.Program(None) + self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None + ) -> str | None: + """Generates the header where frames, waveforms and defcals are declared. + + It also adds any FreeParameter of the calibrations to the circuit parameter set. + + Args: + gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): The + calibration data for the device. + + Returns: + str | None: An OpenQASM string + """ + + program = oqpy.Program(None, simplify_constants=False) frames, waveforms = self._get_frames_waveforms_from_instrs(gate_definitions) - if gate_definitions is not None: - self._validate_gate_calbrations_uniqueness(gate_definitions, frames, waveforms) + self._validate_gate_calibrations_uniqueness(gate_definitions, frames, waveforms) # Declare the frames and waveforms across all pulse sequences declarable_frames = [f for f in frames.values() if not f.is_predefined] - if declarable_frames or waveforms or gate_definitions is not None: + if declarable_frames or waveforms or gate_definitions: frame_wf_to_declare = [f._to_oqpy_expression() for f in declarable_frames] frame_wf_to_declare += [wf._to_oqpy_expression() for wf in waveforms.values()] program.declare(frame_wf_to_declare, encal=True) - if gate_definitions is not None: - for key, calibration in gate_definitions.items(): - gate, qubits = key - - # Ignoring parametric gates - # Corresponding defcals with fixed arguments have been added - # in _get_frames_waveforms_from_instrs - if isinstance(gate, Parameterizable) and any( - not isinstance(parameter, (float, int, complex)) - for parameter in gate.parameters - ): - continue - - gate_name = gate._qasm_name - arguments = ( - [calibration._format_parameter_ast(value) for value in gate.parameters] - if isinstance(gate, Parameterizable) - else None - ) - with oqpy.defcal( - program, [oqpy.PhysicalQubits[int(k)] for k in qubits], gate_name, arguments - ): - program += calibration._program + for key, calibration in gate_definitions.items(): + gate, qubits = key + + # Ignoring parametric gates + # Corresponding defcals with fixed arguments have been added + # in _get_frames_waveforms_from_instrs + if isinstance(gate, Parameterizable) and any( + not isinstance(parameter, (float, int, complex)) + for parameter in gate.parameters + ): + continue + + gate_name = gate._qasm_name + arguments = gate.parameters if isinstance(gate, Parameterizable) else [] + + for param in calibration.parameters: + self._parameters.add(param) + arguments = [ + param._to_oqpy_expression() if isinstance(param, FreeParameter) else param + for param in arguments + ] + + with oqpy.defcal( + program, [oqpy.PhysicalQubits[int(k)] for k in qubits], gate_name, arguments + ): + program += calibration._program ast = program.to_ast(encal=False, include_externs=False) return ast_to_qasm(ast) @@ -1314,8 +1440,8 @@ def _generate_frame_wf_defcal_declarations( return None def _get_frames_waveforms_from_instrs( - self, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] - ) -> tuple[dict[Frame], dict[ArbitraryWaveform]]: + self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] + ) -> tuple[dict[str, Frame], dict[str, Waveform]]: from braket.circuits.gates import PulseGate frames = {} @@ -1329,11 +1455,11 @@ def _get_frames_waveforms_from_instrs( _validate_uniqueness(waveforms, waveform) waveforms[waveform.id] = waveform # this will change with full parametric calibration support - elif isinstance(instruction.operator, Parameterizable) and gate_definitions is not None: + elif isinstance(instruction.operator, Parameterizable): fixed_argument_calibrations = self._add_fixed_argument_calibrations( gate_definitions, instruction ) - gate_definitions.update(fixed_argument_calibrations) + gate_definitions |= fixed_argument_calibrations return frames, waveforms def _add_fixed_argument_calibrations( @@ -1376,7 +1502,7 @@ def _add_fixed_argument_calibrations( instruction.operator.parameters ) == len(gate.parameters): free_parameter_number = sum( - [isinstance(p, FreeParameterExpression) for p in gate.parameters] + isinstance(p, FreeParameterExpression) for p in gate.parameters ) if free_parameter_number == 0: continue @@ -1403,8 +1529,7 @@ def _add_fixed_argument_calibrations( return additional_calibrations def to_unitary(self) -> np.ndarray: - """ - Returns the unitary matrix representation of the entire circuit. + """Returns the unitary matrix representation of the entire circuit. Note: The performance of this method degrades with qubit count. It might be slow for @@ -1431,10 +1556,10 @@ def to_unitary(self) -> np.ndarray: [ 0.70710678+0.j, 0. +0.j, -0.70710678+0.j, 0. +0.j]]) """ - qubits = self.qubits - if not qubits: + if qubits := self.qubits: + return calculate_unitary_big_endian(self.instructions, qubits) + else: return np.zeros(0, dtype=complex) - return calculate_unitary_big_endian(self.instructions, qubits) @property def qubits_frozen(self) -> bool: @@ -1467,8 +1592,7 @@ def _copy(self) -> Circuit: return copy def copy(self) -> Circuit: - """ - Return a shallow copy of the circuit. + """Return a shallow copy of the circuit. Returns: Circuit: A shallow copy of the circuit. @@ -1493,27 +1617,27 @@ def __repr__(self) -> str: ) def __str__(self): - return self.diagram(AsciiCircuitDiagram) + return self.diagram() - def __eq__(self, other): + def __eq__(self, other: Circuit): if isinstance(other, Circuit): return ( self.instructions == other.instructions and self.result_types == other.result_types ) return NotImplemented - def __call__(self, arg: Any | None = None, **kwargs) -> Circuit: - """ - Implements the call function to easily make a bound Circuit. + def __call__(self, arg: Any | None = None, **kwargs: Any) -> Circuit: + """Implements the call function to easily make a bound Circuit. Args: arg (Any | None): A value to bind to all parameters. Defaults to None and can be overridden if the parameter is in kwargs. + **kwargs (Any): The parameter and valued to be bound. Returns: Circuit: A circuit with the specified parameters bound. """ - param_values = dict() + param_values = {} if arg is not None: for param in self.parameters: param_values[str(param)] = arg @@ -1523,8 +1647,7 @@ def __call__(self, arg: Any | None = None, **kwargs) -> Circuit: def subroutine(register: bool = False) -> Callable: - """ - Subroutine is a function that returns instructions, result types, or circuits. + """Subroutine is a function that returns instructions, result types, or circuits. Args: register (bool): If `True`, adds this subroutine into the `Circuit` class. diff --git a/src/braket/circuits/circuit_diagram.py b/src/braket/circuits/circuit_diagram.py index 5b156d290..cc39aa7ee 100644 --- a/src/braket/circuits/circuit_diagram.py +++ b/src/braket/circuits/circuit_diagram.py @@ -24,11 +24,10 @@ class CircuitDiagram(ABC): @staticmethod @abstractmethod def build_diagram(circuit: cir.Circuit) -> str: - """ - Build a diagram for the specified `circuit`. + """Build a diagram for the specified `circuit`. Args: - circuit (Circuit): The circuit to build a diagram for. + circuit (cir.Circuit): The circuit to build a diagram for. Returns: str: String representation for the circuit diagram. diff --git a/src/braket/circuits/circuit_helpers.py b/src/braket/circuits/circuit_helpers.py index f0e3f3144..1a50c4c83 100644 --- a/src/braket/circuits/circuit_helpers.py +++ b/src/braket/circuits/circuit_helpers.py @@ -15,8 +15,7 @@ def validate_circuit_and_shots(circuit: Circuit, shots: int) -> None: - """ - Validates if circuit and shots are correct before running on a device + """Validates if circuit and shots are correct before running on a device Args: circuit (Circuit): circuit to validate @@ -40,7 +39,7 @@ def validate_circuit_and_shots(circuit: Circuit, shots: int) -> None: if not circuit.observables_simultaneously_measurable: raise ValueError("Observables cannot be sampled simultaneously") for rt in circuit.result_types: - if isinstance(rt, ResultType.StateVector) or isinstance(rt, ResultType.Amplitude): + if isinstance(rt, (ResultType.Amplitude, ResultType.StateVector)): raise ValueError("StateVector or Amplitude cannot be specified when shots>0") elif isinstance(rt, ResultType.Probability): num_qubits = len(rt.target) or circuit.qubit_count diff --git a/src/braket/circuits/compiler_directive.py b/src/braket/circuits/compiler_directive.py index 628422c7e..ad2c701c6 100644 --- a/src/braket/circuits/compiler_directive.py +++ b/src/braket/circuits/compiler_directive.py @@ -28,7 +28,8 @@ class CompilerDirective(Operator): """ def __init__(self, ascii_symbols: Sequence[str]): - """ + """Inits a `CompilerDirective`. + Args: ascii_symbols (Sequence[str]): ASCII string symbols for the compiler directiver. These are used when printing a diagram of circuits. @@ -97,7 +98,7 @@ def counterpart(self) -> CompilerDirective: f"Compiler directive {self.name} does not have counterpart implemented" ) - def __eq__(self, other): + def __eq__(self, other: CompilerDirective): return isinstance(other, CompilerDirective) and self.name == other.name def __repr__(self): diff --git a/src/braket/circuits/compiler_directives.py b/src/braket/circuits/compiler_directives.py index 2533537f5..9376d338d 100644 --- a/src/braket/circuits/compiler_directives.py +++ b/src/braket/circuits/compiler_directives.py @@ -18,8 +18,7 @@ class StartVerbatimBox(CompilerDirective): - """ - Prevents the compiler from modifying any ensuing instructions + """Prevents the compiler from modifying any ensuing instructions until the appearance of a corresponding ``EndVerbatimBox``. """ @@ -37,8 +36,7 @@ def _to_openqasm(self) -> str: class EndVerbatimBox(CompilerDirective): - """ - Marks the end of a portion of code following a StartVerbatimBox that prevents the enclosed + """Marks the end of a portion of code following a StartVerbatimBox that prevents the enclosed instructions from being modified by the compiler. """ diff --git a/src/braket/circuits/gate.py b/src/braket/circuits/gate.py index 907c495ce..453b121fd 100644 --- a/src/braket/circuits/gate.py +++ b/src/braket/circuits/gate.py @@ -28,14 +28,14 @@ class Gate(QuantumOperator): - """ - Class `Gate` represents a quantum gate that operates on N qubits. Gates are considered the + """Class `Gate` represents a quantum gate that operates on N qubits. Gates are considered the building blocks of quantum circuits. This class is considered the gate definition containing the metadata that defines what a gate is and what it does. """ def __init__(self, qubit_count: Optional[int], ascii_symbols: Sequence[str]): - """ + """Initializes a `Gate`. + Args: qubit_count (Optional[int]): Number of qubits this gate interacts with. ascii_symbols (Sequence[str]): ASCII string symbols for the gate. These are used when @@ -76,7 +76,7 @@ def to_ir( control_state: Optional[BasisStateInput] = None, power: float = 1, ) -> Any: - """Returns IR object of quantum operator and target + r"""Returns IR object of quantum operator and target Args: target (QubitSet): target qubit(s). @@ -97,6 +97,7 @@ def to_ir( power (float): Integer or fractional power to raise the gate to. Negative powers will be split into an inverse, accompanied by the positive power. Default 1. + Returns: Any: IR object of the quantum operator and target @@ -128,8 +129,7 @@ def to_ir( raise ValueError(f"Supplied ir_type {ir_type} is not supported.") def _to_jaqcd(self, target: QubitSet) -> Any: - """ - Returns the JAQCD representation of the gate. + """Returns the JAQCD representation of the gate. Args: target (QubitSet): target qubit(s). @@ -148,8 +148,7 @@ def _to_openqasm( control_state: Optional[BasisStateInput] = None, power: float = 1, ) -> str: - """ - Returns the openqasm string representation of the gate. + """Returns the OpenQASM string representation of the gate. Args: target (QubitSet): target qubit(s). @@ -208,7 +207,7 @@ def ascii_symbols(self) -> tuple[str, ...]: """tuple[str, ...]: Returns the ascii symbols for the quantum operator.""" return self._ascii_symbols - def __eq__(self, other): + def __eq__(self, other: Gate): return isinstance(other, Gate) and self.name == other.name def __repr__(self): diff --git a/src/braket/circuits/gate_calibrations.py b/src/braket/circuits/gate_calibrations.py index 57013df4a..69ff66254 100644 --- a/src/braket/circuits/gate_calibrations.py +++ b/src/braket/circuits/gate_calibrations.py @@ -27,8 +27,7 @@ class GateCalibrations: - """ - An object containing gate calibration data. The data respresents the mapping on a particular gate + """An object containing gate calibration data. The data represents the mapping on a particular gate on a set of qubits to its calibration to be used by a quantum device. This is represented by a dictionary with keys of `Tuple(Gate, QubitSet)` mapped to a `PulseSequence`. """ # noqa: E501 @@ -37,18 +36,18 @@ def __init__( self, pulse_sequences: dict[tuple[Gate, QubitSet], PulseSequence], ): - """ + """Inits a `GateCalibrations`. + Args: - pulse_sequences (dict[tuple[Gate, QubitSet], PulseSequence]): A mapping containing a key of - `(Gate, QubitSet)` mapped to the corresponding pulse sequence. + pulse_sequences (dict[tuple[Gate, QubitSet], PulseSequence]): A mapping containing a key + of `(Gate, QubitSet)` mapped to the corresponding pulse sequence. - """ # noqa: E501 + """ self.pulse_sequences: dict[tuple[Gate, QubitSet], PulseSequence] = pulse_sequences @property def pulse_sequences(self) -> dict[tuple[Gate, QubitSet], PulseSequence]: - """ - Gets the mapping of (Gate, Qubit) to the corresponding `PulseSequence`. + """Gets the mapping of (Gate, Qubit) to the corresponding `PulseSequence`. Returns: dict[tuple[Gate, QubitSet], PulseSequence]: The calibration data Dictionary. @@ -57,8 +56,7 @@ def pulse_sequences(self) -> dict[tuple[Gate, QubitSet], PulseSequence]: @pulse_sequences.setter def pulse_sequences(self, value: Any) -> None: - """ - Sets the mapping of (Gate, Qubit) to the corresponding `PulseSequence`. + """Sets the mapping of (Gate, Qubit) to the corresponding `PulseSequence`. Args: value(Any): The value for the pulse_sequences property to be set to. @@ -79,8 +77,7 @@ def pulse_sequences(self, value: Any) -> None: ) def copy(self) -> GateCalibrations: - """ - Returns a copy of the object. + """Returns a copy of the object. Returns: GateCalibrations: a copy of the calibrations. @@ -95,8 +92,7 @@ def filter( gates: list[Gate] | None = None, qubits: QubitSet | list[QubitSet] | None = None, ) -> GateCalibrations: - """ - Filters the data based on optional lists of gates and QubitSets. + """Filters the data based on optional lists of gates and QubitSets. Args: gates (list[Gate] | None): An optional list of gates to filter on. @@ -105,7 +101,7 @@ def filter( Returns: GateCalibrations: A filtered GateCalibrations object. - """ # noqa: E501 + """ keys = self.pulse_sequences.keys() if isinstance(qubits, QubitSet): qubits = [qubits] @@ -120,13 +116,15 @@ def filter( ) def to_ir(self, calibration_key: tuple[Gate, QubitSet] | None = None) -> str: - """ - Returns the defcal representation for the `GateCalibrations` object. + """Returns the defcal representation for the `GateCalibrations` object. Args: calibration_key (tuple[Gate, QubitSet] | None): An optional key to get a specific defcal. Default: None + Raises: + ValueError: Key does not exist in the `GateCalibrations` object. + Returns: str: the defcal string for the object. @@ -162,5 +160,5 @@ def _def_cal_gate(self, gate_key: tuple[Gate, QubitSet]) -> str: ] ) - def __eq__(self, other): + def __eq__(self, other: GateCalibrations): return isinstance(other, GateCalibrations) and other.pulse_sequences == self.pulse_sequences diff --git a/src/braket/circuits/gates.py b/src/braket/circuits/gates.py index e955c4bb0..ee5ea684b 100644 --- a/src/braket/circuits/gates.py +++ b/src/braket/circuits/gates.py @@ -24,6 +24,7 @@ from braket.circuits import circuit from braket.circuits.angled_gate import ( AngledGate, + DoubleAngledGate, TripleAngledGate, _get_angles, _multi_angled_ascii_characters, @@ -137,7 +138,7 @@ def h( Gate.register_gate(H) -class I(Gate): # noqa: E742, E261 +class I(Gate): # noqa: E742 r"""Identity gate. Unitary matrix: @@ -220,10 +221,14 @@ class GPhase(AngledGate): Unitary matrix: - .. math:: \mathtt{gphase}(\gamma) = e^(i \gamma) I_1. + .. math:: \mathtt{gphase}(\gamma) = e^{i \gamma} I_1 = \begin{bmatrix} + e^{i \gamma} \end{bmatrix}. Args: angle (Union[FreeParameterExpression, float]): angle in radians. + + Raises: + ValueError: If `angle` is not present """ def __init__(self, angle: Union[FreeParameterExpression, float]): @@ -274,7 +279,8 @@ def gphase( Unitary matrix: - .. math:: \mathtt{gphase}(\gamma) = e^(i \gamma) I_1. + .. math:: \mathtt{gphase}(\gamma) = e^{i \gamma} I_1 = \begin{bmatrix} + e^{i \gamma} \end{bmatrix}. Args: angle (Union[FreeParameterExpression, float]): Phase in radians. @@ -1066,8 +1072,9 @@ def _to_jaqcd(self, target: QubitSet, **kwargs) -> Any: def to_matrix(self) -> np.ndarray: r"""Returns a matrix representation of this gate. + Returns: - ndarray: The matrix representation of this gate. + np.ndarray: The matrix representation of this gate. """ cos = np.cos(self.angle / 2) sin = np.sin(self.angle / 2) @@ -1158,8 +1165,9 @@ def _to_jaqcd(self, target: QubitSet) -> Any: def to_matrix(self) -> np.ndarray: r"""Returns a matrix representation of this gate. + Returns: - ndarray: The matrix representation of this gate. + np.ndarray: The matrix representation of this gate. """ cos = np.cos(self.angle / 2) sin = np.sin(self.angle / 2) @@ -1435,8 +1443,9 @@ def _qasm_name(self) -> str: def to_matrix(self) -> np.ndarray: r"""Returns a matrix representation of this gate. + Returns: - ndarray: The matrix representation of this gate. + np.ndarray: The matrix representation of this gate. """ _theta = self.angle_1 _phi = self.angle_2 @@ -1928,8 +1937,9 @@ def _to_jaqcd(self, target: QubitSet) -> Any: def to_matrix(self) -> np.ndarray: r"""Returns a matrix representation of this gate. + Returns: - ndarray: The matrix representation of this gate. + np.ndarray: The matrix representation of this gate. """ cos = np.cos(self.angle / 2) sin = np.sin(self.angle / 2) @@ -2694,8 +2704,9 @@ def _to_jaqcd(self, target: QubitSet) -> Any: def to_matrix(self) -> np.ndarray: r"""Returns a matrix representation of this gate. + Returns: - ndarray: The matrix representation of this gate. + np.ndarray: The matrix representation of this gate. """ cos = np.cos(self.angle / 2) isin = 1.0j * np.sin(self.angle / 2) @@ -2806,8 +2817,9 @@ def _to_jaqcd(self, target: QubitSet) -> Any: def to_matrix(self) -> np.ndarray: r"""Returns a matrix representation of this gate. + Returns: - ndarray: The matrix representation of this gate. + np.ndarray: The matrix representation of this gate. """ cos = np.cos(self.angle / 2) isin = 1.0j * np.sin(self.angle / 2) @@ -3287,12 +3299,130 @@ def gpi( Gate.register_gate(GPi) +class PRx(DoubleAngledGate): + r"""Phase Rx gate. + + Unitary matrix: + + .. math:: \mathtt{PRx}(\theta,\phi) = \begin{bmatrix} + \cos{(\theta / 2)} & -i e^{-i \phi} \sin{(\theta / 2)} \\ + -i e^{i \phi} \sin{(\theta / 2)} & \cos{(\theta / 2)} + \end{bmatrix}. + + Args: + angle_1 (Union[FreeParameterExpression, float]): The first angle of the gate in + radians or expression representation. + angle_2 (Union[FreeParameterExpression, float]): The second angle of the gate in + radians or expression representation. + """ + + def __init__( + self, + angle_1: Union[FreeParameterExpression, float], + angle_2: Union[FreeParameterExpression, float], + ): + super().__init__( + angle_1=angle_1, + angle_2=angle_2, + qubit_count=None, + ascii_symbols=[_multi_angled_ascii_characters("PRx", angle_1, angle_2)], + ) + + @property + def _qasm_name(self) -> str: + return "prx" + + def to_matrix(self) -> np.ndarray: + """Returns a matrix representation of this gate. + + Returns: + np.ndarray: The matrix representation of this gate. + """ + theta = self.angle_1 + phi = self.angle_2 + return np.array( + [ + [ + np.cos(theta / 2), + -1j * np.exp(-1j * phi) * np.sin(theta / 2), + ], + [ + -1j * np.exp(1j * phi) * np.sin(theta / 2), + np.cos(theta / 2), + ], + ] + ) + + def adjoint(self) -> list[Gate]: + return [PRx(-self.angle_1, self.angle_2)] + + @staticmethod + def fixed_qubit_count() -> int: + return 1 + + def bind_values(self, **kwargs) -> PRx: + return _get_angles(self, **kwargs) + + @staticmethod + @circuit.subroutine(register=True) + def prx( + target: QubitSetInput, + angle_1: Union[FreeParameterExpression, float], + angle_2: Union[FreeParameterExpression, float], + *, + control: Optional[QubitSetInput] = None, + control_state: Optional[BasisStateInput] = None, + power: float = 1, + ) -> Iterable[Instruction]: + r"""PhaseRx gate. + + .. math:: \mathtt{PRx}(\theta,\phi) = \begin{bmatrix} + \cos{(\theta / 2)} & -i e^{-i \phi} \sin{(\theta / 2)} \\ + -i e^{i \phi} \sin{(\theta / 2)} & \cos{(\theta / 2)} + \end{bmatrix}. + + Args: + target (QubitSetInput): Target qubit(s). + angle_1 (Union[FreeParameterExpression, float]): First angle in radians. + angle_2 (Union[FreeParameterExpression, float]): Second angle in radians. + control (Optional[QubitSetInput]): Control qubit(s). Default None. + control_state (Optional[BasisStateInput]): Quantum state on which to control the + operation. Must be a binary sequence of same length as number of qubits in + `control`. Will be ignored if `control` is not present. May be represented as a + string, list, or int. For example "0101", [0, 1, 0, 1], 5 all represent + controlling on qubits 0 and 2 being in the \\|0⟩ state and qubits 1 and 3 being + in the \\|1⟩ state. Default "1" * len(control). + power (float): Integer or fractional power to raise the gate to. Negative + powers will be split into an inverse, accompanied by the positive power. + Default 1. + + Returns: + Iterable[Instruction]: PhaseRx instruction. + + Examples: + >>> circ = Circuit().prx(0, 0.15, 0.25) + """ + return [ + Instruction( + PRx(angle_1, angle_2), + target=qubit, + control=control, + control_state=control_state, + power=power, + ) + for qubit in QubitSet(target) + ] + + +Gate.register_gate(PRx) + + class GPi2(AngledGate): r"""IonQ GPi2 gate. Unitary matrix: - .. math:: \mathtt{GPi2}(\phi) = \begin{bmatrix} + .. math:: \mathtt{GPi2}(\phi) = \frac{1}{\sqrt{2}} \begin{bmatrix} 1 & -i e^{-i \phi} \\ -i e^{i \phi} & 1 \end{bmatrix}. @@ -3342,7 +3472,7 @@ def gpi2( ) -> Iterable[Instruction]: r"""IonQ GPi2 gate. - .. math:: \mathtt{GPi2}(\phi) = \begin{bmatrix} + .. math:: \mathtt{GPi2}(\phi) = \frac{1}{\sqrt{2}} \begin{bmatrix} 1 & -i e^{-i \phi} \\ -i e^{i \phi} & 1 \end{bmatrix}. @@ -3398,7 +3528,7 @@ class MS(TripleAngledGate): angle_1 (Union[FreeParameterExpression, float]): angle in radians. angle_2 (Union[FreeParameterExpression, float]): angle in radians. angle_3 (Union[FreeParameterExpression, float]): angle in radians. - Default value is angle_3=pi/2. + Default value is angle_3=pi/2. """ def __init__( @@ -3554,7 +3684,7 @@ def adjoint(self) -> list[Gate]: def _to_jaqcd(self, target: QubitSet) -> Any: return ir.Unitary.construct( - targets=[qubit for qubit in target], + targets=list(target), matrix=Unitary._transform_matrix_to_ir(self._matrix), ) @@ -3571,10 +3701,8 @@ def _to_openqasm( return f"#pragma braket unitary({formatted_matrix}) {', '.join(qubits)}" - def __eq__(self, other): - if isinstance(other, Unitary): - return self.matrix_equivalence(other) - return False + def __eq__(self, other: Unitary): + return self.matrix_equivalence(other) if isinstance(other, Unitary) else False def __hash__(self): return hash((self.name, str(self._matrix), self.qubit_count)) @@ -3647,8 +3775,7 @@ def parameters(self) -> list[FreeParameter]: return list(self._pulse_sequence.parameters) def bind_values(self, **kwargs) -> PulseGate: - """ - Takes in parameters and returns an object with specified parameters + """Takes in parameters and returns an object with specified parameters replaced with their values. Returns: @@ -3681,7 +3808,7 @@ def pulse_gate( control_state: Optional[BasisStateInput] = None, power: float = 1, ) -> Instruction: - """Arbitrary pulse gate which provides the ability to embed custom pulse sequences + r"""Arbitrary pulse gate which provides the ability to embed custom pulse sequences within circuits. Args: @@ -3721,8 +3848,7 @@ def pulse_gate( def format_complex(number: complex) -> str: - """ - Format a complex number into + im to be consumed by the braket unitary pragma + """Format a complex number into + im to be consumed by the braket unitary pragma Args: number (complex): A complex number. @@ -3731,13 +3857,11 @@ def format_complex(number: complex) -> str: str: The formatted string. """ if number.real: - if number.imag: - imag_sign = "+" if number.imag > 0 else "-" - return f"{number.real} {imag_sign} {abs(number.imag)}im" - else: + if not number.imag: return f"{number.real}" + imag_sign = "+" if number.imag > 0 else "-" + return f"{number.real} {imag_sign} {abs(number.imag)}im" + elif number.imag: + return f"{number.imag}im" else: - if number.imag: - return f"{number.imag}im" - else: - return "0" + return "0" diff --git a/src/braket/circuits/instruction.py b/src/braket/circuits/instruction.py index 0b2f90d01..bedfd0c44 100644 --- a/src/braket/circuits/instruction.py +++ b/src/braket/circuits/instruction.py @@ -29,8 +29,7 @@ class Instruction: - """ - An instruction is a quantum directive that describes the quantum task to perform on a quantum + """An instruction is a quantum directive that describes the quantum task to perform on a quantum device. """ @@ -43,8 +42,7 @@ def __init__( control_state: Optional[BasisStateInput] = None, power: float = 1, ) -> Instruction: - """ - InstructionOperator includes objects of type `Gate` and `Noise` only. + """InstructionOperator includes objects of type `Gate` and `Noise` only. Args: operator (InstructionOperator): Operator for the instruction. @@ -110,30 +108,22 @@ def operator(self) -> InstructionOperator: @property def target(self) -> QubitSet: - """ - QubitSet: Target qubits that the operator is applied to. - """ + """QubitSet: Target qubits that the operator is applied to.""" return self._target @property def control(self) -> QubitSet: - """ - QubitSet: Target qubits that the operator is controlled on. - """ + """QubitSet: Target qubits that the operator is controlled on.""" return self._control @property def control_state(self) -> BasisState: - """ - BasisState: Quantum state that the operator is controlled to. - """ + """BasisState: Quantum state that the operator is controlled to.""" return self._control_state @property def power(self) -> float: - """ - float: Power that the operator is raised to. - """ + """float: Power that the operator is raised to.""" return self._power def adjoint(self) -> list[Instruction]: @@ -168,8 +158,7 @@ def to_ir( ir_type: IRType = IRType.JAQCD, serialization_properties: SerializationProperties | None = None, ) -> Any: - """ - Converts the operator into the canonical intermediate representation. + """Converts the operator into the canonical intermediate representation. If the operator is passed in a request, this method is called before it is passed. Args: @@ -209,8 +198,7 @@ def copy( control_state: Optional[BasisStateInput] = None, power: float = 1, ) -> Instruction: - """ - Return a shallow copy of the instruction. + """Return a shallow copy of the instruction. Note: If `target_mapping` is specified, then `self.target` is mapped to the specified @@ -282,7 +270,7 @@ def __repr__(self): f"'power': {self.power})" ) - def __eq__(self, other): + def __eq__(self, other: Instruction): if isinstance(other, Instruction): return ( self._operator, @@ -299,7 +287,7 @@ def __eq__(self, other): ) return NotImplemented - def __pow__(self, power, modulo=None): + def __pow__(self, power: float, modulo: float = None): new_power = self.power * power if modulo is not None: new_power %= modulo diff --git a/src/braket/circuits/measure.py b/src/braket/circuits/measure.py new file mode 100644 index 000000000..d31555ba9 --- /dev/null +++ b/src/braket/circuits/measure.py @@ -0,0 +1,99 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +from typing import Any + +from braket.circuits.quantum_operator import QuantumOperator +from braket.circuits.serialization import ( + IRType, + OpenQASMSerializationProperties, + SerializationProperties, +) +from braket.registers.qubit_set import QubitSet + + +class Measure(QuantumOperator): + """Class `Measure` represents a measure operation on targeted qubits""" + + def __init__(self, **kwargs): + """Inits a `Measure`. + + Raises: + ValueError: `qubit_count` is less than 1, `ascii_symbols` are `None`, or + `ascii_symbols` length != `qubit_count` + """ + super().__init__(qubit_count=1, ascii_symbols=["M"]) + self._target_index = kwargs.get("index") + + @property + def ascii_symbols(self) -> tuple[str]: + """tuple[str]: Returns the ascii symbols for the measure.""" + return self._ascii_symbols + + def to_ir( + self, + target: QubitSet | None = None, + ir_type: IRType = IRType.OPENQASM, + serialization_properties: SerializationProperties | None = None, + **kwargs, + ) -> Any: + """Returns IR object of the measure operator. + + Args: + target (QubitSet | None): target qubit(s). Defaults to None + ir_type(IRType) : The IRType to use for converting the measure object to its + IR representation. Defaults to IRType.OpenQASM. + serialization_properties (SerializationProperties | None): The serialization properties + to use while serializing the object to the IR representation. The serialization + properties supplied must correspond to the supplied `ir_type`. Defaults to None. + + Returns: + Any: IR object of the measure operator. + + Raises: + ValueError: If the supplied `ir_type` is not supported. + """ + if ir_type == IRType.JAQCD: + return self._to_jaqcd() + elif ir_type == IRType.OPENQASM: + return self._to_openqasm( + target, serialization_properties or OpenQASMSerializationProperties() ** kwargs + ) + else: + raise ValueError(f"supplied ir_type {ir_type} is not supported.") + + def _to_jaqcd(self) -> Any: + """Returns the JAQCD representation of the measure.""" + raise NotImplementedError("measure instructions are not supported with JAQCD.") + + def _to_openqasm( + self, target: QubitSet, serialization_properties: OpenQASMSerializationProperties + ) -> str: + """Returns the openqasm string representation of the measure.""" + target_qubits = [serialization_properties.format_target(int(qubit)) for qubit in target] + instructions = [] + for idx, qubit in enumerate(target_qubits): + bit_index = ( + self._target_index if self._target_index and len(target_qubits) == 1 else idx + ) + instructions.append(f"b[{bit_index}] = measure {qubit};") + + return "\n".join(instructions) + + def __eq__(self, other: Measure): + return isinstance(other, Measure) + + def __repr__(self): + return self.name diff --git a/src/braket/circuits/moments.py b/src/braket/circuits/moments.py index 6e87db78d..b2dee4151 100644 --- a/src/braket/circuits/moments.py +++ b/src/braket/circuits/moments.py @@ -21,20 +21,21 @@ from braket.circuits.compiler_directive import CompilerDirective from braket.circuits.gate import Gate from braket.circuits.instruction import Instruction +from braket.circuits.measure import Measure from braket.circuits.noise import Noise from braket.registers.qubit import Qubit from braket.registers.qubit_set import QubitSet class MomentType(str, Enum): - """ - The type of moments. + """The type of moments. GATE: a gate NOISE: a noise channel added directly to the circuit GATE_NOISE: a gate-based noise channel INITIALIZATION_NOISE: a initialization noise channel READOUT_NOISE: a readout noise channel COMPILER_DIRECTIVE: an instruction to the compiler, external to the quantum program itself + MEASURE: a measurement """ GATE = "gate" @@ -44,10 +45,12 @@ class MomentType(str, Enum): READOUT_NOISE = "readout_noise" COMPILER_DIRECTIVE = "compiler_directive" GLOBAL_PHASE = "global_phase" + MEASURE = "measure" class MomentsKey(NamedTuple): """Key of the Moments mapping. + Args: time: moment qubits: qubit set @@ -65,8 +68,7 @@ class MomentsKey(NamedTuple): class Moments(Mapping[MomentsKey, Instruction]): - """ - An ordered mapping of `MomentsKey` or `NoiseMomentsKey` to `Instruction`. The + r"""An ordered mapping of `MomentsKey` or `NoiseMomentsKey` to `Instruction`. The core data structure that contains instructions, ordering they are inserted in, and time slices when they occur. `Moments` implements `Mapping` and functions the same as a read-only dictionary. It is mutable only through the `add()` method. @@ -77,7 +79,7 @@ class Moments(Mapping[MomentsKey, Instruction]): method. Args: - instructions (Iterable[Instruction], optional): Instructions to initialize self. + instructions (Iterable[Instruction] | None): Instructions to initialize self. Default = None. Examples: @@ -125,9 +127,8 @@ def qubit_count(self) -> int: @property def qubits(self) -> QubitSet: - """ - QubitSet: Get the qubits used across all of the instructions. The order of qubits is based - on the order in which the instructions were added. + """QubitSet: Get the qubits used across all of the instructions. The order of qubits is + based on the order in which the instructions were added. Note: Don't mutate this object, any changes may impact the behavior of this class and / or @@ -136,8 +137,7 @@ def qubits(self) -> QubitSet: return self._qubits def time_slices(self) -> dict[int, list[Instruction]]: - """ - Get instructions keyed by time. + """Get instructions keyed by time. Returns: dict[int, list[Instruction]]: Key is the time and value is a list of instructions that @@ -148,7 +148,6 @@ def time_slices(self) -> dict[int, list[Instruction]]: every call, with a computational runtime O(N) where N is the number of instructions in self. """ - time_slices = {} self.sort_moments() for key, instruction in self._moments.items(): @@ -161,8 +160,7 @@ def time_slices(self) -> dict[int, list[Instruction]]: def add( self, instructions: Union[Iterable[Instruction], Instruction], noise_index: int = 0 ) -> None: - """ - Add one or more instructions to self. + """Add one or more instructions to self. Args: instructions (Union[Iterable[Instruction], Instruction]): Instructions to add to self. @@ -196,6 +194,14 @@ def _add(self, instruction: Instruction, noise_index: int = 0) -> None: self._number_gphase_in_current_moment, ) self._moments[key] = instruction + elif isinstance(operator, Measure): + qubit_range = instruction.target.union(instruction.control) + time = self._get_qubit_times(self._max_times.keys()) + 1 + self._moments[MomentsKey(time, qubit_range, MomentType.MEASURE, noise_index)] = ( + instruction + ) + self._qubits.update(qubit_range) + self._depth = max(self._depth, time + 1) else: qubit_range = instruction.target.union(instruction.control) time = self._update_qubit_times(qubit_range) @@ -218,6 +224,7 @@ def add_noise( self, instruction: Instruction, input_type: str = "noise", noise_index: int = 0 ) -> None: """Adds noise to a moment. + Args: instruction (Instruction): Instruction to add. input_type (str): One of MomentType. @@ -231,14 +238,13 @@ def add_noise( time = 0 while MomentsKey(time, qubit_range, input_type, noise_index) in self._moments: - noise_index = noise_index + 1 + noise_index += 1 self._moments[MomentsKey(time, qubit_range, input_type, noise_index)] = instruction self._qubits.update(qubit_range) def sort_moments(self) -> None: - """ - Make the disordered moments in order. + """Make the disordered moments in order. 1. Make the readout noise in the end 2. Make the initialization noise at the beginning @@ -251,6 +257,7 @@ def sort_moments(self) -> None: key_readout_noise = [] moment_copy = OrderedDict() sorted_moment = OrderedDict() + last_measure = self._depth for key, instruction in self._moments.items(): moment_copy[key] = instruction @@ -258,6 +265,9 @@ def sort_moments(self) -> None: key_readout_noise.append(key) elif key.moment_type == MomentType.INITIALIZATION_NOISE: key_initialization_noise.append(key) + elif key.moment_type == MomentType.MEASURE: + last_measure = key.time + key_noise.append(key) else: key_noise.append(key) @@ -266,7 +276,7 @@ def sort_moments(self) -> None: for key in key_noise: sorted_moment[key] = moment_copy[key] # find the max time in the circuit and make it the time for readout noise - max_time = max(self._depth - 1, 0) + max_time = max(last_measure - 1, 0) for key in key_readout_noise: sorted_moment[ @@ -280,7 +290,7 @@ def _max_time_for_qubit(self, qubit: Qubit) -> int: return self._max_times.get(qubit, -1) # - # Implement abstract methods, default to calling selfs underlying dictionary + # Implement abstract methods, default to calling `self`'s underlying dictionary # def keys(self) -> KeysView[MomentsKey]: @@ -293,6 +303,7 @@ def items(self) -> ItemsView[MomentsKey, Instruction]: def values(self) -> ValuesView[Instruction]: """Return a view of self's instructions. + Returns: ValuesView[Instruction]: The (in-order) instructions. """ @@ -300,8 +311,7 @@ def values(self) -> ValuesView[Instruction]: return self._moments.values() def get(self, key: MomentsKey, default: Any | None = None) -> Instruction: - """ - Get the instruction in self by key. + """Get the instruction in self by key. Args: key (MomentsKey): Key of the instruction to fetch. @@ -312,7 +322,7 @@ def get(self, key: MomentsKey, default: Any | None = None) -> Instruction: """ return self._moments.get(key, default) - def __getitem__(self, key): + def __getitem__(self, key: MomentsKey): return self._moments.__getitem__(key) def __iter__(self): @@ -321,19 +331,17 @@ def __iter__(self): def __len__(self): return self._moments.__len__() - def __contains__(self, item): + def __contains__(self, item: MomentsKey): return self._moments.__contains__(item) - def __eq__(self, other): + def __eq__(self, other: Moments): if isinstance(other, Moments): return self._moments == other._moments return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Moments): result = self.__eq__(other) - if result is not NotImplemented: - return not result - return NotImplemented + return not result if result is not NotImplemented else NotImplemented def __repr__(self): return self._moments.__repr__() diff --git a/src/braket/circuits/noise.py b/src/braket/circuits/noise.py index af1bb422f..e5d4fdf8a 100644 --- a/src/braket/circuits/noise.py +++ b/src/braket/circuits/noise.py @@ -14,7 +14,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import Any, Optional, Union +from typing import Any, ClassVar, Optional, Union import numpy as np @@ -31,8 +31,7 @@ class Noise(QuantumOperator): - """ - Class `Noise` represents a noise channel that operates on one or multiple qubits. Noise + """Class `Noise` represents a noise channel that operates on one or multiple qubits. Noise are considered as building blocks of quantum circuits that simulate noise. It can be used as an operator in an `Instruction` object. It appears in the diagram when user prints a circuit with `Noise`. This class is considered the noise channel definition containing @@ -40,7 +39,8 @@ class Noise(QuantumOperator): """ def __init__(self, qubit_count: Optional[int], ascii_symbols: Sequence[str]): - """ + """Initializes a `Noise` object. + Args: qubit_count (Optional[int]): Number of qubits this noise channel interacts with. ascii_symbols (Sequence[str]): ASCII string symbols for this noise channel. These @@ -56,8 +56,7 @@ def __init__(self, qubit_count: Optional[int], ascii_symbols: Sequence[str]): @property def name(self) -> str: - """ - Returns the name of the quantum operator + """Returns the name of the quantum operator Returns: str: The name of the quantum operator as a string @@ -79,6 +78,7 @@ def to_ir( serialization_properties (SerializationProperties | None): The serialization properties to use while serializing the object to the IR representation. The serialization properties supplied must correspond to the supplied `ir_type`. Defaults to None. + Returns: Any: IR object of the quantum operator and target @@ -103,8 +103,7 @@ def to_ir( raise ValueError(f"Supplied ir_type {ir_type} is not supported.") def _to_jaqcd(self, target: QubitSet) -> Any: - """ - Returns the JAQCD representation of the noise. + """Returns the JAQCD representation of the noise. Args: target (QubitSet): target qubit(s). @@ -117,8 +116,7 @@ def _to_jaqcd(self, target: QubitSet) -> Any: def _to_openqasm( self, target: QubitSet, serialization_properties: OpenQASMSerializationProperties ) -> str: - """ - Returns the openqasm string representation of the noise. + """Returns the openqasm string representation of the noise. Args: target (QubitSet): target qubit(s). @@ -138,18 +136,16 @@ def to_matrix(self, *args, **kwargs) -> Iterable[np.ndarray]: """ raise NotImplementedError("to_matrix has not been implemented yet.") - def __eq__(self, other): - if isinstance(other, Noise): - return self.name == other.name - return False + def __eq__(self, other: Noise): + return self.name == other.name if isinstance(other, Noise) else False def __repr__(self): return f"{self.name}('qubit_count': {self.qubit_count})" @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of + this class. Args: noise (dict): A dictionary representation of an object of this class. @@ -166,6 +162,7 @@ def from_dict(cls, noise: dict) -> Noise: @classmethod def register_noise(cls, noise: type[Noise]) -> None: """Register a noise implementation by adding it into the Noise class. + Args: noise (type[Noise]): Noise class to register. """ @@ -173,8 +170,7 @@ def register_noise(cls, noise: type[Noise]) -> None: class SingleProbabilisticNoise(Noise, Parameterizable): - """ - Class `SingleProbabilisticNoise` represents the bit/phase flip noise channel on N qubits + """Class `SingleProbabilisticNoise` represents the bit/phase flip noise channel on N qubits parameterized by a single probability. """ @@ -185,7 +181,8 @@ def __init__( ascii_symbols: Sequence[str], max_probability: float = 0.5, ): - """ + """Initializes a `SingleProbabilisticNoise`. + Args: probability (Union[FreeParameterExpression, float]): The probability that the noise occurs. @@ -209,6 +206,7 @@ def __init__( @property def probability(self) -> float: """The probability that parametrizes the noise channel. + Returns: float: The probability that parametrizes the noise channel. """ @@ -222,9 +220,8 @@ def __str__(self): @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameter expressions - or bound values. + """Returns the parameters associated with the object, either unbound free parameter + expressions or bound values. Returns: list[Union[FreeParameterExpression, float]]: The free parameter expressions @@ -232,14 +229,13 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: """ return [self._probability] - def __eq__(self, other): - if isinstance(other, type(self)): + def __eq__(self, other: SingleProbabilisticNoise): + if isinstance(other, SingleProbabilisticNoise): return self.name == other.name and self.probability == other.probability return False def bind_values(self, **kwargs) -> SingleProbabilisticNoise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: SingleProbabilisticNoise: A new Noise object of the same type with the requested @@ -251,8 +247,7 @@ def bind_values(self, **kwargs) -> SingleProbabilisticNoise: raise NotImplementedError def to_dict(self) -> dict: - """ - Converts this object into a dictionary representation. + """Converts this object into a dictionary representation. Returns: dict: A dictionary object that represents this object. It can be converted back @@ -267,8 +262,7 @@ def to_dict(self) -> dict: class SingleProbabilisticNoise_34(SingleProbabilisticNoise): - """ - Class `SingleProbabilisticNoise` represents the Depolarizing and TwoQubitDephasing noise + """Class `SingleProbabilisticNoise` represents the Depolarizing and TwoQubitDephasing noise channels parameterized by a single probability. """ @@ -278,7 +272,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Initializes a `SingleProbabilisticNoise_34`. + Args: probability (Union[FreeParameterExpression, float]): The probability that the noise occurs. @@ -301,8 +296,7 @@ def __init__( class SingleProbabilisticNoise_1516(SingleProbabilisticNoise): - """ - Class `SingleProbabilisticNoise` represents the TwoQubitDepolarizing noise channel + """Class `SingleProbabilisticNoise` represents the TwoQubitDepolarizing noise channel parameterized by a single probability. """ @@ -312,7 +306,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Initializes a `SingleProbabilisticNoise_1516`. + Args: probability (Union[FreeParameterExpression, float]): The probability that the noise occurs. @@ -335,12 +330,11 @@ def __init__( class MultiQubitPauliNoise(Noise, Parameterizable): - """ - Class `MultiQubitPauliNoise` represents a general multi-qubit Pauli channel, + """Class `MultiQubitPauliNoise` represents a general multi-qubit Pauli channel, parameterized by up to 4**N - 1 probabilities. """ - _allowed_substrings = {"I", "X", "Y", "Z"} + _allowed_substrings: ClassVar = {"I", "X", "Y", "Z"} def __init__( self, @@ -372,7 +366,6 @@ def __init__( TypeError: If the type of the dictionary keys are not strings. If the probabilities are not floats. """ - super().__init__(qubit_count=qubit_count, ascii_symbols=ascii_symbols) self._probabilities = probabilities @@ -395,10 +388,8 @@ def __init__( total_prob += prob if not (1.0 >= total_prob >= 0.0): raise ValueError( - ( - "Total probability must be a real number in the interval [0, 1]. " - f"Total probability was {total_prob}." - ) + "Total probability must be a real number in the interval [0, 1]. " + f"Total probability was {total_prob}." ) @classmethod @@ -409,17 +400,13 @@ def _validate_pauli_string( raise TypeError(f"Type of {pauli_str} was not a string.") if len(pauli_str) != qubit_count: raise ValueError( - ( - "Length of each Pauli string must be equal to number of qubits. " - f"{pauli_str} had length {len(pauli_str)} instead of length {qubit_count}." - ) + "Length of each Pauli string must be equal to number of qubits. " + f"{pauli_str} had length {len(pauli_str)} instead of length {qubit_count}." ) if not set(pauli_str) <= allowed_substrings: raise ValueError( - ( - "Strings must be Pauli strings consisting of only [I, X, Y, Z]. " - f"Received {pauli_str}." - ) + "Strings must be Pauli strings consisting of only [I, X, Y, Z]. " + f"Received {pauli_str}." ) def __repr__(self): @@ -431,14 +418,15 @@ def __repr__(self): def __str__(self): return f"{self.name}({self._probabilities})" - def __eq__(self, other): - if isinstance(other, type(self)): + def __eq__(self, other: MultiQubitPauliNoise): + if isinstance(other, MultiQubitPauliNoise): return self.name == other.name and self._probabilities == other._probabilities return False @property def probabilities(self) -> dict[str, float]: """A map of a Pauli string to its corresponding probability. + Returns: dict[str, float]: A map of a Pauli string to its corresponding probability. """ @@ -446,9 +434,8 @@ def probabilities(self) -> dict[str, float]: @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameter expressions - or bound values. + """Returns the parameters associated with the object, either unbound free parameter + expressions or bound values. Parameters are in alphabetical order of the Pauli strings in `probabilities`. @@ -461,8 +448,7 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: ] def bind_values(self, **kwargs) -> MultiQubitPauliNoise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: MultiQubitPauliNoise: A new Noise object of the same type with the requested @@ -474,16 +460,16 @@ def bind_values(self, **kwargs) -> MultiQubitPauliNoise: raise NotImplementedError def to_dict(self) -> dict: - """ - Converts this object into a dictionary representation. + """Converts this object into a dictionary representation. Returns: dict: A dictionary object that represents this object. It can be converted back into this object using the `from_dict()` method. """ - probabilities = dict() - for pauli_string, prob in self._probabilities.items(): - probabilities[pauli_string] = _parameter_to_dict(prob) + probabilities = { + pauli_string: _parameter_to_dict(prob) + for pauli_string, prob in self._probabilities.items() + } return { "__class__": self.__class__.__name__, "probabilities": probabilities, @@ -493,8 +479,7 @@ def to_dict(self) -> dict: class PauliNoise(Noise, Parameterizable): - """ - Class `PauliNoise` represents the a single-qubit Pauli noise channel + """Class `PauliNoise` represents the a single-qubit Pauli noise channel acting on one qubit. It is parameterized by three probabilities. """ @@ -506,7 +491,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Initializes a `PauliNoise`. + Args: probX (Union[FreeParameterExpression, float]): The X coefficient of the Kraus operators in the channel. @@ -550,13 +536,13 @@ def _get_param_float(param: Union[FreeParameterExpression, float], param_name: s """ if isinstance(param, FreeParameterExpression): return 0 - else: - _validate_param_value(param, param_name) - return float(param) + _validate_param_value(param, param_name) + return float(param) @property def probX(self) -> Union[FreeParameterExpression, float]: - """ + """The probability of a Pauli X error. + Returns: Union[FreeParameterExpression, float]: The probability of a Pauli X error. """ @@ -564,7 +550,8 @@ def probX(self) -> Union[FreeParameterExpression, float]: @property def probY(self) -> Union[FreeParameterExpression, float]: - """ + """The probability of a Pauli Y error. + Returns: Union[FreeParameterExpression, float]: The probability of a Pauli Y error. """ @@ -572,7 +559,8 @@ def probY(self) -> Union[FreeParameterExpression, float]: @property def probZ(self) -> Union[FreeParameterExpression, float]: - """ + """The probability of a Pauli Z error. + Returns: Union[FreeParameterExpression, float]: The probability of a Pauli Z error. """ @@ -591,16 +579,15 @@ def __repr__(self): def __str__(self): return f"{self.name}({self._parameters[0]}, {self._parameters[1]}, {self._parameters[2]})" - def __eq__(self, other): - if isinstance(other, type(self)): + def __eq__(self, other: PauliNoise): + if isinstance(other, PauliNoise): return self.name == other.name and self._parameters == other._parameters return False @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameter expressions - or bound values. + """Returns the parameters associated with the object, either unbound free parameter + expressions or bound values. Parameters are in the order [probX, probY, probZ] @@ -611,8 +598,7 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: return self._parameters def bind_values(self, **kwargs) -> PauliNoise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: PauliNoise: A new Noise object of the same type with the requested @@ -624,8 +610,7 @@ def bind_values(self, **kwargs) -> PauliNoise: raise NotImplementedError def to_dict(self) -> dict: - """ - Converts this object into a dictionary representation. + """Converts this object into a dictionary representation. Returns: dict: A dictionary object that represents this object. It can be converted back @@ -642,8 +627,7 @@ def to_dict(self) -> dict: class DampingNoise(Noise, Parameterizable): - """ - Class `DampingNoise` represents a damping noise channel + """Class `DampingNoise` represents a damping noise channel on N qubits parameterized by gamma. """ @@ -653,7 +637,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Initializes a `DampingNoise`. + Args: gamma (Union[FreeParameterExpression, float]): Probability of damping. qubit_count (Optional[int]): The number of qubits to apply noise. @@ -661,7 +646,7 @@ def __init__( printing a diagram of a circuit. The length must be the same as `qubit_count`, and index ordering is expected to correlate with the target ordering on the instruction. - Raises: + Raises: ValueError: If `qubit_count` < 1, `ascii_symbols` is `None`, @@ -678,6 +663,7 @@ def __init__( @property def gamma(self) -> float: """Probability of damping. + Returns: float: Probability of damping. """ @@ -691,9 +677,8 @@ def __str__(self): @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameter expressions - or bound values. + """Returns the parameters associated with the object, either unbound free parameter + expressions or bound values. Returns: list[Union[FreeParameterExpression, float]]: The free parameter expressions @@ -701,14 +686,13 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: """ return [self._gamma] - def __eq__(self, other): - if isinstance(other, type(self)): + def __eq__(self, other: DampingNoise): + if isinstance(other, DampingNoise): return self.name == other.name and self.gamma == other.gamma return False def bind_values(self, **kwargs) -> DampingNoise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: DampingNoise: A new Noise object of the same type with the requested @@ -720,8 +704,7 @@ def bind_values(self, **kwargs) -> DampingNoise: raise NotImplementedError def to_dict(self) -> dict: - """ - Converts this object into a dictionary representation. + """Converts this object into a dictionary representation. Returns: dict: A dictionary object that represents this object. It can be converted back @@ -736,8 +719,7 @@ def to_dict(self) -> dict: class GeneralizedAmplitudeDampingNoise(DampingNoise): - """ - Class `GeneralizedAmplitudeDampingNoise` represents the generalized amplitude damping + """Class `GeneralizedAmplitudeDampingNoise` represents the generalized amplitude damping noise channel on N qubits parameterized by gamma and probability. """ @@ -748,7 +730,8 @@ def __init__( qubit_count: Optional[int], ascii_symbols: Sequence[str], ): - """ + """Inits a `GeneralizedAmplitudeDampingNoise`. + Args: gamma (Union[FreeParameterExpression, float]): Probability of damping. probability (Union[FreeParameterExpression, float]): Probability of the system being @@ -758,7 +741,7 @@ def __init__( printing a diagram of a circuit. The length must be the same as `qubit_count`, and index ordering is expected to correlate with the target ordering on the instruction. - Raises: + Raises: ValueError: If `qubit_count` < 1, `ascii_symbols` is `None`, @@ -776,6 +759,7 @@ def __init__( @property def probability(self) -> float: """Probability of the system being excited by the environment. + Returns: float: Probability of the system being excited by the environment. """ @@ -793,9 +777,8 @@ def __str__(self): @property def parameters(self) -> list[Union[FreeParameterExpression, float]]: - """ - Returns the parameters associated with the object, either unbound free parameter expressions - or bound values. + """Returns the parameters associated with the object, either unbound free parameter + expressions or bound values. Parameters are in the order [gamma, probability] @@ -805,8 +788,8 @@ def parameters(self) -> list[Union[FreeParameterExpression, float]]: """ return [self.gamma, self.probability] - def __eq__(self, other): - if isinstance(other, type(self)): + def __eq__(self, other: GeneralizedAmplitudeDampingNoise): + if isinstance(other, GeneralizedAmplitudeDampingNoise): return ( self.name == other.name and self.gamma == other.gamma @@ -815,8 +798,7 @@ def __eq__(self, other): return False def to_dict(self) -> dict: - """ - Converts this object into a dictionary representation. + """Converts this object into a dictionary representation. Returns: dict: A dictionary object that represents this object. It can be converted back diff --git a/src/braket/circuits/noise_helpers.py b/src/braket/circuits/noise_helpers.py index 06c0b3620..a73b7f338 100644 --- a/src/braket/circuits/noise_helpers.py +++ b/src/braket/circuits/noise_helpers.py @@ -32,19 +32,22 @@ def no_noise_applied_warning(noise_applied: bool) -> None: """Helper function to give a warning is noise is not applied. + Args: noise_applied (bool): True if the noise has been applied. """ - if noise_applied is False: + if not noise_applied: warnings.warn( "Noise is not applied to any gate, as there is no eligible gate in the circuit" " with the input criteria or there is no multi-qubit gate to apply" - " the multi-qubit noise." + " the multi-qubit noise.", + stacklevel=1, ) def wrap_with_list(an_item: Any) -> list[Any]: """Helper function to make the input parameter a list. + Args: an_item (Any): The item to wrap. @@ -61,6 +64,7 @@ def check_noise_target_gates(noise: Noise, target_gates: Iterable[type[Gate]]) - 1. whether all the elements in target_gates are a Gate type; 2. if `noise` is multi-qubit noise and `target_gates` contain gates with the number of qubits is the same as `noise.qubit_count`. + Args: noise (Noise): A Noise class object to be applied to the circuit. target_gates (Iterable[type[Gate]]): Gate class or @@ -93,7 +97,6 @@ def check_noise_target_unitary(noise: Noise, target_unitary: np.ndarray) -> None noise (Noise): A Noise class object to be applied to the circuit. target_unitary (ndarray): matrix of the target unitary gates """ - if not isinstance(target_unitary, np.ndarray): raise TypeError("target_unitary must be a np.ndarray type") @@ -104,11 +107,12 @@ def check_noise_target_unitary(noise: Noise, target_unitary: np.ndarray) -> None def check_noise_target_qubits( circuit: Circuit, target_qubits: Optional[QubitSetInput] = None ) -> QubitSet: - """ - Helper function to check whether all the target_qubits are positive integers. + """Helper function to check whether all the target_qubits are positive integers. + Args: - circuit (Circuit): A ciruit where `noise` is to be checked. + circuit (Circuit): A circuit where `noise` is to be checked. target_qubits (Optional[QubitSetInput]): Index or indices of qubit(s). + Returns: QubitSet: The target qubits. """ @@ -118,7 +122,7 @@ def check_noise_target_qubits( target_qubits = wrap_with_list(target_qubits) if not all(isinstance(q, int) for q in target_qubits): raise TypeError("target_qubits must be integer(s)") - if not all(q >= 0 for q in target_qubits): + if any(q < 0 for q in target_qubits): raise ValueError("target_qubits must contain only non-negative integers.") target_qubits = QubitSet(target_qubits) @@ -129,8 +133,7 @@ def check_noise_target_qubits( def apply_noise_to_moments( circuit: Circuit, noise: Iterable[type[Noise]], target_qubits: QubitSet, position: str ) -> Circuit: - """ - Apply initialization/readout noise to the circuit. + """Apply initialization/readout noise to the circuit. When `noise.qubit_count` == 1, `noise` is added to all qubits in `target_qubits`. @@ -138,7 +141,7 @@ def apply_noise_to_moments( `target_qubits`. Args: - circuit (Circuit): A ciruit where `noise` is applied to. + circuit (Circuit): A circuit to `noise` is applied to. noise (Iterable[type[Noise]]): Noise channel(s) to be applied to the circuit. target_qubits (QubitSet): Index or indices of qubits. `noise` is applied to. @@ -206,11 +209,10 @@ def _apply_noise_to_gates_helper( Returns: tuple[Iterable[Instruction], int, bool]: A tuple of three values: - new_noise_instruction: A list of noise intructions + new_noise_instruction: A list of noise instructions noise_index: The number of noise channels applied to the gate noise_applied: Whether noise is applied or not """ - for noise_channel in noise: if noise_channel.qubit_count == 1: for qubit in intersection: @@ -218,17 +220,16 @@ def _apply_noise_to_gates_helper( noise_index += 1 new_noise_instruction.append((Instruction(noise_channel, qubit), noise_index)) noise_applied = True - else: - # only apply noise to the gates that have the same qubit_count as the noise. - if ( - instruction.operator.qubit_count == noise_channel.qubit_count - and instruction.target.issubset(target_qubits) - ): - noise_index += 1 - new_noise_instruction.append( - (Instruction(noise_channel, instruction.target), noise_index) - ) - noise_applied = True + # only apply noise to the gates that have the same qubit_count as the noise. + elif ( + instruction.operator.qubit_count == noise_channel.qubit_count + and instruction.target.issubset(target_qubits) + ): + noise_index += 1 + new_noise_instruction.append( + (Instruction(noise_channel, instruction.target), noise_index) + ) + noise_applied = True return new_noise_instruction, noise_index, noise_applied @@ -247,7 +248,7 @@ def apply_noise_to_gates( the same number of qubits as `noise.qubit_count`. Args: - circuit (Circuit): A ciruit where `noise` is applied to. + circuit (Circuit): A circuit where `noise` is applied to. noise (Iterable[type[Noise]]): Noise channel(s) to be applied to the circuit. target_gates (Union[Iterable[type[Gate]], ndarray]): List of gates, or a unitary matrix @@ -265,7 +266,6 @@ def apply_noise_to_gates( If no `target_gates` exist in `target_qubits` or in the whole circuit when `target_qubits` is not given. """ - new_moments = Moments() noise_applied = False diff --git a/src/braket/circuits/noise_model/circuit_instruction_criteria.py b/src/braket/circuits/noise_model/circuit_instruction_criteria.py index 170d3e996..4dceeb4fb 100644 --- a/src/braket/circuits/noise_model/circuit_instruction_criteria.py +++ b/src/braket/circuits/noise_model/circuit_instruction_criteria.py @@ -29,6 +29,9 @@ def instruction_matches(self, instruction: Instruction) -> bool: Args: instruction (Instruction): An Instruction to match. + Raises: + NotImplementedError: Not implemented. + Returns: bool: True if an Instruction matches the criteria. """ @@ -38,8 +41,7 @@ def instruction_matches(self, instruction: Instruction) -> bool: def _check_target_in_qubits( qubits: Optional[set[Union[int, tuple[int]]]], target: QubitSetInput ) -> bool: - """ - Returns true if the given targets of an instruction match the given qubit input set. + """Returns true if the given targets of an instruction match the given qubit input set. Args: qubits (Optional[set[Union[int, tuple[int]]]]): The qubits provided to the criteria. @@ -51,6 +53,4 @@ def _check_target_in_qubits( if qubits is None: return True target = [int(item) for item in target] - if len(target) == 1: - return target[0] in qubits - return tuple(target) in qubits + return target[0] in qubits if len(target) == 1 else tuple(target) in qubits diff --git a/src/braket/circuits/noise_model/criteria.py b/src/braket/circuits/noise_model/criteria.py index b9f0d2cc4..889211342 100644 --- a/src/braket/circuits/noise_model/criteria.py +++ b/src/braket/circuits/noise_model/criteria.py @@ -20,8 +20,8 @@ class CriteriaKey(str, Enum): - """ - Specifies the types of keys that a criteria may use to match an instruction, observable, etc. + """Specifies the types of keys that a criteria may use to match an instruction, + observable, etc. """ QUBIT = "QUBIT" @@ -31,8 +31,7 @@ class CriteriaKey(str, Enum): class CriteriaKeyResult(str, Enum): - """ - The get_keys() method may return this enum instead of actual keys for + """The get_keys() method may return this enum instead of actual keys for a given criteria key type. """ @@ -74,10 +73,10 @@ def __eq__(self, other: Criteria): return NotImplemented if self.applicable_key_types() != other.applicable_key_types(): return False - for key_type in self.applicable_key_types(): - if self.get_keys(key_type) != other.get_keys(key_type): - return False - return True + return all( + self.get_keys(key_type) == other.get_keys(key_type) + for key_type in self.applicable_key_types() + ) @abstractmethod def to_dict(self) -> dict: @@ -90,8 +89,8 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, criteria: dict) -> Criteria: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of this + class. Args: criteria (dict): A dictionary representation of an object of this class. diff --git a/src/braket/circuits/noise_model/criteria_input_parsing.py b/src/braket/circuits/noise_model/criteria_input_parsing.py index 11f24619e..456867ce2 100644 --- a/src/braket/circuits/noise_model/criteria_input_parsing.py +++ b/src/braket/circuits/noise_model/criteria_input_parsing.py @@ -19,10 +19,9 @@ def parse_operator_input( - operators: Union[QuantumOperator, Iterable[QuantumOperator]] + operators: Union[QuantumOperator, Iterable[QuantumOperator]], ) -> Optional[set[QuantumOperator]]: - """ - Processes the quantum operator input to __init__ to validate and return a set of + """Processes the quantum operator input to __init__ to validate and return a set of QuantumOperators. Args: @@ -49,8 +48,7 @@ def parse_operator_input( def parse_qubit_input( qubits: Optional[QubitSetInput], expected_qubit_count: Optional[int] = 0 ) -> Optional[set[Union[int, tuple[int]]]]: - """ - Processes the qubit input to __init__ to validate and return a set of qubit targets. + """Processes the qubit input to __init__ to validate and return a set of qubit targets. Args: qubits (Optional[QubitSetInput]): Qubit input. @@ -86,6 +84,4 @@ def parse_qubit_input( if qubit_count == 1: return {item[0] for item in qubits} return {tuple(item) for item in qubits} - if qubit_count > 1: - return {tuple(qubits)} - return set(qubits) + return {tuple(qubits)} if qubit_count > 1 else set(qubits) diff --git a/src/braket/circuits/noise_model/gate_criteria.py b/src/braket/circuits/noise_model/gate_criteria.py index 623c02477..7870e9b6b 100644 --- a/src/braket/circuits/noise_model/gate_criteria.py +++ b/src/braket/circuits/noise_model/gate_criteria.py @@ -33,8 +33,7 @@ def __init__( gates: Optional[Union[Gate, Iterable[Gate]]] = None, qubits: Optional[QubitSetInput] = None, ): - """ - Creates Gate-based Criteria. See instruction_matches() for more details. + """Creates Gate-based Criteria. See instruction_matches() for more details. Args: gates (Optional[Union[Gate, Iterable[Gate]]]): A set of relevant Gates. All the Gates @@ -60,7 +59,8 @@ def __repr__(self): return f"{self.__class__.__name__}(gates={gate_names}, qubits={self._qubits})" def applicable_key_types(self) -> Iterable[CriteriaKey]: - """ + """Returns an Iterable of criteria keys. + Returns: Iterable[CriteriaKey]: This Criteria operates on Gates and Qubits. """ @@ -86,8 +86,8 @@ def get_keys(self, key_type: CriteriaKey) -> Union[CriteriaKeyResult, set[Any]]: return set() def to_dict(self) -> dict: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of this + class. Returns: dict: A dictionary representing the serialized version of this Criteria. diff --git a/src/braket/circuits/noise_model/initialization_criteria.py b/src/braket/circuits/noise_model/initialization_criteria.py index e40d4e9de..4bf29a356 100644 --- a/src/braket/circuits/noise_model/initialization_criteria.py +++ b/src/braket/circuits/noise_model/initialization_criteria.py @@ -18,14 +18,11 @@ class InitializationCriteria(Criteria): - """ - Criteria that implement these methods may be used to determine initialization noise. - """ + """Criteria that implement these methods may be used to determine initialization noise.""" @abstractmethod def qubit_intersection(self, qubits: QubitSetInput) -> QubitSetInput: - """ - Returns subset of passed qubits that match the criteria. + """Returns subset of passed qubits that match the criteria. Args: qubits (QubitSetInput): A qubit or set of qubits that may match the criteria. diff --git a/src/braket/circuits/noise_model/noise_model.py b/src/braket/circuits/noise_model/noise_model.py index 8922cda72..e8a603075 100644 --- a/src/braket/circuits/noise_model/noise_model.py +++ b/src/braket/circuits/noise_model/noise_model.py @@ -56,8 +56,8 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, noise_model_item: dict) -> NoiseModelInstruction: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of + this class. Args: noise_model_item (dict): A dictionary representation of an object of this class. @@ -82,8 +82,7 @@ class NoiseModelInstructions: class NoiseModel: - """ - A Noise Model can be thought of as a set of Noise objects, and a corresponding set of + """A Noise Model can be thought of as a set of Noise objects, and a corresponding set of criteria for how each Noise object should be applied to a circuit. For example, a noise model may represent that every H gate that acts on qubit 0 has a 10% probability of a bit flip, and every X gate that acts on qubit 0 has a 20% probability of a bit flip, and 5% probability of @@ -110,8 +109,7 @@ def __str__(self): @property def instructions(self) -> list[NoiseModelInstruction]: - """ - List all the noise in the NoiseModel. + """List all the noise in the NoiseModel. Returns: list[NoiseModelInstruction]: The noise model instructions. @@ -157,8 +155,7 @@ def _add_instruction(self, instruction: NoiseModelInstruction) -> NoiseModel: return self def remove_noise(self, index: int) -> NoiseModel: - """ - Removes the noise and corresponding criteria from the NoiseModel at the given index. + """Removes the noise and corresponding criteria from the NoiseModel at the given index. Args: index (int): The index of the instruction to remove. @@ -175,6 +172,7 @@ def remove_noise(self, index: int) -> NoiseModel: def get_instructions_by_type(self) -> NoiseModelInstructions: """Returns the noise in this model by noise type. + Returns: NoiseModelInstructions: The noise model instructions. """ @@ -200,8 +198,7 @@ def from_filter( gate: Optional[Gate] = None, noise: Optional[type[Noise]] = None, ) -> NoiseModel: - """ - Returns a new NoiseModel from this NoiseModel using a given filter. If no filters are + """Returns a new NoiseModel from this NoiseModel using a given filter. If no filters are specified, the returned NoiseModel will be the same as this one. Args: @@ -235,8 +232,7 @@ class as the given noise class. return new_model def apply(self, circuit: Circuit) -> Circuit: - """ - Applies this noise model to a circuit, and returns a new circuit that's the `noisy` + """Applies this noise model to a circuit, and returns a new circuit that's the `noisy` version of the given circuit. If multiple noise will act on the same instruction, they will be applied in the order they are added to the noise model. @@ -261,8 +257,7 @@ def _apply_gate_noise( circuit: Circuit, gate_noise_instructions: list[NoiseModelInstruction], ) -> Circuit: - """ - Applies the gate noise to return a new circuit that's the `noisy` version of the given + """Applies the gate noise to return a new circuit that's the `noisy` version of the given circuit. Args: @@ -295,8 +290,8 @@ def _apply_init_noise( circuit: Circuit, init_noise_instructions: list[NoiseModelInstruction], ) -> Circuit: - """ - Applies the initialization noise of this noise model to a circuit and returns the circuit. + """Applies the initialization noise of this noise model to a circuit and returns + the circuit. Args: circuit (Circuit): A circuit to apply `noise` to. @@ -320,8 +315,7 @@ def _apply_readout_noise( circuit: Circuit, readout_noise_instructions: list[NoiseModelInstruction], ) -> Circuit: - """ - Applies the readout noise of this noise model to a circuit and returns the circuit. + """Applies the readout noise of this noise model to a circuit and returns the circuit. Args: circuit (Circuit): A circuit to apply `noise` to. @@ -339,8 +333,7 @@ def _apply_readout_noise( def _items_to_string( cls, instructions_title: str, instructions: list[NoiseModelInstruction] ) -> list[str]: - """ - Creates a string representation of a list of instructions. + """Creates a string representation of a list of instructions. Args: instructions_title (str): The title for this list of instructions. @@ -350,16 +343,15 @@ def _items_to_string( list[str]: A list of string representations of the passed instructions. """ results = [] - if len(instructions) > 0: + if instructions: results.append(instructions_title) - for item in instructions: - results.append(f" {item}") + results.extend(f" {item}" for item in instructions) return results @classmethod def from_dict(cls, noise_dict: dict) -> NoiseModel: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance + of this class. Args: noise_dict (dict): A dictionary representation of an object of this class. diff --git a/src/braket/circuits/noise_model/observable_criteria.py b/src/braket/circuits/noise_model/observable_criteria.py index 2275ccce7..5cb510f2e 100644 --- a/src/braket/circuits/noise_model/observable_criteria.py +++ b/src/braket/circuits/noise_model/observable_criteria.py @@ -33,8 +33,7 @@ def __init__( observables: Optional[Union[Observable, Iterable[Observable]]] = None, qubits: Optional[QubitSetInput] = None, ): - """ - Creates Observable-based Criteria. See instruction_matches() for more details. + """Creates Observable-based Criteria. See instruction_matches() for more details. Args: observables (Optional[Union[Observable, Iterable[Observable]]]): A set of relevant @@ -66,7 +65,8 @@ def __repr__(self): return f"{self.__class__.__name__}(observables={observables_names}, qubits={self._qubits})" def applicable_key_types(self) -> Iterable[CriteriaKey]: - """ + """Returns an Iterable of criteria keys. + Returns: Iterable[CriteriaKey]: This Criteria operates on Observables and Qubits. """ @@ -93,8 +93,8 @@ def get_keys(self, key_type: CriteriaKey) -> Union[CriteriaKeyResult, set[Any]]: return set() def to_dict(self) -> dict: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of + this class. Returns: dict: A dictionary representing the serialized version of this Criteria. @@ -116,6 +116,7 @@ def result_type_matches(self, result_type: ResultType) -> bool: Args: result_type (ResultType): A result type or list of result types to match. + Returns: bool: Returns true if the result type is one of the Observables provided in the constructor and the target is a qubit (or set of qubits)provided in the constructor. @@ -131,9 +132,7 @@ def result_type_matches(self, result_type: ResultType) -> bool: if self._qubits is None: return True target = list(result_type.target) - if not target: - return True - return target[0] in self._qubits + return target[0] in self._qubits if target else True @classmethod def from_dict(cls, criteria: dict) -> Criteria: diff --git a/src/braket/circuits/noise_model/qubit_initialization_criteria.py b/src/braket/circuits/noise_model/qubit_initialization_criteria.py index abed13af8..26594ca60 100644 --- a/src/braket/circuits/noise_model/qubit_initialization_criteria.py +++ b/src/braket/circuits/noise_model/qubit_initialization_criteria.py @@ -24,8 +24,7 @@ class QubitInitializationCriteria(InitializationCriteria): """This class models initialization noise Criteria based on qubits.""" def __init__(self, qubits: Optional[QubitSetInput] = None): - """ - Creates initialization noise Qubit-based Criteria. + """Creates initialization noise Qubit-based Criteria. Args: qubits (Optional[QubitSetInput]): A set of relevant qubits. If no qubits @@ -40,7 +39,8 @@ def __repr__(self): return f"{self.__class__.__name__}(qubits={self._qubits})" def applicable_key_types(self) -> Iterable[CriteriaKey]: - """ + """Gets the QUBIT criteria key. + Returns: Iterable[CriteriaKey]: This Criteria operates on Qubits, but is valid for all Gates. """ @@ -54,19 +54,17 @@ def get_keys(self, key_type: CriteriaKey) -> Union[CriteriaKeyResult, set[Any]]: Returns: Union[CriteriaKeyResult, set[Any]]: The return value is based on the key type: - QUBIT will return a set of qubit targets that are relevant to this Critera, or + QUBIT will return a set of qubit targets that are relevant to this Criteria, or CriteriaKeyResult.ALL if the Criteria is relevant for all (possible) qubits. All other keys will return an empty set. """ if key_type == CriteriaKey.QUBIT: - if self._qubits is None: - return CriteriaKeyResult.ALL - return set(self._qubits) + return CriteriaKeyResult.ALL if self._qubits is None else set(self._qubits) return set() def to_dict(self) -> dict: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of + this class. Returns: dict: A dictionary representing the serialized version of this Criteria. @@ -78,8 +76,7 @@ def to_dict(self) -> dict: } def qubit_intersection(self, qubits: QubitSetInput) -> QubitSetInput: - """ - Returns subset of passed qubits that match the criteria. + """Returns subset of passed qubits that match the criteria. Args: qubits (QubitSetInput): A qubit or set of qubits that may match the criteria. @@ -94,8 +91,7 @@ def qubit_intersection(self, qubits: QubitSetInput) -> QubitSetInput: @classmethod def from_dict(cls, criteria: dict) -> Criteria: - """ - Deserializes a dictionary into a Criteria object. + """Deserializes a dictionary into a Criteria object. Args: criteria (dict): A dictionary representation of a QubitCriteria. diff --git a/src/braket/circuits/noise_model/result_type_criteria.py b/src/braket/circuits/noise_model/result_type_criteria.py index 77c8d0d68..4d52c5c29 100644 --- a/src/braket/circuits/noise_model/result_type_criteria.py +++ b/src/braket/circuits/noise_model/result_type_criteria.py @@ -26,6 +26,7 @@ def result_type_matches(self, result_type: ResultType) -> bool: Args: result_type (ResultType): A result type or list of result types to match. + Returns: bool: True if the result type matches the criteria. """ diff --git a/src/braket/circuits/noise_model/unitary_gate_criteria.py b/src/braket/circuits/noise_model/unitary_gate_criteria.py index 229fc11ba..34b348e2e 100644 --- a/src/braket/circuits/noise_model/unitary_gate_criteria.py +++ b/src/braket/circuits/noise_model/unitary_gate_criteria.py @@ -26,14 +26,14 @@ class UnitaryGateCriteria(CircuitInstructionCriteria): """This class models noise Criteria based on unitary gates represented as a matrix.""" def __init__(self, unitary: Unitary, qubits: Optional[QubitSetInput] = None): - """ - Creates unitary gate-based Criteria. See instruction_matches() for more details. + """Creates unitary gate-based Criteria. See instruction_matches() for more details. Args: unitary (Unitary): A unitary gate matrix represented as a Braket Unitary. qubits (Optional[QubitSetInput]): A set of relevant qubits. If no qubits are provided, all (possible) qubits are considered to be relevant. - Throws: + + Raises: ValueError: If unitary is not a Unitary type. """ if not isinstance(unitary, Unitary): @@ -48,7 +48,8 @@ def __repr__(self): return f"{self.__class__.__name__}(unitary={self._unitary}, qubits={self._qubits})" def applicable_key_types(self) -> Iterable[CriteriaKey]: - """ + """Returns keys based on criterion. + Returns: Iterable[CriteriaKey]: This Criteria operates on unitary gates and Qubits. """ @@ -75,8 +76,8 @@ def get_keys(self, key_type: CriteriaKey) -> Union[CriteriaKeyResult, set[Any]]: return set() def to_dict(self) -> dict: - """ - Converts a dictionary representing an object of this class into an instance of this class. + """Converts a dictionary representing an object of this class into an instance of + this class. Returns: dict: A dictionary representing the serialized version of this Criteria. diff --git a/src/braket/circuits/noises.py b/src/braket/circuits/noises.py index 572489a9d..a8829f1a4 100644 --- a/src/braket/circuits/noises.py +++ b/src/braket/circuits/noises.py @@ -13,7 +13,7 @@ import itertools from collections.abc import Iterable -from typing import Any, Union +from typing import Any, ClassVar, Union import numpy as np @@ -52,7 +52,7 @@ class BitFlip(SingleProbabilisticNoise): - """Bit flip noise channel which transforms a density matrix :math:`\\rho` according to: + r"""Bit flip noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow (1-p) \\rho + p X \\rho X^{\\dagger} @@ -96,6 +96,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -127,9 +128,11 @@ def bit_flip(target: QubitSetInput, probability: float) -> Iterable[Instruction] for qubit in QubitSet(target) ] - def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> Noise: + """Takes in parameters and attempts to assign them to values. + + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. Returns: Noise: A new Noise object of the same type with the requested @@ -139,8 +142,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -155,7 +157,7 @@ def from_dict(cls, noise: dict) -> Noise: class PhaseFlip(SingleProbabilisticNoise): - """Phase flip noise channel which transforms a density matrix :math:`\\rho` according to: + r"""Phase flip noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow (1-p) \\rho + p X \\rho X^{\\dagger} @@ -199,6 +201,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -230,9 +233,11 @@ def phase_flip(target: QubitSetInput, probability: float) -> Iterable[Instructio for qubit in QubitSet(target) ] - def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> Noise: + """Takes in parameters and attempts to assign them to values. + + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. Returns: Noise: A new Noise object of the same type with the requested @@ -242,8 +247,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -258,7 +262,7 @@ def from_dict(cls, noise: dict) -> Noise: class PauliChannel(PauliNoise): - """Pauli noise channel which transforms a density matrix :math:`\\rho` according to: + r"""Pauli noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow (1-probX-probY-probZ) \\rho @@ -307,6 +311,7 @@ def __init__( probZ: Union[FreeParameterExpression, float], ): """Creates PauliChannel noise. + Args: probX (Union[FreeParameterExpression, float]): X rotation probability. probY (Union[FreeParameterExpression, float]): Y rotation probability. @@ -336,6 +341,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -376,8 +382,7 @@ def pauli_channel( ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -391,8 +396,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -411,7 +415,7 @@ def from_dict(cls, noise: dict) -> Noise: class Depolarizing(SingleProbabilisticNoise_34): - """Depolarizing noise channel which transforms a density matrix :math:`\\rho` according to: + r"""Depolarizing noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow (1-p) \\rho @@ -473,6 +477,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -507,8 +512,7 @@ def depolarizing(target: QubitSetInput, probability: float) -> Iterable[Instruct ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -518,8 +522,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -534,7 +537,7 @@ def from_dict(cls, noise: dict) -> Noise: class TwoQubitDepolarizing(SingleProbabilisticNoise_1516): - """Two-Qubit Depolarizing noise channel which transforms a + r"""Two-Qubit Depolarizing noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: @@ -605,10 +608,10 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ - SI = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=complex) SX = np.array([[0.0, 1.0], [1.0, 0.0]], dtype=complex) SY = np.array([[0.0, -1.0j], [1.0j, 0.0]], dtype=complex) @@ -652,8 +655,7 @@ def two_qubit_depolarizing( ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -663,8 +665,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -679,7 +680,7 @@ def from_dict(cls, noise: dict) -> Noise: class TwoQubitDephasing(SingleProbabilisticNoise_34): - """Two-Qubit Dephasing noise channel which transforms a + r"""Two-Qubit Dephasing noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: @@ -732,6 +733,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -771,8 +773,7 @@ def two_qubit_dephasing( ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -782,8 +783,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -798,7 +798,7 @@ def from_dict(cls, noise: dict) -> Noise: class TwoQubitPauliChannel(MultiQubitPauliNoise): - """Two-Qubit Pauli noise channel which transforms a + r"""Two-Qubit Pauli noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: @@ -855,14 +855,14 @@ class TwoQubitPauliChannel(MultiQubitPauliNoise): This noise channel is shown as `PC_2({"pauli_string": probability})` in circuit diagrams. """ - _paulis = { + _paulis: ClassVar = { "I": np.array([[1.0, 0.0], [0.0, 1.0]], dtype=complex), "X": np.array([[0.0, 1.0], [1.0, 0.0]], dtype=complex), "Y": np.array([[0.0, -1.0j], [1.0j, 0.0]], dtype=complex), "Z": np.array([[1.0, 0.0], [0.0, -1.0]], dtype=complex), } _tensor_products_strings = itertools.product(_paulis.keys(), repeat=2) - _names_list = ["".join(x) for x in _tensor_products_strings] + _names_list: ClassVar = ["".join(x) for x in _tensor_products_strings] def __init__(self, probabilities: dict[str, float]): super().__init__( @@ -877,6 +877,7 @@ def __init__(self, probabilities: dict[str, float]): def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -930,8 +931,7 @@ def two_qubit_pauli_channel( ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -945,8 +945,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -954,9 +953,10 @@ def from_dict(cls, noise: dict) -> Noise: Returns: Noise: A Noise object that represents the passed in dictionary. """ - probabilities = dict() - for pauli_string, prob in noise["probabilities"].items(): - probabilities[pauli_string] = _parameter_from_dict(prob) + probabilities = { + pauli_string: _parameter_from_dict(prob) + for pauli_string, prob in noise["probabilities"].items() + } return TwoQubitPauliChannel(probabilities=probabilities) @@ -964,7 +964,7 @@ def from_dict(cls, noise: dict) -> Noise: class AmplitudeDamping(DampingNoise): - """AmplitudeDamping noise channel which transforms a density matrix :math:`\\rho` according to: + r"""AmplitudeDamping noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow E_0 \\rho E_0^{\\dagger} + E_1 \\rho E_1^{\\dagger} @@ -1006,6 +1006,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -1038,8 +1039,7 @@ def amplitude_damping(target: QubitSetInput, gamma: float) -> Iterable[Instructi ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -1049,8 +1049,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -1065,7 +1064,7 @@ def from_dict(cls, noise: dict) -> Noise: class GeneralizedAmplitudeDamping(GeneralizedAmplitudeDampingNoise): - """Generalized AmplitudeDamping noise channel which transforms a + r"""Generalized AmplitudeDamping noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow E_0 \\rho E_0^{\\dagger} + E_1 \\rho E_1^{\\dagger} @@ -1131,6 +1130,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -1175,8 +1175,7 @@ def generalized_amplitude_damping( ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -1188,8 +1187,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -1207,7 +1205,7 @@ def from_dict(cls, noise: dict) -> Noise: class PhaseDamping(DampingNoise): - """Phase damping noise channel which transforms a density matrix :math:`\\rho` according to: + r"""Phase damping noise channel which transforms a density matrix :math:`\\rho` according to: .. math:: \\rho \\Rightarrow E_0 \\rho E_0^{\\dagger} + E_1 \\rho E_1^{\\dagger} @@ -1251,6 +1249,7 @@ def _to_openqasm( def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -1282,8 +1281,7 @@ def phase_damping(target: QubitSetInput, gamma: float) -> Iterable[Instruction]: ] def bind_values(self, **kwargs) -> Noise: - """ - Takes in parameters and attempts to assign them to values. + """Takes in parameters and attempts to assign them to values. Returns: Noise: A new Noise object of the same type with the requested @@ -1293,8 +1291,7 @@ def bind_values(self, **kwargs) -> Noise: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -1311,24 +1308,27 @@ def from_dict(cls, noise: dict) -> Noise: class Kraus(Noise): """User-defined noise channel that uses the provided matrices as Kraus operators This noise channel is shown as `NK` in circuit diagrams. - - Args: - matrices (Iterable[np.array]): A list of matrices that define a noise - channel. These matrices need to satisfy the requirement of CPTP map. - display_name (str): Name to be used for an instance of this general noise - channel for circuit diagrams. Defaults to `KR`. - - Raises: - ValueError: If any matrix in `matrices` is not a two-dimensional square - matrix, - or has a dimension length which is not a positive exponent of 2, - or the `matrices` do not satisfy CPTP condition. """ def __init__(self, matrices: Iterable[np.ndarray], display_name: str = "KR"): + """Inits `Kraus`. + + Args: + matrices (Iterable[ndarray]): A list of matrices that define a noise + channel. These matrices need to satisfy the requirement of CPTP map. + display_name (str): Name to be used for an instance of this general noise + channel for circuit diagrams. Defaults to `KR`. + + Raises: + ValueError: If any matrix in `matrices` is not a two-dimensional square + matrix, + or has a dimension length which is not a positive exponent of 2, + or the `matrices` do not satisfy CPTP condition. + + """ for matrix in matrices: verify_quantum_operator_matrix_dimensions(matrix) - if not int(np.log2(matrix.shape[0])) == int(np.log2(matrices[0].shape[0])): + if int(np.log2(matrix.shape[0])) != int(np.log2(matrices[0].shape[0])): raise ValueError(f"all matrices in {matrices} must have the same shape") self._matrices = [np.array(matrix, dtype=complex) for matrix in matrices] self._display_name = display_name @@ -1347,6 +1347,7 @@ def __init__(self, matrices: Iterable[np.ndarray], display_name: str = "KR"): def to_matrix(self) -> Iterable[np.ndarray]: """Returns a matrix representation of this noise. + Returns: Iterable[ndarray]: A list of matrix representations of this noise. """ @@ -1354,7 +1355,7 @@ def to_matrix(self) -> Iterable[np.ndarray]: def _to_jaqcd(self, target: QubitSet) -> Any: return ir.Kraus.construct( - targets=[qubit for qubit in target], + targets=list(target), matrices=Kraus._transform_matrix_to_ir(self._matrices), ) @@ -1414,8 +1415,7 @@ def kraus( ) def to_dict(self) -> dict: - """ - Converts this object into a dictionary representation. Not implemented at this time. + """Converts this object into a dictionary representation. Not implemented at this time. Returns: dict: Not implemented at this time.. @@ -1424,8 +1424,7 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, noise: dict) -> Noise: - """ - Converts a dictionary representation of this class into this class. + """Converts a dictionary representation of this class into this class. Args: noise(dict): The dictionary representation of this noise. @@ -1442,8 +1441,7 @@ def from_dict(cls, noise: dict) -> Noise: def _ascii_representation( noise: str, parameters: list[Union[FreeParameterExpression, float]] ) -> str: - """ - Generates a formatted ascii representation of a noise. + """Generates a formatted ascii representation of a noise. Args: noise (str): The name of the noise. @@ -1452,11 +1450,10 @@ def _ascii_representation( Returns: str: The ascii representation of the noise. """ - param_list = [] - for param in parameters: - param_list.append( - str(param) if isinstance(param, FreeParameterExpression) else "{:.2g}".format(param) - ) + param_list = [ + (str(param) if isinstance(param, FreeParameterExpression) else f"{param:.2g}") + for param in parameters + ] param_str = ",".join(param_list) return f"{noise}({param_str})" diff --git a/src/braket/circuits/observable.py b/src/braket/circuits/observable.py index 03cd0714e..d3f3fc862 100644 --- a/src/braket/circuits/observable.py +++ b/src/braket/circuits/observable.py @@ -31,8 +31,7 @@ class Observable(QuantumOperator): - """ - Class `Observable` to represent a quantum observable. + """Class `Observable` to represent a quantum observable. Objects of this type can be used as input to `ResultType.Sample`, `ResultType.Variance`, `ResultType.Expectation` to specify the measurement basis. @@ -67,7 +66,7 @@ def to_ir( Raises: ValueError: If the supplied `ir_type` is not supported, or if the supplied serialization - properties don't correspond to the `ir_type`. + properties don't correspond to the `ir_type`. """ if ir_type == IRType.JAQCD: return self._to_jaqcd() @@ -94,8 +93,7 @@ def _to_openqasm( serialization_properties: OpenQASMSerializationProperties, target: QubitSet | None = None, ) -> str: - """ - Returns the openqasm string representation of the result type. + """Returns the openqasm string representation of the result type. Args: serialization_properties (OpenQASMSerializationProperties): The serialization properties @@ -109,7 +107,8 @@ def _to_openqasm( @property def coefficient(self) -> int: - """ + """The coefficient of the observable. + Returns: int: coefficient value of the observable. """ @@ -118,6 +117,7 @@ def coefficient(self) -> int: @property def basis_rotation_gates(self) -> tuple[Gate, ...]: """Returns the basis rotation gates for this observable. + Returns: tuple[Gate, ...]: The basis rotation gates for this observable. """ @@ -126,8 +126,9 @@ def basis_rotation_gates(self) -> tuple[Gate, ...]: @property def eigenvalues(self) -> np.ndarray: """Returns the eigenvalues of this observable. + Returns: - ndarray: The eigenvalues of this observable. + np.ndarray: The eigenvalues of this observable. """ raise NotImplementedError @@ -154,13 +155,13 @@ def register_observable(cls, observable: Observable) -> None: """ setattr(cls, observable.__name__, observable) - def __matmul__(self, other) -> Observable.TensorProduct: + def __matmul__(self, other: Observable) -> Observable.TensorProduct: if isinstance(other, Observable): return Observable.TensorProduct([self, other]) raise ValueError("Can only perform tensor products between observables.") - def __mul__(self, other) -> Observable: + def __mul__(self, other: Observable) -> Observable: """Scalar multiplication""" if isinstance(other, numbers.Number): observable_copy = deepcopy(self) @@ -168,16 +169,16 @@ def __mul__(self, other) -> Observable: return observable_copy raise TypeError("Observable coefficients must be numbers.") - def __rmul__(self, other) -> Observable: + def __rmul__(self, other: Observable) -> Observable: return self * other - def __add__(self, other): + def __add__(self, other: Observable): if not isinstance(other, Observable): raise ValueError("Can only perform addition between observables.") return Observable.Sum([self, other]) - def __sub__(self, other): + def __sub__(self, other: Observable): if not isinstance(other, Observable): raise ValueError("Can only perform subtraction between observables.") @@ -186,15 +187,14 @@ def __sub__(self, other): def __repr__(self) -> str: return f"{self.name}('qubit_count': {self.qubit_count})" - def __eq__(self, other) -> bool: + def __eq__(self, other: Observable) -> bool: if isinstance(other, Observable): return self.name == other.name return NotImplemented class StandardObservable(Observable): - """ - Class `StandardObservable` to represent a Pauli-like quantum observable with + """Class `StandardObservable` to represent a Pauli-like quantum observable with eigenvalues of (+1, -1). """ diff --git a/src/braket/circuits/observables.py b/src/braket/circuits/observables.py index dc0ff0782..ae4f36a09 100644 --- a/src/braket/circuits/observables.py +++ b/src/braket/circuits/observables.py @@ -18,7 +18,7 @@ import math import numbers from copy import deepcopy -from typing import Union +from typing import ClassVar, Union import numpy as np @@ -37,9 +37,8 @@ class H(StandardObservable): """Hadamard operation as an observable.""" def __init__(self): - """ - Examples: - >>> Observable.H() + """Examples: + >>> Observable.H() """ super().__init__(ascii_symbols=["H"]) @@ -68,19 +67,18 @@ def to_matrix(self) -> np.ndarray: @property def basis_rotation_gates(self) -> tuple[Gate, ...]: - return tuple([Gate.Ry(-math.pi / 4)]) + return (Gate.Ry(-math.pi / 4),) Observable.register_observable(H) -class I(Observable): # noqa: E742, E261 +class I(Observable): # noqa: E742 """Identity operation as an observable.""" def __init__(self): - """ - Examples: - >>> Observable.I() + """Examples: + >>> Observable.I() """ super().__init__(qubit_count=1, ascii_symbols=["I"]) @@ -112,6 +110,7 @@ def basis_rotation_gates(self) -> tuple[Gate, ...]: @property def eigenvalues(self) -> np.ndarray: """Returns the eigenvalues of this observable. + Returns: np.ndarray: The eigenvalues of this observable. """ @@ -128,9 +127,8 @@ class X(StandardObservable): """Pauli-X operation as an observable.""" def __init__(self): - """ - Examples: - >>> Observable.X() + """Examples: + >>> Observable.X() """ super().__init__(ascii_symbols=["X"]) @@ -157,7 +155,7 @@ def to_matrix(self) -> np.ndarray: @property def basis_rotation_gates(self) -> tuple[Gate, ...]: - return tuple([Gate.H()]) + return (Gate.H(),) Observable.register_observable(X) @@ -167,9 +165,8 @@ class Y(StandardObservable): """Pauli-Y operation as an observable.""" def __init__(self): - """ - Examples: - >>> Observable.Y() + """Examples: + >>> Observable.Y() """ super().__init__(ascii_symbols=["Y"]) @@ -196,7 +193,7 @@ def to_matrix(self) -> np.ndarray: @property def basis_rotation_gates(self) -> tuple[Gate, ...]: - return tuple([Gate.Z(), Gate.S(), Gate.H()]) + return Gate.Z(), Gate.S(), Gate.H() Observable.register_observable(Y) @@ -206,9 +203,8 @@ class Z(StandardObservable): """Pauli-Z operation as an observable.""" def __init__(self): - """ - Examples: - >>> Observable.Z() + """Examples: + >>> Observable.Z() """ super().__init__(ascii_symbols=["Z"]) @@ -245,7 +241,8 @@ class TensorProduct(Observable): """Tensor product of observables""" def __init__(self, observables: list[Observable]): - """ + """Initializes a `TensorProduct`. + Args: observables (list[Observable]): List of observables for tensor product @@ -269,15 +266,14 @@ def __init__(self, observables: list[Observable]): flattened_observables = [] for obs in observables: if isinstance(obs, TensorProduct): - for nested_obs in obs.factors: - flattened_observables.append(nested_obs) + flattened_observables.extend(iter(obs.factors)) # make sure you don't lose coefficient of tensor product flattened_observables[-1] *= obs.coefficient elif isinstance(obs, Sum): raise TypeError("Sum observables not allowed in TensorProduct") else: flattened_observables.append(obs) - qubit_count = sum([obs.qubit_count for obs in flattened_observables]) + qubit_count = sum(obs.qubit_count for obs in flattened_observables) # aggregate all coefficients for the product, since aX @ bY == ab * X @ Y coefficient = np.prod([obs.coefficient for obs in flattened_observables]) unscaled_factors = tuple(obs._unscaled() for obs in flattened_observables) @@ -348,6 +344,7 @@ def to_matrix(self) -> np.ndarray: @property def basis_rotation_gates(self) -> tuple[Gate, ...]: """Returns the basis rotation gates for this observable. + Returns: tuple[Gate, ...]: The basis rotation gates for this observable. """ @@ -359,6 +356,7 @@ def basis_rotation_gates(self) -> tuple[Gate, ...]: @property def eigenvalues(self) -> np.ndarray: """Returns the eigenvalues of this observable. + Returns: np.ndarray: The eigenvalues of this observable. """ @@ -400,7 +398,7 @@ def eigenvalue(self, index: int) -> float: def __repr__(self): return "TensorProduct(" + ", ".join([repr(o) for o in self.factors]) + ")" - def __eq__(self, other): + def __eq__(self, other: TensorProduct): return self.matrix_equivalence(other) @staticmethod @@ -432,7 +430,8 @@ class Sum(Observable): """Sum of observables""" def __init__(self, observables: list[Observable], display_name: str = "Hamiltonian"): - """ + """Inits a `Sum`. + Args: observables (list[Observable]): List of observables for Sum display_name (str): Name to use for an instance of this Sum @@ -447,8 +446,7 @@ def __init__(self, observables: list[Observable], display_name: str = "Hamiltoni flattened_observables = [] for obs in observables: if isinstance(obs, Sum): - for nested_obs in obs.summands: - flattened_observables.append(nested_obs) + flattened_observables.extend(iter(obs.summands)) else: flattened_observables.append(obs) @@ -456,11 +454,11 @@ def __init__(self, observables: list[Observable], display_name: str = "Hamiltoni qubit_count = max(flattened_observables, key=lambda obs: obs.qubit_count).qubit_count super().__init__(qubit_count=qubit_count, ascii_symbols=[display_name] * qubit_count) - def __mul__(self, other) -> Observable: + def __mul__(self, other: numbers.Number) -> Observable: """Scalar multiplication""" if isinstance(other, numbers.Number): sum_copy = deepcopy(self) - for i, obs in enumerate(sum_copy.summands): + for i, _obs in enumerate(sum_copy.summands): sum_copy._summands[i]._coef *= other return sum_copy raise TypeError("Observable coefficients must be numbers.") @@ -514,7 +512,7 @@ def eigenvalue(self, index: int) -> float: def __repr__(self): return "Sum(" + ", ".join([repr(o) for o in self.summands]) + ")" - def __eq__(self, other): + def __eq__(self, other: Sum): return repr(self) == repr(other) @staticmethod @@ -529,12 +527,13 @@ class Hermitian(Observable): """Hermitian matrix as an observable.""" # Cache of eigenpairs - _eigenpairs = {} + _eigenpairs: ClassVar = {} def __init__(self, matrix: np.ndarray, display_name: str = "Hermitian"): - """ + """Inits a `Hermitian`. + Args: - matrix (numpy.ndarray): Hermitian matrix that defines the observable. + matrix (np.ndarray): Hermitian matrix that defines the observable. display_name (str): Name to use for an instance of this Hermitian matrix observable for circuit diagrams. Defaults to `Hermitian`. @@ -594,7 +593,7 @@ def _serialized_matrix_openqasm_matrix(self) -> str: def to_matrix(self) -> np.ndarray: return self.coefficient * self._matrix - def __eq__(self, other) -> bool: + def __eq__(self, other: Hermitian) -> bool: return self.matrix_equivalence(other) @property @@ -604,6 +603,7 @@ def basis_rotation_gates(self) -> tuple[Gate, ...]: @property def eigenvalues(self) -> np.ndarray: """Returns the eigenvalues of this observable. + Returns: np.ndarray: The eigenvalues of this observable. """ @@ -614,8 +614,7 @@ def eigenvalue(self, index: int) -> float: @staticmethod def _get_eigendecomposition(matrix: np.ndarray) -> dict[str, np.ndarray]: - """ - Decomposes the Hermitian matrix into its eigenvectors and associated eigenvalues. + """Decomposes the Hermitian matrix into its eigenvectors and associated eigenvalues. The eigendecomposition is cached so that if another Hermitian observable is created with the same matrix, the eigendecomposition doesn't have to be recalculated. @@ -649,8 +648,7 @@ def __repr__(self): def observable_from_ir(ir_observable: list[Union[str, list[list[list[float]]]]]) -> Observable: - """ - Create an observable from the IR observable list. This can be a tensor product of + """Create an observable from the IR observable list. This can be a tensor product of observables or a single observable. Args: @@ -661,9 +659,8 @@ def observable_from_ir(ir_observable: list[Union[str, list[list[list[float]]]]]) """ if len(ir_observable) == 1: return _observable_from_ir_list_item(ir_observable[0]) - else: - observable = TensorProduct([_observable_from_ir_list_item(obs) for obs in ir_observable]) - return observable + observable = TensorProduct([_observable_from_ir_list_item(obs) for obs in ir_observable]) + return observable def _observable_from_ir_list_item(observable: Union[str, list[list[list[float]]]]) -> Observable: @@ -684,4 +681,4 @@ def _observable_from_ir_list_item(observable: Union[str, list[list[list[float]]] ) return Hermitian(matrix) except Exception as e: - raise ValueError(f"Invalid observable specified: {observable} error: {e}") + raise ValueError(f"Invalid observable specified: {observable} error: {e}") from e diff --git a/src/braket/circuits/operator.py b/src/braket/circuits/operator.py index 06e72c8a8..ccd63ac37 100644 --- a/src/braket/circuits/operator.py +++ b/src/braket/circuits/operator.py @@ -22,14 +22,14 @@ class Operator(ABC): @abstractmethod def name(self) -> str: """The name of the operator. + Returns: str: The name of the operator. """ @abstractmethod def to_ir(self, *args, **kwargs) -> Any: - """ - Converts the operator into the canonical intermediate representation. + """Converts the operator into the canonical intermediate representation. If the operator is passed in a request, this method is called before it is passed. Returns: diff --git a/src/braket/circuits/quantum_operator.py b/src/braket/circuits/quantum_operator.py index df67bf199..b706e1822 100644 --- a/src/braket/circuits/quantum_operator.py +++ b/src/braket/circuits/quantum_operator.py @@ -25,7 +25,8 @@ class QuantumOperator(Operator): """A quantum operator is the definition of a quantum operation for a quantum device.""" def __init__(self, qubit_count: Optional[int], ascii_symbols: Sequence[str]): - """ + """Initializes a `QuantumOperator`. + Args: qubit_count (Optional[int]): Number of qubits this quantum operator acts on. If all instances of the operator act on the same number of qubits, this argument @@ -48,16 +49,15 @@ def __init__(self, qubit_count: Optional[int], ascii_symbols: Sequence[str]): ``fixed_qubit_count`` is implemented and and not equal to ``qubit_count``, or ``len(ascii_symbols) != qubit_count`` """ - fixed_qubit_count = self.fixed_qubit_count() if fixed_qubit_count is NotImplemented: self._qubit_count = qubit_count + elif qubit_count and qubit_count != fixed_qubit_count: + raise ValueError( + f"Provided qubit count {qubit_count}" + "does not equal fixed qubit count {fixed_qubit_count}" + ) else: - if qubit_count and qubit_count != fixed_qubit_count: - raise ValueError( - f"Provided qubit count {qubit_count}" - "does not equal fixed qubit count {fixed_qubit_count}" - ) self._qubit_count = fixed_qubit_count if not isinstance(self._qubit_count, int): @@ -79,8 +79,7 @@ def __init__(self, qubit_count: Optional[int], ascii_symbols: Sequence[str]): @staticmethod def fixed_qubit_count() -> int: - """ - Returns the number of qubits this quantum operator acts on, + """Returns the number of qubits this quantum operator acts on, if instances are guaranteed to act on the same number of qubits. If different instances can act on a different number of qubits, @@ -103,33 +102,45 @@ def ascii_symbols(self) -> tuple[str, ...]: @property def name(self) -> str: - """ - Returns the name of the quantum operator + """Returns the name of the quantum operator Returns: str: The name of the quantum operator as a string """ return self.__class__.__name__ - def to_ir(self, *args, **kwargs) -> Any: + def to_ir(self, *args: Any, **kwargs: Any) -> Any: """Returns IR representation of quantum operator. + Args: + *args (Any): Not Implemented. + **kwargs (Any): Not Implemented. + + Raises: + NotImplementError: Not Implemented. + Returns: Any: The the canonical intermediate representation of the operator. """ raise NotImplementedError("to_ir has not been implemented yet.") - def to_matrix(self, *args, **kwargs) -> np.ndarray: - """Returns a matrix representation of the quantum operator + def to_matrix(self, *args: Any, **kwargs: Any) -> np.ndarray: + """Returns a matrix representation of the quantum operator. + + Args: + *args (Any): Not Implemented. + **kwargs (Any): Not Implemented. + + Raises: + NotImplementError: Not Implemented. Returns: - ndarray: A matrix representation of the quantum operator + np.ndarray: A matrix representation of the quantum operator """ raise NotImplementedError("to_matrix has not been implemented yet.") def matrix_equivalence(self, other: QuantumOperator) -> bool: - """ - Whether the matrix form of two quantum operators are equivalent + """Whether the matrix form of two quantum operators are equivalent Args: other (QuantumOperator): Quantum operator instance to compare this quantum operator to diff --git a/src/braket/circuits/quantum_operator_helpers.py b/src/braket/circuits/quantum_operator_helpers.py index a264d0b38..10c22808e 100644 --- a/src/braket/circuits/quantum_operator_helpers.py +++ b/src/braket/circuits/quantum_operator_helpers.py @@ -18,8 +18,7 @@ def verify_quantum_operator_matrix_dimensions(matrix: np.ndarray) -> None: - """ - Verifies matrix is square and matrix dimensions are positive powers of 2, + """Verifies matrix is square and matrix dimensions are positive powers of 2, raising `ValueError` otherwise. Args: @@ -40,8 +39,7 @@ def verify_quantum_operator_matrix_dimensions(matrix: np.ndarray) -> None: def is_hermitian(matrix: np.ndarray) -> bool: - r""" - Whether matrix is Hermitian + r"""Whether matrix is Hermitian A square matrix :math:`U` is Hermitian if @@ -59,11 +57,10 @@ def is_hermitian(matrix: np.ndarray) -> bool: def is_square_matrix(matrix: np.ndarray) -> bool: - """ - Whether matrix is square, meaning it has exactly two dimensions and the dimensions are equal + """Whether matrix is square, meaning it has exactly two dimensions and the dimensions are equal Args: - matrix (ndarray): matrix to verify + matrix (np.ndarray): matrix to verify Returns: bool: If matrix is square @@ -72,8 +69,7 @@ def is_square_matrix(matrix: np.ndarray) -> bool: def is_unitary(matrix: np.ndarray) -> bool: - r""" - Whether matrix is unitary + r"""Whether matrix is unitary A square matrix :math:`U` is unitary if @@ -83,7 +79,7 @@ def is_unitary(matrix: np.ndarray) -> bool: and :math:`I` is the identity matrix. Args: - matrix (ndarray): matrix to verify + matrix (np.ndarray): matrix to verify Returns: bool: If matrix is unitary @@ -92,8 +88,7 @@ def is_unitary(matrix: np.ndarray) -> bool: def is_cptp(matrices: Iterable[np.ndarray]) -> bool: - """ - Whether a transformation defined by these matrics as Kraus operators is a + """Whether a transformation defined by these matrices as Kraus operators is a completely positive trace preserving (CPTP) map. This is the requirement for a transformation to be a quantum channel. Reference: Section 8.2.3 in Nielsen & Chuang (2010) 10th edition. @@ -104,21 +99,20 @@ def is_cptp(matrices: Iterable[np.ndarray]) -> bool: Returns: bool: If the matrices define a CPTP map. """ - E = sum([np.dot(matrix.T.conjugate(), matrix) for matrix in matrices]) + E = sum(np.dot(matrix.T.conjugate(), matrix) for matrix in matrices) return np.allclose(E, np.eye(*E.shape)) -@lru_cache() +@lru_cache def get_pauli_eigenvalues(num_qubits: int) -> np.ndarray: - """ - Get the eigenvalues of Pauli operators and their tensor products as + """Get the eigenvalues of Pauli operators and their tensor products as an immutable Numpy ndarray. Args: num_qubits (int): the number of qubits the operator acts on Returns: - ndarray: the eigenvalues of a Pauli product operator of the given size + np.ndarray: the eigenvalues of a Pauli product operator of the given size """ if num_qubits == 1: eigs = np.array([1, -1]) diff --git a/src/braket/circuits/result_type.py b/src/braket/circuits/result_type.py index c6877262e..3e9f0dfad 100644 --- a/src/braket/circuits/result_type.py +++ b/src/braket/circuits/result_type.py @@ -28,14 +28,14 @@ class ResultType: - """ - Class `ResultType` represents a requested result type for the circuit. + """Class `ResultType` represents a requested result type for the circuit. This class is considered the result type definition containing the metadata that defines what a requested result type is and what it does. """ def __init__(self, ascii_symbols: list[str]): - """ + """Initializes a `ResultType`. + Args: ascii_symbols (list[str]): ASCII string symbols for the result type. This is used when printing a diagram of circuits. @@ -43,7 +43,6 @@ def __init__(self, ascii_symbols: list[str]): Raises: ValueError: `ascii_symbols` is `None` """ - if ascii_symbols is None: raise ValueError("ascii_symbols must not be None") @@ -56,8 +55,7 @@ def ascii_symbols(self) -> list[str]: @property def name(self) -> str: - """ - Returns the name of the result type + """Returns the name of the result type Returns: str: The name of the result type as a string @@ -73,7 +71,7 @@ def to_ir( """Returns IR object of the result type Args: - ir_type(IRType) : The IRType to use for converting the result type object to its + ir_type(IRType): The IRType to use for converting the result type object to its IR representation. Defaults to IRType.JAQCD. serialization_properties (SerializationProperties | None): The serialization properties to use while serializing the object to the IR representation. The serialization @@ -84,7 +82,7 @@ def to_ir( Raises: ValueError: If the supplied `ir_type` is not supported, or if the supplied serialization - properties don't correspond to the `ir_type`. + properties don't correspond to the `ir_type`. """ if ir_type == IRType.JAQCD: return self._to_jaqcd() @@ -105,13 +103,15 @@ def _to_jaqcd(self) -> Any: raise NotImplementedError("to_jaqcd has not been implemented yet.") def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str: - """ - Returns the openqasm string representation of the result type. + """Returns the openqasm string representation of the result type. Args: serialization_properties (OpenQASMSerializationProperties): The serialization properties to use while serializing the object to the IR representation. + Raises: + NotImplementedError: not implemented. + Returns: str: Representing the openqasm representation of the result type. """ @@ -122,8 +122,7 @@ def copy( target_mapping: dict[QubitInput, QubitInput] | None = None, target: QubitSetInput | None = None, ) -> ResultType: - """ - Return a shallow copy of the result type. + """Return a shallow copy of the result type. Note: If `target_mapping` is specified, then `self.target` is mapped to the specified @@ -180,8 +179,7 @@ def __hash__(self) -> int: class ObservableResultType(ResultType): - """ - Result types with observables and targets. + """Result types with observables and targets. If no targets are specified, the observable must only operate on 1 qubit and it will be applied to all qubits in parallel. Otherwise, the number of specified targets must be equivalent to the number of qubits the observable can be applied to. @@ -192,7 +190,8 @@ class ObservableResultType(ResultType): def __init__( self, ascii_symbols: list[str], observable: Observable, target: QubitSetInput | None = None ): - """ + """Initializes an `ObservableResultType`. + Args: ascii_symbols (list[str]): ASCII string symbols for the result type. This is used when printing a diagram of circuits. @@ -215,29 +214,28 @@ def __init__( raise ValueError( f"Observable {self._observable} must only operate on 1 qubit for target=None" ) - else: - if isinstance(observable, Sum): # nested target - if len(target) != len(observable.summands): + elif isinstance(observable, Sum): # nested target + if len(target) != len(observable.summands): + raise ValueError( + "Sum observable's target shape must be a nested list where each term's " + "target length is equal to the observable term's qubits count." + ) + self._target = [QubitSet(term_target) for term_target in target] + for term_target, obs in zip(self._target, observable.summands): + if obs.qubit_count != len(term_target): raise ValueError( "Sum observable's target shape must be a nested list where each term's " "target length is equal to the observable term's qubits count." ) - self._target = [QubitSet(term_target) for term_target in target] - for term_target, obs in zip(target, observable.summands): - if obs.qubit_count != len(term_target): - raise ValueError( - "Sum observable's target shape must be a nested list where each term's " - "target length is equal to the observable term's qubits count." - ) - elif self._observable.qubit_count != len(self._target): - raise ValueError( - f"Observable's qubit count {self._observable.qubit_count} and " - f"the size of the target qubit set {self._target} must be equal" - ) - elif self._observable.qubit_count != len(self.ascii_symbols): - raise ValueError( - "Observable's qubit count and the number of ASCII symbols must be equal" - ) + elif self._observable.qubit_count != len(self._target): + raise ValueError( + f"Observable's qubit count {self._observable.qubit_count} and " + f"the size of the target qubit set {self._target} must be equal" + ) + elif self._observable.qubit_count != len(self.ascii_symbols): + raise ValueError( + "Observable's qubit count and the number of ASCII symbols must be equal" + ) @property def observable(self) -> Observable: @@ -250,12 +248,13 @@ def target(self) -> QubitSet: @target.setter def target(self, target: QubitSetInput) -> None: """Sets the target. + Args: target (QubitSetInput): The new target. """ self._target = QubitSet(target) - def __eq__(self, other) -> bool: + def __eq__(self, other: ObservableResultType) -> bool: if isinstance(other, ObservableResultType): return ( self.name == other.name @@ -275,8 +274,7 @@ def __hash__(self) -> int: class ObservableParameterResultType(ObservableResultType): - """ - Result types with observables, targets and parameters. + """Result types with observables, targets and parameters. If no targets are specified, the observable must only operate on 1 qubit and it will be applied to all qubits in parallel. Otherwise, the number of specified targets must be equivalent to the number of qubits the observable can be applied to. diff --git a/src/braket/circuits/result_types.py b/src/braket/circuits/result_types.py index 0a1a1c630..325fa8f46 100644 --- a/src/braket/circuits/result_types.py +++ b/src/braket/circuits/result_types.py @@ -41,8 +41,7 @@ class StateVector(ResultType): - """ - The full state vector as a requested result type. + """The full state vector as a requested result type. This is available on simulators only when `shots=0`. """ @@ -68,10 +67,8 @@ def state_vector() -> ResultType: """ return ResultType.StateVector() - def __eq__(self, other) -> bool: - if isinstance(other, StateVector): - return True - return False + def __eq__(self, other: StateVector) -> bool: + return isinstance(other, StateVector) def __copy__(self) -> StateVector: return type(self)() @@ -86,13 +83,13 @@ def __hash__(self) -> int: class DensityMatrix(ResultType): - """ - The full density matrix as a requested result type. + """The full density matrix as a requested result type. This is available on simulators only when `shots=0`. """ def __init__(self, target: QubitSetInput | None = None): - """ + """Inits a `DensityMatrix`. + Args: target (QubitSetInput | None): The target qubits of the reduced density matrix. Default is `None`, and the @@ -112,6 +109,7 @@ def target(self) -> QubitSet: @target.setter def target(self, target: QubitSetInput) -> None: """Sets the target qubit set. + Args: target (QubitSetInput): The target qubit set. """ @@ -136,6 +134,7 @@ def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties @circuit.subroutine(register=True) def density_matrix(target: QubitSetInput | None = None) -> ResultType: """Registers this function into the circuit class. + Args: target (QubitSetInput | None): The target qubits of the reduced density matrix. Default is `None`, and the @@ -149,7 +148,7 @@ def density_matrix(target: QubitSetInput | None = None) -> ResultType: """ return ResultType.DensityMatrix(target=target) - def __eq__(self, other) -> bool: + def __eq__(self, other: DensityMatrix) -> bool: if isinstance(other, DensityMatrix): return self.target == other.target return False @@ -170,8 +169,7 @@ def __hash__(self) -> int: class AdjointGradient(ObservableParameterResultType): - """ - The gradient of the expectation value of the provided observable, applied to target, + """The gradient of the expectation value of the provided observable, applied to target, with respect to the given parameter. """ @@ -181,7 +179,8 @@ def __init__( target: list[QubitSetInput] | None = None, parameters: list[Union[str, FreeParameter]] | None = None, ): - """ + """Inits an `AdjointGradient`. + Args: observable (Observable): The expectation value of this observable is the function against which parameters in the gradient are differentiated. @@ -197,6 +196,7 @@ def __init__( ValueError: If the observable's qubit count does not equal the number of target qubits, or if `target=None` and the observable's qubit count is not 1. + Examples: >>> ResultType.AdjointGradient(observable=Observable.Z(), target=0, parameters=["alpha", "beta"]) @@ -209,7 +209,6 @@ def __init__( >>> parameters=["alpha", "beta"], >>> ) """ - if isinstance(observable, Sum): target_qubits = reduce(QubitSet.union, map(QubitSet, target), QubitSet()) else: @@ -274,13 +273,13 @@ def adjoint_gradient( class Amplitude(ResultType): - """ - The amplitude of the specified quantum states as a requested result type. + """The amplitude of the specified quantum states as a requested result type. This is available on simulators only when `shots=0`. """ def __init__(self, state: list[str]): - """ + """Initializes an `Amplitude`. + Args: state (list[str]): list of quantum states as strings with "0" and "1" @@ -332,10 +331,8 @@ def amplitude(state: list[str]) -> ResultType: """ return ResultType.Amplitude(state=state) - def __eq__(self, other): - if isinstance(other, Amplitude): - return self.state == other.state - return False + def __eq__(self, other: Amplitude): + return self.state == other.state if isinstance(other, Amplitude) else False def __repr__(self): return f"Amplitude(state={self.state})" @@ -362,7 +359,8 @@ class Probability(ResultType): """ def __init__(self, target: QubitSetInput | None = None): - """ + """Inits a `Probability`. + Args: target (QubitSetInput | None): The target qubits that the result type is requested for. Default is `None`, which means all qubits for the @@ -382,6 +380,7 @@ def target(self) -> QubitSet: @target.setter def target(self, target: QubitSetInput) -> None: """Sets the target qubit set. + Args: target (QubitSetInput): The target qubit set. """ @@ -420,10 +419,8 @@ def probability(target: QubitSetInput | None = None) -> ResultType: """ return ResultType.Probability(target=target) - def __eq__(self, other) -> bool: - if isinstance(other, Probability): - return self.target == other.target - return False + def __eq__(self, other: Probability) -> bool: + return self.target == other.target if isinstance(other, Probability) else False def __repr__(self) -> str: return f"Probability(target={self.target})" @@ -452,16 +449,14 @@ class Expectation(ObservableResultType): """ def __init__(self, observable: Observable, target: QubitSetInput | None = None): - """ + """Inits an `Expectation`. + Args: observable (Observable): the observable for the result type target (QubitSetInput | None): Target qubits that the result type is requested for. Default is `None`, which means the observable must operate only on 1 qubit and it is applied to all qubits in parallel. - Raises: - ValueError: If the observable's qubit count does not equal the number of target - qubits, or if `target=None` and the observable's qubit count is not 1. Examples: >>> ResultType.Expectation(observable=Observable.Z(), target=0) @@ -527,17 +522,14 @@ class Sample(ObservableResultType): """ def __init__(self, observable: Observable, target: QubitSetInput | None = None): - """ + """Inits a `Sample`. + Args: observable (Observable): the observable for the result type target (QubitSetInput | None): Target qubits that the result type is requested for. Default is `None`, which means the observable must operate only on 1 qubit and it is applied to all qubits in parallel. - Raises: - ValueError: If the observable's qubit count is not equal to the number of target - qubits, or if `target=None` and the observable's qubit count is not 1. - Examples: >>> ResultType.Sample(observable=Observable.Z(), target=0) @@ -603,7 +595,8 @@ class Variance(ObservableResultType): """ def __init__(self, observable: Observable, target: QubitSetInput | None = None): - """ + """Inits a `Variance`. + Args: observable (Observable): the observable for the result type target (QubitSetInput | None): Target qubits that the diff --git a/src/braket/circuits/serialization.py b/src/braket/circuits/serialization.py index 1e0826e80..fdee7d144 100644 --- a/src/braket/circuits/serialization.py +++ b/src/braket/circuits/serialization.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -23,8 +24,7 @@ class IRType(str, Enum): class QubitReferenceType(str, Enum): - """ - Defines how qubits should be referenced in the generated OpenQASM string. + """Defines how qubits should be referenced in the generated OpenQASM string. See https://qiskit.github.io/openqasm/language/types.html#quantum-types for details. """ @@ -33,10 +33,29 @@ class QubitReferenceType(str, Enum): PHYSICAL = "PHYSICAL" +class SerializableProgram(ABC): + @abstractmethod + def to_ir( + self, + ir_type: IRType = IRType.OPENQASM, + ) -> str: + """Serializes the program into an intermediate representation. + + Args: + ir_type (IRType): The IRType to use for converting the program to its + IR representation. Defaults to IRType.OPENQASM. + + Raises: + ValueError: Raised if the supplied `ir_type` is not supported. + + Returns: + str: A representation of the program in the `ir_type` format. + """ + + @dataclass class OpenQASMSerializationProperties: - """ - Properties for serializing a circuit to OpenQASM. + """Properties for serializing a circuit to OpenQASM. qubit_reference_type (QubitReferenceType): determines whether to use logical qubits or physical qubits (q[i] vs $i). @@ -46,6 +65,7 @@ class OpenQASMSerializationProperties: def format_target(self, target: int) -> str: """Format a target qubit to the appropriate OpenQASM representation. + Args: target (int): The target qubit. diff --git a/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py b/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py new file mode 100644 index 000000000..4a7c9565c --- /dev/null +++ b/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py @@ -0,0 +1,197 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +from functools import reduce +from typing import Literal, Union + +import braket.circuits.circuit as cir +from braket.circuits.compiler_directive import CompilerDirective +from braket.circuits.gate import Gate +from braket.circuits.instruction import Instruction +from braket.circuits.result_type import ResultType +from braket.circuits.text_diagram_builders.text_circuit_diagram import TextCircuitDiagram +from braket.registers.qubit_set import QubitSet + + +class AsciiCircuitDiagram(TextCircuitDiagram): + """Builds ASCII string circuit diagrams.""" + + @staticmethod + def build_diagram(circuit: cir.Circuit) -> str: + """Build a text circuit diagram. + + Args: + circuit (Circuit): Circuit for which to build a diagram. + + Returns: + str: string circuit diagram. + """ + return AsciiCircuitDiagram._build(circuit) + + @classmethod + def _vertical_delimiter(cls) -> str: + """Character that connects qubits of multi-qubit gates.""" + return "|" + + @classmethod + def _qubit_line_character(cls) -> str: + """Character used for the qubit line.""" + return "-" + + @classmethod + def _box_pad(cls) -> int: + """number of blank space characters around the gate name.""" + return 0 + + @classmethod + def _qubit_line_spacing_above(cls) -> int: + """number of empty lines above the qubit line.""" + return 1 + + @classmethod + def _qubit_line_spacing_below(cls) -> int: + """number of empty lines below the qubit line.""" + return 0 + + @classmethod + def _duplicate_time_at_bottom(cls, lines: str) -> None: + # duplicate times after an empty line + lines.append(lines[0]) + + @classmethod + def _create_diagram_column( + cls, + circuit_qubits: QubitSet, + items: list[Union[Instruction, ResultType]], + global_phase: float | None = None, + ) -> str: + """Return a column in the ASCII string diagram of the circuit for a given list of items. + + Args: + circuit_qubits (QubitSet): qubits in circuit + items (list[Union[Instruction, ResultType]]): list of instructions or result types + global_phase (float | None): the integrated global phase up to this column + + Returns: + str: an ASCII string diagram for the specified moment in time for a column. + """ + symbols = {qubit: cls._qubit_line_character() for qubit in circuit_qubits} + connections = {qubit: "none" for qubit in circuit_qubits} + + for item in items: + if isinstance(item, ResultType) and not item.target: + target_qubits = circuit_qubits + control_qubits = QubitSet() + target_and_control = target_qubits.union(control_qubits) + qubits = circuit_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits) + elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective): + target_qubits = circuit_qubits + control_qubits = QubitSet() + target_and_control = target_qubits.union(control_qubits) + qubits = circuit_qubits + ascii_symbol = item.ascii_symbols[0] + marker = "*" * len(ascii_symbol) + num_after = len(circuit_qubits) - 1 + after = ["|"] * (num_after - 1) + ([marker] if num_after else []) + ascii_symbols = [ascii_symbol, *after] + elif ( + isinstance(item, Instruction) + and isinstance(item.operator, Gate) + and item.operator.name == "GPhase" + ): + target_qubits = circuit_qubits + control_qubits = QubitSet() + target_and_control = QubitSet() + qubits = circuit_qubits + ascii_symbols = cls._qubit_line_character() * len(circuit_qubits) + else: + if isinstance(item.target, list): + target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) + else: + target_qubits = item.target + control_qubits = getattr(item, "control", QubitSet()) + control_state = getattr(item, "control_state", "1" * len(control_qubits)) + map_control_qubit_states = dict(zip(control_qubits, control_state)) + + target_and_control = target_qubits.union(control_qubits) + qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) + + ascii_symbols = item.ascii_symbols + + for qubit in qubits: + # Determine if the qubit is part of the item or in the middle of a + # multi qubit item. + if qubit in target_qubits: + item_qubit_index = [ + index for index, q in enumerate(target_qubits) if q == qubit + ][0] + power_string = ( + f"^{power}" + if ( + (power := getattr(item, "power", 1)) != 1 + # this has the limitation of not printing the power + # when a user has a gate genuinely named C, but + # is necessary to enable proper printing of custom + # gates with built-in control qubits + and ascii_symbols[item_qubit_index] != "C" + ) + else "" + ) + symbols[qubit] = ( + f"({ascii_symbols[item_qubit_index]}{power_string})" + if power_string + else ascii_symbols[item_qubit_index] + ) + elif qubit in control_qubits: + symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N" + else: + symbols[qubit] = "|" + + # Set the margin to be a connector if not on the first qubit + if target_and_control and qubit != min(target_and_control): + connections[qubit] = "above" + + output = cls._create_output(symbols, connections, circuit_qubits, global_phase) + return output + + # Ignore flake8 issue caused by Literal["above", "below", "both", "none"] + # flake8: noqa: BCS005 + @classmethod + def _draw_symbol( + cls, symbol: str, symbols_width: int, connection: Literal["above", "below", "both", "none"] + ) -> str: + """Create a string representing the symbol. + + Args: + symbol (str): the gate name + symbols_width (int): size of the expected output. The output will be filled with + cls._qubit_line_character() if needed. + connection (Literal["above", "below", "both", "none"]): character indicating + if the gate also involve a qubit with a lower index. + + Returns: + str: a string representing the symbol. + """ + connection_char = cls._vertical_delimiter() if connection in ["above"] else " " + output = "{0:{width}}\n".format( + connection_char, width=symbols_width + 1 + ) + "{0:{fill}{align}{width}}\n".format( + symbol, + fill=cls._qubit_line_character(), + align="<", + width=symbols_width + 1, + ) + return output diff --git a/src/braket/circuits/text_diagram_builders/text_circuit_diagram.py b/src/braket/circuits/text_diagram_builders/text_circuit_diagram.py new file mode 100644 index 000000000..ad30b34a7 --- /dev/null +++ b/src/braket/circuits/text_diagram_builders/text_circuit_diagram.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Literal, Union + +import braket.circuits.circuit as cir +from braket.circuits.circuit_diagram import CircuitDiagram +from braket.circuits.instruction import Instruction +from braket.circuits.moments import MomentType +from braket.circuits.result_type import ResultType +from braket.circuits.text_diagram_builders.text_circuit_diagram_utils import ( + _add_footers, + _categorize_result_types, + _compute_moment_global_phase, + _group_items, + _prepare_qubit_identifier_column, + _unite_strings, +) +from braket.registers.qubit import Qubit +from braket.registers.qubit_set import QubitSet + + +class TextCircuitDiagram(CircuitDiagram, ABC): + """Abstract base class for text circuit diagrams.""" + + @classmethod + @abstractmethod + def _vertical_delimiter(cls) -> str: + """Character that connects qubits of multi-qubit gates.""" + + @classmethod + @abstractmethod + def _qubit_line_character(cls) -> str: + """Character used for the qubit line.""" + + @classmethod + @abstractmethod + def _box_pad(cls) -> int: + """number of blank space characters around the gate name.""" + + @classmethod + @abstractmethod + def _qubit_line_spacing_above(cls) -> int: + """number of empty lines above the qubit line.""" + + @classmethod + @abstractmethod + def _qubit_line_spacing_below(cls) -> int: + """number of empty lines below the qubit line.""" + + @classmethod + @abstractmethod + def _create_diagram_column( + cls, + circuit_qubits: QubitSet, + items: list[Instruction | ResultType], + global_phase: float | None = None, + ) -> str: + """Return a column in the string diagram of the circuit for a given list of items. + + Args: + circuit_qubits (QubitSet): qubits in circuit + items (list[Instruction | ResultType]): list of instructions or result types + global_phase (float | None): the integrated global phase up to this column + + Returns: + str: a string diagram for the specified moment in time for a column. + """ + + # Ignore flake8 issue caused by Literal["above", "below", "both", "none"] + # flake8: noqa: BCS005 + @classmethod + @abstractmethod + def _draw_symbol( + cls, + symbol: str, + symbols_width: int, + connection: Literal["above", "below", "both", "none"], + ) -> str: + """Create a string representing the symbol inside a box. + + Args: + symbol (str): the gate name + symbols_width (int): size of the expected output. The output will be filled with + cls._qubit_line_character() if needed. + connection (Literal["above", "below", "both", "none"]): specifies if a connection + will be drawn above and/or below the box. + + Returns: + str: a string representing the symbol. + """ + + @classmethod + def _build(cls, circuit: cir.Circuit) -> str: + """Build a text circuit diagram. + + The procedure follows as: + 1. Prepare the first column composed of the qubit identifiers + 2. Construct the circuit as a list of columns by looping through the + time slices. A column is a string with rows separated via '\n' + a. compute the instantaneous global phase + b. create the column corresponding to the current moment + 3. Add result types at the end of the circuit + 4. Join the columns to get a list of qubit lines + 5. Add a list of optional parameters: + a. the total global phase + b. results types that do not have any target such as statevector + c. the list of unassigned parameters + + Args: + circuit (Circuit): Circuit for which to build a diagram. + + Returns: + str: string circuit diagram. + """ + if not circuit.instructions: + return "" + + if all(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments): + return f"Global phase: {circuit.global_phase}" + + circuit_qubits = circuit.qubits + circuit_qubits.sort() + + y_axis_str, global_phase = _prepare_qubit_identifier_column( + circuit, + circuit_qubits, + cls._vertical_delimiter(), + cls._qubit_line_character(), + cls._qubit_line_spacing_above(), + cls._qubit_line_spacing_below(), + ) + + column_strs = [] + + global_phase, additional_result_types = cls._build_columns( + circuit, circuit_qubits, global_phase, column_strs + ) + + # Unite strings + lines = _unite_strings(y_axis_str, column_strs) + cls._duplicate_time_at_bottom(lines) + + return _add_footers(lines, circuit, global_phase, additional_result_types) + + @classmethod + def _build_columns( + cls, + circuit: cir.Circuit, + circuit_qubits: QubitSet, + global_phase: float | None, + column_strs: list, + ) -> tuple[float | None, list[str]]: + time_slices = circuit.moments.time_slices() + + # Moment columns + for time, instructions in time_slices.items(): + global_phase = _compute_moment_global_phase(global_phase, instructions) + moment_str = cls._create_diagram_column_set( + str(time), circuit_qubits, instructions, global_phase + ) + column_strs.append(moment_str) + + # Result type columns + additional_result_types, target_result_types = _categorize_result_types( + circuit.result_types + ) + if target_result_types: + column_strs.append( + cls._create_diagram_column_set( + "Result Types", circuit_qubits, target_result_types, global_phase + ) + ) + return global_phase, additional_result_types + + @classmethod + def _create_diagram_column_set( + cls, + col_title: str, + circuit_qubits: QubitSet, + items: list[Union[Instruction, ResultType]], + global_phase: float | None, + ) -> str: + """Return a set of columns in the string diagram of the circuit for a list of items. + + Args: + col_title (str): title of column set + circuit_qubits (QubitSet): qubits in circuit + items (list[Union[Instruction, ResultType]]): list of instructions or result types + global_phase (float | None): the integrated global phase up to this set + + Returns: + str: A string diagram for the column set. + """ + + # Group items to separate out overlapping multi-qubit items + groupings = _group_items(circuit_qubits, items) + + column_strs = [ + cls._create_diagram_column(circuit_qubits, grouping[1], global_phase) + for grouping in groupings + ] + + # Unite column strings + lines = _unite_strings(column_strs[0], column_strs[1:]) + + # Adjust for column title width + col_title_width = len(col_title) + symbols_width = len(lines[0]) - 1 + if symbols_width < col_title_width: + diff = col_title_width - symbols_width + for i in range(len(lines) - 1): + if lines[i].endswith(cls._qubit_line_character()): + lines[i] += cls._qubit_line_character() * diff + else: + lines[i] += " " + + first_line = "{:^{width}}{vdelim}\n".format( + col_title, width=len(lines[0]) - 1, vdelim=cls._vertical_delimiter() + ) + + return first_line + "\n".join(lines) + + @classmethod + def _create_output( + cls, + symbols: dict[Qubit, str], + margins: dict[Qubit, str], + qubits: QubitSet, + global_phase: float | None, + ) -> str: + """Creates the output for a single column: + a. If there was one or more gphase gate, create a first line with the total global + phase shift ending with the _vertical_delimiter() class attribute, e.g. 0.14| + b. for each qubit, append the text representation produces by cls._draw_symbol + + Args: + symbols (dict[Qubit, str]): dictionary of the gate name for each qubit + margins (dict[Qubit, str]): map of the qubit interconnections. Specific to the + `_draw_symbol` classmethod. + qubits (QubitSet): set of the circuit qubits + global_phase (float | None): total global phase shift added during the moment + + Returns: + str: a string representing a diagram column. + """ + symbols_width = max(len(symbol) for symbol in symbols.values()) + cls._box_pad() + output = "" + + if global_phase is not None: + global_phase_str = ( + f"{global_phase:.2f}" if isinstance(global_phase, float) else str(global_phase) + ) + symbols_width = max([symbols_width, len(global_phase_str)]) + output += "{0:{fill}{align}{width}}{vdelim}\n".format( + global_phase_str, + fill=" ", + align="^", + width=symbols_width, + vdelim=cls._vertical_delimiter(), + ) + + for qubit in qubits: + output += cls._draw_symbol(symbols[qubit], symbols_width, margins[qubit]) + return output diff --git a/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py b/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py new file mode 100644 index 000000000..f261b00b6 --- /dev/null +++ b/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py @@ -0,0 +1,199 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +from functools import reduce +from typing import Union + +import braket.circuits.circuit as cir +from braket.circuits.compiler_directive import CompilerDirective +from braket.circuits.gate import Gate +from braket.circuits.instruction import Instruction +from braket.circuits.measure import Measure +from braket.circuits.moments import MomentType +from braket.circuits.noise import Noise +from braket.circuits.result_type import ResultType +from braket.registers.qubit_set import QubitSet + + +def _add_footers( + lines: list, + circuit: cir.Circuit, + global_phase: float | None, + additional_result_types: list[str], +) -> str: + if global_phase: + lines.append(f"\nGlobal phase: {global_phase}") + + # Additional result types line on bottom + if additional_result_types: + lines.append(f"\nAdditional result types: {', '.join(additional_result_types)}") + + # A list of parameters in the circuit to the currently assigned values. + if circuit.parameters: + lines.append( + "\nUnassigned parameters: " + f"{sorted(circuit.parameters, key=lambda param: param.name)}." + ) + + return "\n".join(lines) + + +def _prepare_qubit_identifier_column( + circuit: cir.Circuit, + circuit_qubits: QubitSet, + vdelim: str, + qubit_line_char: str, + line_spacing_before: int, + line_spacing_after: int, +) -> tuple[str, float | None]: + # Y Axis Column + y_axis_width = len(str(int(max(circuit_qubits)))) + y_axis_str = "{0:{width}} : {vdelim}\n".format("T", width=y_axis_width + 1, vdelim=vdelim) + + global_phase = None + if any(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments): + y_axis_str += "{0:{width}} : {vdelim}\n".format("GP", width=y_axis_width, vdelim=vdelim) + global_phase = 0 + + for qubit in circuit_qubits: + for _ in range(line_spacing_before): + y_axis_str += "{0:{width}}\n".format(" ", width=y_axis_width + 5) + + y_axis_str += "q{0:{width}} : {qubit_line_char}\n".format( + str(int(qubit)), + width=y_axis_width, + qubit_line_char=qubit_line_char, + ) + + for _ in range(line_spacing_after): + y_axis_str += "{0:{width}}\n".format(" ", width=y_axis_width + 5) + return y_axis_str, global_phase + + +def _unite_strings(first_column: str, column_strs: list[str]) -> list: + lines = first_column.split("\n") + for col_str in column_strs: + for i, line_in_col in enumerate(col_str.split("\n")): + lines[i] += line_in_col + return lines + + +def _compute_moment_global_phase( + global_phase: float | None, items: list[Instruction] +) -> float | None: + """ + Compute the integrated phase at a certain moment. + + Args: + global_phase (float | None): The integrated phase up to the computed moment + items (list[Instruction]): list of instructions + + Returns: + float | None: The updated integrated phase. + """ + moment_phase = sum( + item.operator.angle + for item in items + if ( + isinstance(item, Instruction) + and isinstance(item.operator, Gate) + and item.operator.name == "GPhase" + ) + ) + return global_phase + moment_phase if global_phase is not None else None + + +def _group_items( + circuit_qubits: QubitSet, + items: list[Union[Instruction, ResultType]], +) -> list[tuple[QubitSet, list[Instruction]]]: + """ + Group instructions in a moment + + Args: + circuit_qubits (QubitSet): set of qubits in circuit + items (list[Union[Instruction, ResultType]]): list of instructions or result types + + Returns: + list[tuple[QubitSet, list[Instruction]]]: list of grouped instructions or result types. + """ + groupings = [] + for item in items: + # Can only print Gate, Noise and Measure operators for instructions at the moment + if isinstance(item, Instruction) and not isinstance( + item.operator, (Gate, Noise, CompilerDirective, Measure) + ): + continue + + # As a zero-qubit gate, GPhase can be grouped with anything. We set qubit_range + # to an empty list and we just add it to the first group below. + if ( + isinstance(item, Instruction) + and isinstance(item.operator, Gate) + and item.operator.name == "GPhase" + ): + qubit_range = QubitSet() + elif (isinstance(item, ResultType) and not item.target) or ( + isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective) + ): + qubit_range = circuit_qubits + else: + if isinstance(item.target, list): + target = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) + else: + target = item.target + control = getattr(item, "control", QubitSet()) + target_and_control = target.union(control) + qubit_range = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) + + found_grouping = False + for group in groupings: + qubits_added = group[0] + instr_group = group[1] + # Take into account overlapping multi-qubit gates + if not qubits_added.intersection(set(qubit_range)): + instr_group.append(item) + qubits_added.update(qubit_range) + found_grouping = True + break + + if not found_grouping: + groupings.append((qubit_range, [item])) + + return groupings + + +def _categorize_result_types( + result_types: list[ResultType], +) -> tuple[list[str], list[ResultType]]: + """ + Categorize result types into result types with target and those without. + + Args: + result_types (list[ResultType]): list of result types + + Returns: + tuple[list[str], list[ResultType]]: first element is a list of result types + without `target` attribute; second element is a list of result types with + `target` attribute + """ + additional_result_types = [] + target_result_types = [] + for result_type in result_types: + if hasattr(result_type, "target"): + target_result_types.append(result_type) + else: + additional_result_types.extend(result_type.ascii_symbols) + return additional_result_types, target_result_types diff --git a/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py b/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py new file mode 100644 index 000000000..85567de28 --- /dev/null +++ b/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py @@ -0,0 +1,278 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import annotations + +from functools import reduce +from typing import Literal + +import braket.circuits.circuit as cir +from braket.circuits.compiler_directive import CompilerDirective +from braket.circuits.gate import Gate +from braket.circuits.instruction import Instruction +from braket.circuits.result_type import ResultType +from braket.circuits.text_diagram_builders.text_circuit_diagram import TextCircuitDiagram +from braket.registers.qubit import Qubit +from braket.registers.qubit_set import QubitSet + + +class UnicodeCircuitDiagram(TextCircuitDiagram): + """Builds string circuit diagrams using box-drawing characters.""" + + @staticmethod + def build_diagram(circuit: cir.Circuit) -> str: + """Build a text circuit diagram. + + Args: + circuit (Circuit): Circuit for which to build a diagram. + + Returns: + str: string circuit diagram. + """ + return UnicodeCircuitDiagram._build(circuit) + + @classmethod + def _vertical_delimiter(cls) -> str: + """Character that connects qubits of multi-qubit gates.""" + return "│" + + @classmethod + def _qubit_line_character(cls) -> str: + """Character used for the qubit line.""" + return "─" + + @classmethod + def _box_pad(cls) -> int: + """number of blank space characters around the gate name.""" + return 4 + + @classmethod + def _qubit_line_spacing_above(cls) -> int: + """number of empty lines above the qubit line.""" + return 1 + + @classmethod + def _qubit_line_spacing_below(cls) -> int: + """number of empty lines below the qubit line.""" + return 1 + + @classmethod + def _duplicate_time_at_bottom(cls, lines: list) -> None: + # Do not add a line after the circuit + # It is safe to do because the last line is empty: _qubit_line_spacing["after"] = 1 + lines[-1] = lines[0] + + @classmethod + def _create_diagram_column( + cls, + circuit_qubits: QubitSet, + items: list[Instruction | ResultType], + global_phase: float | None = None, + ) -> str: + """Return a column in the string diagram of the circuit for a given list of items. + + Args: + circuit_qubits (QubitSet): qubits in circuit + items (list[Instruction | ResultType]): list of instructions or result types + global_phase (float | None): the integrated global phase up to this column + + Returns: + str: a string diagram for the specified moment in time for a column. + """ + symbols = {qubit: cls._qubit_line_character() for qubit in circuit_qubits} + connections = {qubit: "none" for qubit in circuit_qubits} + + for item in items: + ( + target_qubits, + control_qubits, + qubits, + connections, + ascii_symbols, + map_control_qubit_states, + ) = cls._build_parameters(circuit_qubits, item, connections) + + for qubit in qubits: + # Determine if the qubit is part of the item or in the middle of a + # multi qubit item. + if qubit in target_qubits: + item_qubit_index = [ + index for index, q in enumerate(target_qubits) if q == qubit + ][0] + power_string = ( + f"^{power}" + if ( + (power := getattr(item, "power", 1)) != 1 + # this has the limitation of not printing the power + # when a user has a gate genuinely named C, but + # is necessary to enable proper printing of custom + # gates with built-in control qubits + and ascii_symbols[item_qubit_index] != "C" + ) + else "" + ) + symbols[qubit] = ( + f"{ascii_symbols[item_qubit_index]}{power_string}" + if power_string + else ascii_symbols[item_qubit_index] + ) + + elif qubit in control_qubits: + symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N" + else: + symbols[qubit] = "┼" + + output = cls._create_output(symbols, connections, circuit_qubits, global_phase) + return output + + @classmethod + def _build_parameters( + cls, circuit_qubits: QubitSet, item: ResultType | Instruction, connections: dict[Qubit, str] + ) -> tuple: + map_control_qubit_states = {} + + if (isinstance(item, ResultType) and not item.target) or ( + isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective) + ): + target_qubits = circuit_qubits + control_qubits = QubitSet() + qubits = circuit_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(qubits) + cls._update_connections(qubits, connections) + elif ( + isinstance(item, Instruction) + and isinstance(item.operator, Gate) + and item.operator.name == "GPhase" + ): + target_qubits = circuit_qubits + control_qubits = QubitSet() + qubits = circuit_qubits + ascii_symbols = cls._qubit_line_character() * len(circuit_qubits) + else: + if isinstance(item.target, list): + target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) + else: + target_qubits = item.target + control_qubits = getattr(item, "control", QubitSet()) + control_state = getattr(item, "control_state", "1" * len(control_qubits)) + map_control_qubit_states = dict(zip(control_qubits, control_state)) + + target_and_control = target_qubits.union(control_qubits) + qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) + ascii_symbols = item.ascii_symbols + cls._update_connections(qubits, connections) + + return ( + target_qubits, + control_qubits, + qubits, + connections, + ascii_symbols, + map_control_qubit_states, + ) + + @staticmethod + def _update_connections(qubits: QubitSet, connections: dict[Qubit, str]) -> None: + if len(qubits) > 1: + connections |= {qubit: "both" for qubit in qubits[1:-1]} + connections[qubits[-1]] = "above" + connections[qubits[0]] = "below" + + # Ignore flake8 issue caused by Literal["above", "below", "both", "none"] + # flake8: noqa: BCS005 + @classmethod + def _draw_symbol( + cls, + symbol: str, + symbols_width: int, + connection: Literal["above", "below", "both", "none"], + ) -> str: + """Create a string representing the symbol inside a box. + + Args: + symbol (str): the gate name + symbols_width (int): size of the expected output. The output will be filled with + cls._qubit_line_character() if needed. + connection (Literal["above", "below", "both", "none"]): specifies if a connection + will be drawn above and/or below the box. + + Returns: + str: a string representing the symbol. + """ + top = "" + bottom = "" + if symbol in {"C", "N", "SWAP"}: + if connection in ["above", "both"]: + top = _fill_symbol(cls._vertical_delimiter(), " ") + if connection in ["below", "both"]: + bottom = _fill_symbol(cls._vertical_delimiter(), " ") + new_symbol = {"C": "●", "N": "◯", "SWAP": "x"} + # replace SWAP by x + # the size of the moment remains as if there was a box with 4 characters inside + symbol = _fill_symbol(new_symbol[symbol], cls._qubit_line_character()) + elif symbol in ["StartVerbatim", "EndVerbatim"]: + top, symbol, bottom = cls._build_verbatim_box(symbol, connection) + elif symbol == "┼": + top = bottom = _fill_symbol(cls._vertical_delimiter(), " ") + symbol = _fill_symbol(f"{symbol}", cls._qubit_line_character()) + elif symbol != cls._qubit_line_character(): + top, symbol, bottom = cls._build_box(symbol, connection) + + output = f"{_fill_symbol(top, ' ', symbols_width)} \n" + output += f"{_fill_symbol(symbol, cls._qubit_line_character(), symbols_width)}{cls._qubit_line_character()}\n" + output += f"{_fill_symbol(bottom, ' ', symbols_width)} \n" + return output + + @staticmethod + def _build_box( + symbol: str, connection: Literal["above", "below", "both", "none"] + ) -> tuple[str, str, str]: + top_edge_symbol = "┴" if connection in ["above", "both"] else "─" + top = f"┌─{_fill_symbol(top_edge_symbol, '─', len(symbol))}─┐" + + bottom_edge_symbol = "┬" if connection in ["below", "both"] else "─" + bottom = f"└─{_fill_symbol(bottom_edge_symbol, '─', len(symbol))}─┘" + + symbol = f"┤ {symbol} ├" + return top, symbol, bottom + + @classmethod + def _build_verbatim_box( + cls, + symbol: Literal["StartVerbatim", "EndVerbatim"], + connection: Literal["above", "below", "both", "none"], + ) -> str: + top = "" + bottom = "" + if connection == "below": + bottom = "║" + elif connection == "both": + top = bottom = "║" + symbol = "║" + elif connection == "above": + top = "║" + symbol = "╨" + top = _fill_symbol(top, " ") + symbol = _fill_symbol(symbol, cls._qubit_line_character()) + bottom = _fill_symbol(bottom, " ") + + return top, symbol, bottom + + +def _fill_symbol(symbol: str, filler: str, width: int | None = None) -> str: + return "{0:{fill}{align}{width}}".format( + symbol, + fill=filler, + align="^", + width=width if width is not None else len(symbol), + ) diff --git a/src/braket/circuits/translations.py b/src/braket/circuits/translations.py index bbb194be3..78bb7eed0 100644 --- a/src/braket/circuits/translations.py +++ b/src/braket/circuits/translations.py @@ -15,10 +15,9 @@ from typing import Union import braket.circuits.gates as braket_gates -import braket.circuits.noises as noises -import braket.circuits.result_types as ResultTypes +import braket.circuits.result_types as ResultTypes # noqa: N812 import braket.ir.jaqcd.shared_models as models -from braket.circuits import Observable, observables +from braket.circuits import Observable, noises, observables from braket.ir.jaqcd import ( Amplitude, DensityMatrix, @@ -68,6 +67,7 @@ "cswap": braket_gates.CSwap, "gpi": braket_gates.GPi, "gpi2": braket_gates.GPi2, + "prx": braket_gates.PRx, "ms": braket_gates.MS, "unitary": braket_gates.Unitary, } @@ -84,8 +84,29 @@ "phase_damping": noises.PhaseDamping, } +SUPPORTED_NOISE_PRAGMA_TO_NOISE = { + "braket_noise_bit_flip": noises.BitFlip, + "braket_noise_phase_flip": noises.PhaseFlip, + "braket_noise_pauli_channel": noises.PauliChannel, + "braket_noise_depolarizing": noises.Depolarizing, + "braket_noise_two_qubit_depolarizing": noises.TwoQubitDepolarizing, + "braket_noise_two_qubit_dephasing": noises.TwoQubitDephasing, + "braket_noise_amplitude_damping": noises.AmplitudeDamping, + "braket_noise_generalized_amplitude_damping": noises.GeneralizedAmplitudeDamping, + "braket_noise_phase_damping": noises.PhaseDamping, + "braket_noise_kraus": noises.Kraus, +} + def get_observable(obs: Union[models.Observable, list]) -> Observable: + """Gets the observable. + + Args: + obs (Union[Observable, list]): The observable(s) to get translated. + + Returns: + Observable: The translated observable. + """ return _get_observable(obs) @@ -127,39 +148,39 @@ def braket_result_to_result_type(result: Results) -> None: @_braket_result_to_result_type.register(Amplitude) -def _(result): +def _(result: Results) -> Amplitude: return ResultTypes.Amplitude(state=result.states) @_braket_result_to_result_type.register(Expectation) -def _(result): +def _(result: Results) -> Expectation: tensor_product = get_tensor_product(result.observable) return ResultTypes.Expectation(observable=tensor_product, target=result.targets) @_braket_result_to_result_type.register(Probability) -def _(result): +def _(result: Results) -> Probability: return ResultTypes.Probability(result.targets) @_braket_result_to_result_type.register(Sample) -def _(result): +def _(result: Results) -> Sample: tensor_product = get_tensor_product(result.observable) return ResultTypes.Sample(observable=tensor_product, target=result.targets) @_braket_result_to_result_type.register(StateVector) -def _(result): +def _(result: Results) -> StateVector: return ResultTypes.StateVector() @_braket_result_to_result_type.register(DensityMatrix) -def _(result): +def _(result: Results): return ResultTypes.DensityMatrix(target=result.targets) @_braket_result_to_result_type.register(Variance) -def _(result): +def _(result: Results): tensor_product = get_tensor_product(result.observable) return ResultTypes.Variance(observable=tensor_product, target=result.targets) diff --git a/src/braket/circuits/unitary_calculation.py b/src/braket/circuits/unitary_calculation.py index ebc3c7878..9fa404284 100644 --- a/src/braket/circuits/unitary_calculation.py +++ b/src/braket/circuits/unitary_calculation.py @@ -26,8 +26,7 @@ def calculate_unitary_big_endian( instructions: Iterable[Instruction], qubits: QubitSet ) -> np.ndarray: - """ - Returns the unitary matrix representation for all the `instructions` on qubits `qubits`. + """Returns the unitary matrix representation for all the `instruction`s on qubits `qubits`. Note: The performance of this method degrades with qubit count. It might be slow for diff --git a/src/braket/devices/device.py b/src/braket/devices/device.py index 49f510edf..dbc7b6b35 100644 --- a/src/braket/devices/device.py +++ b/src/braket/devices/device.py @@ -11,11 +11,17 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import warnings from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Any, Optional, Union +from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.annealing.problem import Problem -from braket.circuits import Circuit +from braket.circuits import Circuit, Noise +from braket.circuits.noise_model import NoiseModel +from braket.circuits.translations import SUPPORTED_NOISE_PRAGMA_TO_NOISE +from braket.device_schema import DeviceActionType +from braket.ir.openqasm import Program from braket.tasks.quantum_task import QuantumTask from braket.tasks.quantum_task_batch import QuantumTaskBatch @@ -24,7 +30,8 @@ class Device(ABC): """An abstraction over quantum devices that includes quantum computers and simulators.""" def __init__(self, name: str, status: str): - """ + """Initializes a `Device`. + Args: name (str): Name of quantum device status (str): Status of quantum device @@ -39,7 +46,7 @@ def run( shots: Optional[int], inputs: Optional[dict[str, float]], *args, - **kwargs + **kwargs, ) -> QuantumTask: """Run a quantum task specification on this quantum device. A quantum task can be a circuit or an annealing problem. @@ -52,6 +59,8 @@ def run( inputs (Optional[dict[str, float]]): Inputs to be passed along with the IR. If IR is an OpenQASM Program, the inputs will be updated with this value. Not all devices and IR formats support inputs. Default: {}. + *args (Any): Arbitrary arguments. + **kwargs (Any): Arbitrary keyword arguments. Returns: QuantumTask: The QuantumTask tracking task execution on this device @@ -67,8 +76,8 @@ def run_batch( shots: Optional[int], max_parallel: Optional[int], inputs: Optional[Union[dict[str, float], list[dict[str, float]]]], - *args, - **kwargs + *args: Any, + **kwargs: Any, ) -> QuantumTaskBatch: """Executes a batch of quantum tasks in parallel @@ -82,6 +91,8 @@ def run_batch( inputs (Optional[Union[dict[str, float], list[dict[str, float]]]]): Inputs to be passed along with the IR. If the IR supports inputs, the inputs will be updated with this value. + *args (Any): Arbitrary arguments. + **kwargs (Any): Arbitrary keyword arguments. Returns: QuantumTaskBatch: A batch containing all of the qauntum tasks run @@ -104,3 +115,36 @@ def status(self) -> str: str: The status of this Device """ return self._status + + def _validate_device_noise_model_support(self, noise_model: NoiseModel) -> None: + supported_noises = { + SUPPORTED_NOISE_PRAGMA_TO_NOISE[pragma].__name__ + for pragma in self.properties.action[DeviceActionType.OPENQASM].supportedPragmas + if pragma in SUPPORTED_NOISE_PRAGMA_TO_NOISE + } + noise_operators = {noise_instr.noise.name for noise_instr in noise_model._instructions} + if not noise_operators <= supported_noises: + raise ValueError( + f"{self.name} does not support noise simulation or the noise model includes noise " + + f"that is not supported by {self.name}." + ) + + def _apply_noise_model_to_circuit( + self, task_specification: Union[Circuit, Problem, Program, AnalogHamiltonianSimulation] + ) -> None: + if isinstance(task_specification, Circuit): + for instruction in task_specification.instructions: + if isinstance(instruction.operator, Noise): + warnings.warn( + "The noise model of the device is applied to a circuit that already has" + " noise instructions." + ) + break + task_specification = self._noise_model.apply(task_specification) + else: + warnings.warn( + "Noise model is only applicable to circuits. The type of the task specification is" + f" {task_specification.__class__.__name__}. The noise model of the device does not" + " apply." + ) + return task_specification diff --git a/src/braket/devices/devices.py b/src/braket/devices/devices.py index f4fe8ff8d..fa2c6d025 100644 --- a/src/braket/devices/devices.py +++ b/src/braket/devices/devices.py @@ -27,6 +27,9 @@ class _DWave(str, Enum): _Advantage6 = "arn:aws:braket:us-west-2::device/qpu/d-wave/Advantage_system6" _DW2000Q6 = "arn:aws:braket:::device/qpu/d-wave/DW_2000Q_6" + class _IQM(str, Enum): + Garnet = "arn:aws:braket:eu-north-1::device/qpu/iqm/Garnet" + class _IonQ(str, Enum): Harmony = "arn:aws:braket:us-east-1::device/qpu/ionq/Harmony" Aria1 = "arn:aws:braket:us-east-1::device/qpu/ionq/Aria-1" @@ -54,6 +57,7 @@ class _Xanadu(str, Enum): Amazon = _Amazon # DWave = _DWave IonQ = _IonQ + IQM = _IQM OQC = _OQC QuEra = _QuEra Rigetti = _Rigetti diff --git a/src/braket/devices/local_simulator.py b/src/braket/devices/local_simulator.py index c719978ac..15ec904de 100644 --- a/src/braket/devices/local_simulator.py +++ b/src/braket/devices/local_simulator.py @@ -13,23 +13,23 @@ from __future__ import annotations +import sys from functools import singledispatchmethod from itertools import repeat from multiprocessing import Pool from os import cpu_count -from typing import Optional, Union - -import pkg_resources +from typing import Any, Optional, Union from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.annealing.problem import Problem from braket.circuits import Circuit from braket.circuits.circuit_helpers import validate_circuit_and_shots -from braket.circuits.serialization import IRType +from braket.circuits.noise_model import NoiseModel +from braket.circuits.serialization import IRType, SerializableProgram from braket.device_schema import DeviceActionType, DeviceCapabilities from braket.devices.device import Device from braket.ir.ahs import Program as AHSProgram -from braket.ir.openqasm import Program +from braket.ir.openqasm import Program as OpenQASMProgram from braket.simulator import BraketSimulator from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult from braket.tasks.analog_hamiltonian_simulation_quantum_task_result import ( @@ -38,9 +38,12 @@ from braket.tasks.local_quantum_task import LocalQuantumTask from braket.tasks.local_quantum_task_batch import LocalQuantumTaskBatch -_simulator_devices = { - entry.name: entry for entry in pkg_resources.iter_entry_points("braket.simulators") -} +if sys.version_info.minor == 9: + from backports.entry_points_selectable import entry_points +else: + from importlib.metadata import entry_points + +_simulator_devices = {entry.name: entry for entry in entry_points(group="braket.simulators")} class LocalSimulator(Device): @@ -50,12 +53,20 @@ class LocalSimulator(Device): results using constructs from the SDK rather than Braket IR. """ - def __init__(self, backend: Union[str, BraketSimulator] = "default"): - """ + def __init__( + self, + backend: Union[str, BraketSimulator] = "default", + noise_model: Optional[NoiseModel] = None, + ): + """Initializes a `LocalSimulator`. + Args: backend (Union[str, BraketSimulator]): The name of the simulator backend or the actual simulator instance to use for simulation. Defaults to the `default` simulator backend name. + noise_model (Optional[NoiseModel]): The Braket noise model to apply to the circuit + before execution. Noise model can only be added to the devices that support + noise simulation. """ delegate = self._get_simulator(backend) super().__init__( @@ -63,19 +74,24 @@ def __init__(self, backend: Union[str, BraketSimulator] = "default"): status="AVAILABLE", ) self._delegate = delegate + if noise_model: + self._validate_device_noise_model_support(noise_model) + self._noise_model = noise_model def run( self, - task_specification: Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], + task_specification: Union[ + Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram + ], shots: int = 0, inputs: Optional[dict[str, float]] = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> LocalQuantumTask: """Runs the given task with the wrapped local simulator. Args: - task_specification (Union[Circuit, Problem, Program, AnalogHamiltonianSimulation]): + task_specification (Union[Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram]): # noqa E501 The quantum task specification. shots (int): The number of times to run the circuit or annealing problem. Default is 0, which means that the simulator will compute the exact @@ -84,6 +100,8 @@ def run( inputs (Optional[dict[str, float]]): Inputs to be passed along with the IR. If the IR supports inputs, the inputs will be updated with this value. Default: {}. + *args (Any): Arbitrary arguments. + **kwargs(Any): Arbitrary keyword arguments. Returns: LocalQuantumTask: A LocalQuantumTask object containing the results @@ -98,14 +116,26 @@ def run( >>> device = LocalSimulator("default") >>> device.run(circuit, shots=1000) """ + if self._noise_model: + task_specification = self._apply_noise_model_to_circuit(task_specification) result = self._run_internal(task_specification, shots, inputs=inputs, *args, **kwargs) return LocalQuantumTask(result) - def run_batch( + def run_batch( # noqa: C901 self, task_specifications: Union[ - Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], - list[Union[Circuit, Problem, Program, AnalogHamiltonianSimulation]], + Union[ + Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram + ], + list[ + Union[ + Circuit, + Problem, + OpenQASMProgram, + AnalogHamiltonianSimulation, + SerializableProgram, + ] + ], ], shots: Optional[int] = 0, max_parallel: Optional[int] = None, @@ -116,7 +146,7 @@ def run_batch( """Executes a batch of quantum tasks in parallel Args: - task_specifications (Union[Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], list[Union[Circuit, Problem, Program, AnalogHamiltonianSimulation]]]): # noqa + task_specifications (Union[Union[Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram], list[Union[Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram]]]): # noqa Single instance or list of quantum task specification. shots (Optional[int]): The number of times to run the quantum task. Default: 0. @@ -131,24 +161,27 @@ def run_batch( See Also: `braket.tasks.local_quantum_task_batch.LocalQuantumTaskBatch` - """ + """ # noqa E501 inputs = inputs or {} + if self._noise_model: + task_specifications = [ + self._apply_noise_model_to_circuit(task_specification) + for task_specification in task_specifications + ] + if not max_parallel: max_parallel = cpu_count() single_task = isinstance( task_specifications, - (Circuit, Program, Problem, AnalogHamiltonianSimulation), + (Circuit, OpenQASMProgram, Problem, AnalogHamiltonianSimulation), ) single_input = isinstance(inputs, dict) - if not single_task and not single_input: - if len(task_specifications) != len(inputs): - raise ValueError( - "Multiple inputs and task specifications must " "be equal in number." - ) + if not single_task and not single_input and len(task_specifications) != len(inputs): + raise ValueError("Multiple inputs and task specifications must be equal in number.") if single_task: task_specifications = repeat(task_specifications) @@ -165,8 +198,7 @@ def run_batch( for task_specification, input_map in tasks_and_inputs: if isinstance(task_specification, Circuit): param_names = {param.name for param in task_specification.parameters} - unbounded_parameters = param_names - set(input_map.keys()) - if unbounded_parameters: + if unbounded_parameters := param_names - set(input_map.keys()): raise ValueError( f"Cannot execute circuit with unbound parameters: " f"{unbounded_parameters}" @@ -184,7 +216,8 @@ def properties(self) -> DeviceCapabilities: Please see `braket.device_schema` in amazon-braket-schemas-python_ - .. _amazon-braket-schemas-python: https://github.com/aws/amazon-braket-schemas-python""" + .. _amazon-braket-schemas-python: https://github.com/aws/amazon-braket-schemas-python + """ return self._delegate.properties @staticmethod @@ -199,7 +232,9 @@ def registered_backends() -> set[str]: def _run_internal_wrap( self, - task_specification: Union[Circuit, Problem, Program, AnalogHamiltonianSimulation], + task_specification: Union[ + Circuit, Problem, OpenQASMProgram, AnalogHamiltonianSimulation, SerializableProgram + ], shots: Optional[int] = None, inputs: Optional[dict[str, float]] = None, *args, @@ -214,13 +249,12 @@ def _get_simulator(self, simulator: Union[str, BraketSimulator]) -> LocalSimulat @_get_simulator.register def _(self, backend_name: str): - if backend_name in _simulator_devices: - device_class = _simulator_devices[backend_name].load() - return device_class() - else: + if backend_name not in _simulator_devices: raise ValueError( f"Only the following devices are available {_simulator_devices.keys()}" ) + device_class = _simulator_devices[backend_name].load() + return device_class() @_get_simulator.register def _(self, backend_impl: BraketSimulator): @@ -230,7 +264,12 @@ def _(self, backend_impl: BraketSimulator): def _run_internal( self, task_specification: Union[ - Circuit, Problem, Program, AnalogHamiltonianSimulation, AHSProgram + Circuit, + Problem, + OpenQASMProgram, + AnalogHamiltonianSimulation, + AHSProgram, + SerializableProgram, ], shots: Optional[int] = None, *args, @@ -276,7 +315,7 @@ def _(self, problem: Problem, shots: Optional[int] = None, *args, **kwargs): @_run_internal.register def _( self, - program: Program, + program: OpenQASMProgram, shots: Optional[int] = None, inputs: Optional[dict[str, float]] = None, *args, @@ -288,13 +327,30 @@ def _( if inputs: inputs_copy = program.inputs.copy() if program.inputs is not None else {} inputs_copy.update(inputs) - program = Program( + program = OpenQASMProgram( source=program.source, inputs=inputs_copy, ) + results = simulator.run(program, shots, *args, **kwargs) + + if isinstance(results, GateModelQuantumTaskResult): + return results + return GateModelQuantumTaskResult.from_object(results) + @_run_internal.register + def _( + self, + program: SerializableProgram, + shots: Optional[int] = None, + inputs: Optional[dict[str, float]] = None, + *args, + **kwargs, + ): + program = OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM)) + return self._run_internal(program, shots, inputs, *args, **kwargs) + @_run_internal.register def _( self, diff --git a/src/braket/error_mitigation/debias.py b/src/braket/error_mitigation/debias.py index 305bf7b78..8beddc7ef 100644 --- a/src/braket/error_mitigation/debias.py +++ b/src/braket/error_mitigation/debias.py @@ -16,9 +16,7 @@ class Debias(ErrorMitigation): - """ - The debias error mitigation scheme. This scheme takes no parameters. - """ + """The debias error mitigation scheme. This scheme takes no parameters.""" def serialize(self) -> list[error_mitigation.Debias]: return [error_mitigation.Debias()] diff --git a/src/braket/error_mitigation/error_mitigation.py b/src/braket/error_mitigation/error_mitigation.py index 79b1f3e30..95e6b6582 100644 --- a/src/braket/error_mitigation/error_mitigation.py +++ b/src/braket/error_mitigation/error_mitigation.py @@ -16,9 +16,14 @@ class ErrorMitigation: def serialize(self) -> list[error_mitigation.ErrorMitigationScheme]: - """ + """This returns a list of service-readable error mitigation + scheme descriptions. + Returns: list[ErrorMitigationScheme]: A list of service-readable error mitigation scheme descriptions. + + Raises: + NotImplementedError: Not implemented in the base class. """ raise NotImplementedError("serialize is not implemented.") diff --git a/src/braket/ipython_utils.py b/src/braket/ipython_utils.py index 20100d944..d850ee85c 100644 --- a/src/braket/ipython_utils.py +++ b/src/braket/ipython_utils.py @@ -15,8 +15,7 @@ def running_in_jupyter() -> bool: - """ - Determine if running within Jupyter. + """Determine if running within Jupyter. Inspired by https://github.com/ipython/ipython/issues/11694 @@ -24,8 +23,6 @@ def running_in_jupyter() -> bool: bool: True if running in Jupyter, else False. """ in_ipython = False - in_ipython_kernel = False - # if IPython hasn't been imported, there's nothing to check if "IPython" in sys.modules: get_ipython = sys.modules["IPython"].__dict__["get_ipython"] @@ -33,7 +30,4 @@ def running_in_jupyter() -> bool: ip = get_ipython() in_ipython = ip is not None - if in_ipython: - in_ipython_kernel = getattr(ip, "kernel", None) is not None - - return in_ipython_kernel + return getattr(ip, "kernel", None) is not None if in_ipython else False diff --git a/src/braket/jobs/config.py b/src/braket/jobs/config.py index a598388e4..7c84b42dd 100644 --- a/src/braket/jobs/config.py +++ b/src/braket/jobs/config.py @@ -54,8 +54,8 @@ class DeviceConfig: class S3DataSourceConfig: - """ - Data source for data that lives on S3 + """Data source for data that lives on S3. + Attributes: config (dict[str, dict]): config passed to the Braket API """ diff --git a/src/braket/jobs/data_persistence.py b/src/braket/jobs/data_persistence.py index f2ec9b6fa..0386ed7a7 100644 --- a/src/braket/jobs/data_persistence.py +++ b/src/braket/jobs/data_persistence.py @@ -26,9 +26,8 @@ def save_job_checkpoint( checkpoint_file_suffix: str = "", data_format: PersistedJobDataFormat = PersistedJobDataFormat.PLAINTEXT, ) -> None: - """ - Saves the specified `checkpoint_data` to the local output directory, specified by the container - environment variable `CHECKPOINT_DIR`, with the filename + """Saves the specified `checkpoint_data` to the local output directory, specified by + the container environment variable `CHECKPOINT_DIR`, with the filename `f"{job_name}(_{checkpoint_file_suffix}).json"`. The `job_name` refers to the name of the current job and is retrieved from the container environment variable `JOB_NAME`. The `checkpoint_data` values are serialized to the specified `data_format`. @@ -68,8 +67,7 @@ def save_job_checkpoint( def load_job_checkpoint( job_name: str | None = None, checkpoint_file_suffix: str = "" ) -> dict[str, Any]: - """ - Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint + """Loads the job checkpoint data stored for the job named 'job_name', with the checkpoint file that ends with the `checkpoint_file_suffix`. The `job_name` can refer to any job whose checkpoint data you expect to be available in the file path specified by the `CHECKPOINT_DIR` container environment variable. If not provided, this function will use the currently running @@ -104,7 +102,7 @@ def load_job_checkpoint( if checkpoint_file_suffix else f"{checkpoint_directory}/{job_name}.json" ) - with open(checkpoint_file_path, "r") as f: + with open(checkpoint_file_path) as f: persisted_data = PersistedJobData.parse_raw(f.read()) deserialized_data = deserialize_values( persisted_data.dataDictionary, persisted_data.dataFormat @@ -115,7 +113,7 @@ def load_job_checkpoint( def _load_persisted_data(filename: str | Path | None = None) -> PersistedJobData: filename = filename or Path(get_results_dir()) / "results.json" try: - with open(filename, mode="r") as f: + with open(filename) as f: return PersistedJobData.parse_raw(f.read()) except FileNotFoundError: return PersistedJobData( @@ -125,8 +123,7 @@ def _load_persisted_data(filename: str | Path | None = None) -> PersistedJobData def load_job_result(filename: str | Path | None = None) -> dict[str, Any]: - """ - Loads job result of currently running job. + """Loads job result of currently running job. Args: filename (str | Path | None): Location of job results. Default `results.json` in job @@ -145,8 +142,7 @@ def save_job_result( result_data: dict[str, Any] | Any, data_format: PersistedJobDataFormat | None = None, ) -> None: - """ - Saves the `result_data` to the local output directory that is specified by the container + """Saves the `result_data` to the local output directory that is specified by the container environment variable `AMZN_BRAKET_JOB_RESULTS_DIR`, with the filename 'results.json'. The `result_data` values are serialized to the specified `data_format`. @@ -160,6 +156,9 @@ def save_job_result( data_format (PersistedJobDataFormat | None): The data format used to serialize the values. Note that for `PICKLED` data formats, the values are base64 encoded after serialization. Default: PersistedJobDataFormat.PLAINTEXT. + + Raises: + TypeError: Unsupported data format. """ if not isinstance(result_data, dict): result_data = {"result": result_data} diff --git a/src/braket/jobs/environment_variables.py b/src/braket/jobs/environment_variables.py index 4fba9315c..6d7d18364 100644 --- a/src/braket/jobs/environment_variables.py +++ b/src/braket/jobs/environment_variables.py @@ -16,8 +16,7 @@ def get_job_name() -> str: - """ - Get the name of the current job. + """Get the name of the current job. Returns: str: The name of the job if in a job, else an empty string. @@ -26,8 +25,7 @@ def get_job_name() -> str: def get_job_device_arn() -> str: - """ - Get the device ARN of the current job. If not in a job, default to "local:none/none". + """Get the device ARN of the current job. If not in a job, default to "local:none/none". Returns: str: The device ARN of the current job or "local:none/none". @@ -36,8 +34,7 @@ def get_job_device_arn() -> str: def get_input_data_dir(channel: str = "input") -> str: - """ - Get the job input data directory. + """Get the job input data directory. Args: channel (str): The name of the input channel. Default value @@ -47,14 +44,11 @@ def get_input_data_dir(channel: str = "input") -> str: str: The input directory, defaulting to current working directory. """ input_dir = os.getenv("AMZN_BRAKET_INPUT_DIR", ".") - if input_dir != ".": - return f"{input_dir}/{channel}" - return input_dir + return f"{input_dir}/{channel}" if input_dir != "." else input_dir def get_results_dir() -> str: - """ - Get the job result directory. + """Get the job result directory. Returns: str: The results directory, defaulting to current working directory. @@ -63,8 +57,7 @@ def get_results_dir() -> str: def get_checkpoint_dir() -> str: - """ - Get the job checkpoint directory. + """Get the job checkpoint directory. Returns: str: The checkpoint directory, defaulting to current working directory. @@ -73,13 +66,12 @@ def get_checkpoint_dir() -> str: def get_hyperparameters() -> dict[str, str]: - """ - Get the job hyperparameters as a dict, with the values stringified. + """Get the job hyperparameters as a dict, with the values stringified. Returns: dict[str, str]: The hyperparameters of the job. """ if "AMZN_BRAKET_HP_FILE" in os.environ: - with open(os.getenv("AMZN_BRAKET_HP_FILE"), "r") as f: + with open(os.getenv("AMZN_BRAKET_HP_FILE")) as f: return json.load(f) return {} diff --git a/src/braket/jobs/hybrid_job.py b/src/braket/jobs/hybrid_job.py index b8e1e58bf..77f5f43d0 100644 --- a/src/braket/jobs/hybrid_job.py +++ b/src/braket/jobs/hybrid_job.py @@ -168,9 +168,13 @@ def hybrid_job( def _hybrid_job(entry_point: Callable) -> Callable: @functools.wraps(entry_point) - def job_wrapper(*args, **kwargs) -> Callable: - """ - The job wrapper. + def job_wrapper(*args: Any, **kwargs: Any) -> Callable: + """The job wrapper. + + Args: + *args (Any): Arbitrary arguments. + **kwargs (Any): Arbitrary keyword arguments. + Returns: Callable: the callable for creating a Hybrid Job. """ @@ -243,7 +247,7 @@ def _validate_python_version(image_uri: str | None, aws_session: AwsSession | No image_uri = image_uri or retrieve_image(Framework.BASE, aws_session.region) tag = aws_session.get_full_image_tag(image_uri) major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups() - if not (sys.version_info.major, sys.version_info.minor) == ( + if (sys.version_info.major, sys.version_info.minor) != ( int(major_version), int(minor_version), ): @@ -322,7 +326,8 @@ def _log_hyperparameters(entry_point: Callable, args: tuple, kwargs: dict) -> di hyperparameters.update(**value) else: warnings.warn( - "Positional only arguments will not be logged to the hyperparameters file." + "Positional only arguments will not be logged to the hyperparameters file.", + stacklevel=1, ) return {name: _sanitize(value) for name, value in hyperparameters.items()} @@ -351,8 +356,7 @@ def _sanitize(hyperparameter: Any) -> str: def _process_input_data(input_data: dict) -> list[str]: - """ - Create symlinks to data + """Create symlinks to data. Logic chart for how the service moves files into the data directory on the instance: input data matches exactly one file: cwd/filename -> channel/filename @@ -365,9 +369,7 @@ def _process_input_data(input_data: dict) -> list[str]: input_data = {"input": input_data} def matches(prefix: str) -> list[str]: - return [ - str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(str(prefix)) - ] + return [str(path) for path in Path(prefix).parent.iterdir() if str(path).startswith(prefix)] def is_prefix(path: str) -> bool: return len(matches(path)) > 1 or not Path(path).exists() @@ -384,7 +386,7 @@ def is_prefix(path: str) -> bool: f"the working directory. Use `get_input_data_dir({channel_arg})` to read " f"input data from S3 source inside the job container." ) - elif is_prefix(data): + elif is_prefix(str(data)): prefix_channels.add(channel) elif Path(data).is_dir(): directory_channels.add(channel) diff --git a/src/braket/jobs/image_uri_config/base.json b/src/braket/jobs/image_uri_config/base.json index eb71e60fd..c7aef2be2 100644 --- a/src/braket/jobs/image_uri_config/base.json +++ b/src/braket/jobs/image_uri_config/base.json @@ -5,6 +5,7 @@ "us-east-1", "us-west-1", "us-west-2", - "eu-west-2" + "eu-west-2", + "eu-north-1" ] } diff --git a/src/braket/jobs/image_uri_config/pl_pytorch.json b/src/braket/jobs/image_uri_config/pl_pytorch.json index c7e28fbde..0a00e8537 100644 --- a/src/braket/jobs/image_uri_config/pl_pytorch.json +++ b/src/braket/jobs/image_uri_config/pl_pytorch.json @@ -5,6 +5,7 @@ "us-east-1", "us-west-1", "us-west-2", - "eu-west-2" + "eu-west-2", + "eu-north-1" ] } diff --git a/src/braket/jobs/image_uri_config/pl_tensorflow.json b/src/braket/jobs/image_uri_config/pl_tensorflow.json index 3278a8712..c43792e8a 100644 --- a/src/braket/jobs/image_uri_config/pl_tensorflow.json +++ b/src/braket/jobs/image_uri_config/pl_tensorflow.json @@ -5,6 +5,7 @@ "us-east-1", "us-west-1", "us-west-2", - "eu-west-2" + "eu-west-2", + "eu-north-1" ] } diff --git a/src/braket/jobs/image_uris.py b/src/braket/jobs/image_uris.py index 3a3346abe..af6c5012a 100644 --- a/src/braket/jobs/image_uris.py +++ b/src/braket/jobs/image_uris.py @@ -15,7 +15,6 @@ import os from enum import Enum from functools import cache -from typing import Dict, Set class Framework(str, Enum): @@ -26,7 +25,15 @@ class Framework(str, Enum): PL_PYTORCH = "PL_PYTORCH" -def built_in_images(region: str) -> Set[str]: +def built_in_images(region: str) -> set[str]: + """Checks a region for built in Braket images. + + Args: + region (str): The AWS region to check for images + + Returns: + set[str]: returns a set of built images + """ return {retrieve_image(framework, region) for framework in Framework} @@ -53,25 +60,25 @@ def retrieve_image(framework: Framework, region: str) -> str: return f"{registry}.dkr.ecr.{region}.amazonaws.com/{tag}" -def _config_for_framework(framework: Framework) -> Dict[str, str]: +def _config_for_framework(framework: Framework) -> dict[str, str]: """Loads the JSON config for the given framework. Args: framework (Framework): The framework whose config needs to be loaded. Returns: - Dict[str, str]: Dict that contains the configuration for the specified framework. + dict[str, str]: Dict that contains the configuration for the specified framework. """ fname = os.path.join(os.path.dirname(__file__), "image_uri_config", f"{framework.lower()}.json") with open(fname) as f: return json.load(f) -def _registry_for_region(config: Dict[str, str], region: str) -> str: +def _registry_for_region(config: dict[str, str], region: str) -> str: """Retrieves the registry for the specified region from the configuration. Args: - config (Dict[str, str]): Dict containing the framework configuration. + config (dict[str, str]): Dict containing the framework configuration. region (str): str that specifies the region for which the registry is retrieved. Returns: diff --git a/src/braket/jobs/local/local_job.py b/src/braket/jobs/local/local_job.py index f516d9693..4dd15607a 100644 --- a/src/braket/jobs/local/local_job.py +++ b/src/braket/jobs/local/local_job.py @@ -117,6 +117,9 @@ def create( container image. Optional. Default: True. + Raises: + ValueError: Local directory with the job name already exists. + Returns: LocalQuantumJob: The representation of a local Braket Hybrid Job. """ @@ -166,11 +169,15 @@ def create( return LocalQuantumJob(f"local:job/{job_name}", run_log) def __init__(self, arn: str, run_log: str | None = None): - """ + """Initializes a `LocalQuantumJob`. + Args: arn (str): The ARN of the hybrid job. - run_log (str | None): The container output log of running the hybrid job with the - given arn. + run_log (str | None): The container output log of running the hybrid job with the given + arn. + + Raises: + ValueError: Local job is not found. """ if not arn.startswith("local:job/"): raise ValueError(f"Arn {arn} is not a valid local job arn") @@ -194,24 +201,31 @@ def name(self) -> str: def run_log(self) -> str: """Gets the run output log from running the hybrid job. + Raises: + ValueError: The log file is not found. + Returns: str: The container output log from running the hybrid job. """ if not self._run_log: try: - with open(os.path.join(self.name, "log.txt"), "r") as log_file: + with open(os.path.join(self.name, "log.txt")) as log_file: self._run_log = log_file.read() - except FileNotFoundError: - raise ValueError(f"Unable to find logs in the local job directory {self.name}.") + except FileNotFoundError as e: + raise ValueError( + f"Unable to find logs in the local job directory {self.name}." + ) from e return self._run_log def state(self, use_cached_value: bool = False) -> str: """The state of the hybrid job. + Args: use_cached_value (bool): If `True`, uses the value most recently retrieved value from the Amazon Braket `GetJob` operation. If `False`, calls the `GetJob` operation to retrieve metadata, which also updates the cached value. Default = `False`. + Returns: str: Returns "COMPLETED". """ @@ -219,22 +233,23 @@ def state(self, use_cached_value: bool = False) -> str: def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: """When running the hybrid job in local mode, the metadata is not available. + Args: use_cached_value (bool): If `True`, uses the value most recently retrieved from the Amazon Braket `GetJob` operation, if it exists; if does not exist, `GetJob` is called to retrieve the metadata. If `False`, always calls `GetJob`, which also updates the cached value. Default: `False`. + Returns: dict[str, Any]: None """ - pass def cancel(self) -> str: """When running the hybrid job in local mode, the cancelling a running is not possible. + Returns: str: None """ - pass def download_result( self, @@ -253,14 +268,13 @@ def download_result( poll_interval_seconds (float): The polling interval, in seconds, for `result()`. Default: 5 seconds. """ - pass def result( self, poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, ) -> dict[str, Any]: - """Retrieves the hybrid job result persisted using save_job_result() function. + """Retrieves the `LocalQuantumJob` result persisted using `save_job_result` function. Args: poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`. @@ -268,18 +282,23 @@ def result( poll_interval_seconds (float): The polling interval, in seconds, for `result()`. Default: 5 seconds. + Raises: + ValueError: The local job directory does not exist. + Returns: dict[str, Any]: Dict specifying the hybrid job results. """ try: - with open(os.path.join(self.name, "results.json"), "r") as f: + with open(os.path.join(self.name, "results.json")) as f: persisted_data = PersistedJobData.parse_raw(f.read()) deserialized_data = deserialize_values( persisted_data.dataDictionary, persisted_data.dataFormat ) return deserialized_data - except FileNotFoundError: - raise ValueError(f"Unable to find results in the local job directory {self.name}.") + except FileNotFoundError as e: + raise ValueError( + f"Unable to find results in the local job directory {self.name}." + ) from e def metrics( self, diff --git a/src/braket/jobs/local/local_job_container.py b/src/braket/jobs/local/local_job_container.py index ea5625623..6d9d08f4f 100644 --- a/src/braket/jobs/local/local_job_container.py +++ b/src/braket/jobs/local/local_job_container.py @@ -17,12 +17,11 @@ import subprocess from logging import Logger, getLogger from pathlib import PurePosixPath -from typing import Dict, List from braket.aws.aws_session import AwsSession -class _LocalJobContainer(object): +class _LocalJobContainer: """Uses docker CLI to run Braket Hybrid Jobs on a local docker container.""" ECR_URI_PATTERN = r"^((\d+)\.dkr\.ecr\.([^.]+)\.[^/]*)/([^:]*):(.*)$" @@ -39,13 +38,14 @@ def __init__( container. The function "end_session" must be called when the container is no longer needed. + Args: image_uri (str): The URI of the container image to run. aws_session (AwsSession | None): AwsSession for connecting to AWS Services. Default: AwsSession() logger (Logger): Logger object with which to write logs. Default: `getLogger(__name__)` - force_update (bool): Try to update the container, if an update is availble. + force_update (bool): Try to update the container, if an update is available. Default: False """ self._aws_session = aws_session or AwsSession() @@ -65,16 +65,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._end_session() @staticmethod - def _envs_to_list(environment_variables: Dict[str, str]) -> List[str]: + def _envs_to_list(environment_variables: dict[str, str]) -> list[str]: """Converts a dictionary environment variables to a list of parameters that can be passed to the container exec/run commands to ensure those env variables are available in the container. Args: - environment_variables (Dict[str, str]): A dictionary of environment variables and + environment_variables (dict[str, str]): A dictionary of environment variables and their values. + Returns: - List[str]: The list of parameters to use when running a hybrid job that will include the + list[str]: The list of parameters to use when running a hybrid job that will include the provided environment variables as part of the runtime. """ env_list = [] @@ -84,12 +85,12 @@ def _envs_to_list(environment_variables: Dict[str, str]) -> List[str]: return env_list @staticmethod - def _check_output_formatted(command: List[str]) -> str: + def _check_output_formatted(command: list[str]) -> str: """This is a wrapper around the subprocess.check_output command that decodes the output to UTF-8 encoding. Args: - command(List[str]): The command to run. + command(list[str]): The command to run. Returns: str: The UTF-8 encoded output of running the command. @@ -103,6 +104,9 @@ def _login_to_ecr(self, account_id: str, ecr_url: str) -> None: Args: account_id(str): The customer account ID. ecr_url(str): The URL of the ECR repo to log into. + + Raises: + ValueError: Invalid permissions to pull container. """ ecr_client = self._aws_session.ecr_client authorization_data_result = ecr_client.get_authorization_token(registryIds=[account_id]) @@ -121,6 +125,9 @@ def _pull_image(self, image_uri: str) -> None: Args: image_uri(str): The URI of the ECR image to pull. + + Raises: + ValueError: Invalid ECR URL. """ ecr_pattern = re.compile(self.ECR_URI_PATTERN) ecr_pattern_match = ecr_pattern.match(image_uri) @@ -131,8 +138,8 @@ def _pull_image(self, image_uri: str) -> None: "Please pull down the container, or specify a valid ECR URL, " "before proceeding." ) - ecr_url = ecr_pattern_match.group(1) - account_id = ecr_pattern_match.group(2) + ecr_url = ecr_pattern_match[1] + account_id = ecr_pattern_match[2] self._login_to_ecr(account_id, ecr_url) self._logger.warning("Pulling docker container image. This may take a while.") subprocess.run(["docker", "pull", image_uri]) @@ -145,6 +152,9 @@ def _start_container(self, image_uri: str, force_update: bool) -> str: image_uri(str): The URI of the ECR image to run. force_update(bool): Do a docker pull, even if the image is local, in order to update. + Raises: + ValueError: Invalid local image URI. + Returns: str: The name of the running container, which can be used to execute further commands. """ @@ -230,13 +240,16 @@ def copy_from(self, source: str, destination: str) -> None: def run_local_job( self, - environment_variables: Dict[str, str], + environment_variables: dict[str, str], ) -> None: """Runs a Braket Hybrid job in a local container. Args: - environment_variables (Dict[str, str]): The environment variables to make available + environment_variables (dict[str, str]): The environment variables to make available as part of running the hybrid job. + + Raises: + ValueError: `start_program_name` is not found. """ start_program_name = self._check_output_formatted( ["docker", "exec", self._container_name, "printenv", "SAGEMAKER_PROGRAM"] diff --git a/src/braket/jobs/local/local_job_container_setup.py b/src/braket/jobs/local/local_job_container_setup.py index 7505dcbf5..65cef387c 100644 --- a/src/braket/jobs/local/local_job_container_setup.py +++ b/src/braket/jobs/local/local_job_container_setup.py @@ -13,17 +13,18 @@ import json import tempfile +from collections.abc import Iterable from logging import Logger, getLogger from pathlib import Path -from typing import Any, Dict, Iterable +from typing import Any from braket.aws.aws_session import AwsSession from braket.jobs.local.local_job_container import _LocalJobContainer def setup_container( - container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs -) -> Dict[str, str]: + container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs: str +) -> dict[str, str]: """Sets up a container with prerequisites for running a Braket Hybrid Job. The prerequisites are based on the options the customer has chosen for the hybrid job. Similarly, any environment variables that are needed during runtime will be returned by this function. @@ -31,15 +32,16 @@ def setup_container( Args: container(_LocalJobContainer): The container that will run the braket hybrid job. aws_session (AwsSession): AwsSession for connecting to AWS Services. + **creation_kwargs (str): Arbitrary keyword arguments. Returns: - Dict[str, str]: A dictionary of environment variables that reflect Braket Hybrid Jobs + dict[str, str]: A dictionary of environment variables that reflect Braket Hybrid Jobs options requested by the customer. """ logger = getLogger(__name__) _create_expected_paths(container, **creation_kwargs) run_environment_variables = {} - run_environment_variables.update(_get_env_credentials(aws_session, logger)) + run_environment_variables |= _get_env_credentials(aws_session, logger) run_environment_variables.update( _get_env_script_mode_config(creation_kwargs["algorithmSpecification"]["scriptModeConfig"]) ) @@ -51,17 +53,18 @@ def setup_container( return run_environment_variables -def _create_expected_paths(container: _LocalJobContainer, **creation_kwargs) -> None: +def _create_expected_paths(container: _LocalJobContainer, **creation_kwargs: str) -> None: """Creates the basic paths required for Braket Hybrid Jobs to run. Args: container(_LocalJobContainer): The container that will run the braket hybrid job. + **creation_kwargs (str): Arbitrary keyword arguments. """ container.makedir("/opt/ml/model") container.makedir(creation_kwargs["checkpointConfig"]["localPath"]) -def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> Dict[str, str]: +def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> dict[str, str]: """Gets the account credentials from boto so they can be added as environment variables to the running container. @@ -70,7 +73,7 @@ def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> Dict[str, s logger (Logger): Logger object with which to write logs. Default is `getLogger(__name__)` Returns: - Dict[str, str]: The set of key/value pairs that should be added as environment variables + dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ credentials = aws_session.boto_session.get_credentials() @@ -90,15 +93,15 @@ def _get_env_credentials(aws_session: AwsSession, logger: Logger) -> Dict[str, s } -def _get_env_script_mode_config(script_mode_config: Dict[str, str]) -> Dict[str, str]: +def _get_env_script_mode_config(script_mode_config: dict[str, str]) -> dict[str, str]: """Gets the environment variables related to the customer script mode config. Args: - script_mode_config (Dict[str, str]): The values for scriptModeConfig in the boto3 input + script_mode_config (dict[str, str]): The values for scriptModeConfig in the boto3 input parameters for running a Braket Hybrid Job. Returns: - Dict[str, str]: The set of key/value pairs that should be added as environment variables + dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ result = { @@ -110,15 +113,16 @@ def _get_env_script_mode_config(script_mode_config: Dict[str, str]) -> Dict[str, return result -def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs) -> Dict[str, str]: +def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs: str) -> dict[str, str]: """This function gets the remaining 'simple' env variables, that don't require any additional logic to determine what they are or when they should be added as env variables. Args: aws_session (AwsSession): AwsSession for connecting to AWS Services. + **creation_kwargs (str): Arbitrary keyword arguments. Returns: - Dict[str, str]: The set of key/value pairs that should be added as environment variables + dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ job_name = creation_kwargs["jobName"] @@ -135,12 +139,12 @@ def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs) -> Dict[st } -def _get_env_hyperparameters() -> Dict[str, str]: +def _get_env_hyperparameters() -> dict[str, str]: """Gets the env variable for hyperparameters. This should only be added if the customer has provided hyperpameters to the hybrid job. Returns: - Dict[str, str]: The set of key/value pairs that should be added as environment variables + dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ return { @@ -148,12 +152,12 @@ def _get_env_hyperparameters() -> Dict[str, str]: } -def _get_env_input_data() -> Dict[str, str]: +def _get_env_input_data() -> dict[str, str]: """Gets the env variable for input data. This should only be added if the customer has provided input data to the hybrid job. Returns: - Dict[str, str]: The set of key/value pairs that should be added as environment variables + dict[str, str]: The set of key/value pairs that should be added as environment variables to the running container. """ return { @@ -161,12 +165,13 @@ def _get_env_input_data() -> Dict[str, str]: } -def _copy_hyperparameters(container: _LocalJobContainer, **creation_kwargs) -> bool: +def _copy_hyperparameters(container: _LocalJobContainer, **creation_kwargs: str) -> bool: """If hyperpameters are present, this function will store them as a JSON object in the container in the appropriate location on disk. Args: container(_LocalJobContainer): The container to save hyperparameters to. + **creation_kwargs (str): Arbitrary keyword arguments. Returns: bool: True if any hyperparameters were copied to the container. @@ -185,15 +190,20 @@ def _copy_hyperparameters(container: _LocalJobContainer, **creation_kwargs) -> b def _download_input_data( aws_session: AwsSession, download_dir: str, - input_data: Dict[str, Any], + input_data: dict[str, Any], ) -> None: """Downloads input data for a hybrid job. Args: aws_session (AwsSession): AwsSession for connecting to AWS Services. download_dir (str): The directory path to download to. - input_data (Dict[str, Any]): One of the input data in the boto3 input parameters for + input_data (dict[str, Any]): One of the input data in the boto3 input parameters for running a Braket Hybrid Job. + + Raises: + ValueError: File already exists. + RuntimeError: The item is not found. + """ # If s3 prefix is the full name of a directory and all keys are inside # that directory, the contents of said directory will be copied into a @@ -212,8 +222,10 @@ def _download_input_data( found_item = False try: Path(download_dir, channel_name).mkdir() - except FileExistsError: - raise ValueError(f"Duplicate channel names not allowed for input data: {channel_name}") + except FileExistsError as e: + raise ValueError( + f"Duplicate channel names not allowed for input data: {channel_name}" + ) from e for s3_key in s3_keys: relative_key = Path(s3_key).relative_to(top_level) download_path = Path(download_dir, channel_name, relative_key) @@ -243,7 +255,7 @@ def _is_dir(prefix: str, keys: Iterable[str]) -> bool: def _copy_input_data_list( - container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs + container: _LocalJobContainer, aws_session: AwsSession, **creation_kwargs: str ) -> bool: """If the input data list is not empty, this function will download the input files and store them in the container. @@ -251,6 +263,7 @@ def _copy_input_data_list( Args: container (_LocalJobContainer): The container to save input data to. aws_session (AwsSession): AwsSession for connecting to AWS Services. + **creation_kwargs (str): Arbitrary keyword arguments. Returns: bool: True if any input data was copied to the container. diff --git a/src/braket/jobs/logs.py b/src/braket/jobs/logs.py index 734d51123..9aa7dfaca 100644 --- a/src/braket/jobs/logs.py +++ b/src/braket/jobs/logs.py @@ -14,30 +14,31 @@ import collections import os import sys +from collections.abc import Generator ############################################################################## # # Support for reading logs # ############################################################################## -from typing import Dict, List, Optional, Tuple +from typing import ClassVar, Optional from botocore.exceptions import ClientError from braket.aws.aws_session import AwsSession -class ColorWrap(object): +class ColorWrap: """A callable that prints text in a different color depending on the instance. Up to 5 if the standard output is a terminal or a Jupyter notebook cell. """ # For what color each number represents, see # https://misc.flogisoft.com/bash/tip_colors_and_formatting#colors - _stream_colors = [34, 35, 32, 36, 33] + _stream_colors: ClassVar = [34, 35, 32, 36, 33] - def __init__(self, force=False): - """Initialize the class. + def __init__(self, force: bool = False): + """Initialize a `ColorWrap`. Args: force (bool): If True, the render output is colorized wherever the @@ -45,7 +46,7 @@ def __init__(self, force=False): """ self.colorize = force or sys.stdout.isatty() or os.environ.get("JPY_PARENT_PID", None) - def __call__(self, index, s): + def __call__(self, index: int, s: str): """Prints the string, colorized or not, depending on the environment. Args: @@ -73,8 +74,8 @@ def _color_wrap(self, index: int, s: str) -> None: def multi_stream_iter( - aws_session: AwsSession, log_group: str, streams: List[str], positions: Dict[str, Position] -) -> Tuple[int, Dict]: + aws_session: AwsSession, log_group: str, streams: list[str], positions: dict[str, Position] +) -> Generator[tuple[int, dict]]: """Iterates over the available events coming from a set of log streams. Log streams are in a single log group interleaving the events from each stream, so they yield in timestamp order. @@ -82,13 +83,13 @@ def multi_stream_iter( Args: aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. log_group (str): The name of the log group. - streams (List[str]): A list of the log stream names. The the stream number is + streams (list[str]): A list of the log stream names. The the stream number is the position of the stream in this list. - positions (Dict[str, Position]): A list of (timestamp, skip) pairs which represent + positions (dict[str, Position]): A list of (timestamp, skip) pairs which represent the last record read from each stream. Yields: - Tuple[int, Dict]: A tuple of (stream number, cloudwatch log event). + Generator[tuple[int, dict]]: A tuple of (stream number, cloudwatch log event). """ event_iters = [ log_stream(aws_session, log_group, s, positions[s].timestamp, positions[s].skip) @@ -112,7 +113,7 @@ def multi_stream_iter( def log_stream( aws_session: AwsSession, log_group: str, stream_name: str, start_time: int = 0, skip: int = 0 -) -> Dict: +) -> Generator[dict]: """A generator for log items in a single stream. This yields all the items that are available at the current moment. @@ -125,12 +126,11 @@ def log_stream( when there are multiple entries at the same timestamp.) Yields: - Dict: A CloudWatch log event with the following key-value pairs: + Generator[dict]: A CloudWatch log event with the following key-value pairs: 'timestamp' (int): The time of the event. 'message' (str): The log event data. 'ingestionTime' (int): The time the event was ingested. """ - next_token = None event_count = 1 @@ -151,16 +151,15 @@ def log_stream( else: skip = skip - event_count events = [] - for ev in events: - yield ev + yield from events def flush_log_streams( # noqa C901 aws_session: AwsSession, log_group: str, stream_prefix: str, - stream_names: List[str], - positions: Dict[str, Position], + stream_names: list[str], + positions: dict[str, Position], stream_count: int, has_streams: bool, color_wrap: ColorWrap, @@ -173,11 +172,11 @@ def flush_log_streams( # noqa C901 aws_session (AwsSession): The AwsSession for interfacing with CloudWatch. log_group (str): The name of the log group. stream_prefix (str): The prefix for log streams to flush. - stream_names (List[str]): A list of the log stream names. The position of the stream in + stream_names (list[str]): A list of the log stream names. The position of the stream in this list is the stream number. If incomplete, the function will check for remaining streams and mutate this list to add stream names when available, up to the `stream_count` limit. - positions (Dict[str, Position]): A dict mapping stream numbers to (timestamp, skip) pairs + positions (dict[str, Position]): A dict mapping stream numbers to (timestamp, skip) pairs which represent the last record read from each stream. The function will update this list after being called to represent the new last record read from each stream. stream_count (int): The number of streams expected. @@ -189,6 +188,9 @@ def flush_log_streams( # noqa C901 queue_position (Optional[str]): The current queue position. This is not passed in if the job is ran with `quiet=True` + Raises: + Exception: Any exception found besides a ResourceNotFoundException. + Returns: bool: Returns 'True' if any streams have been flushed. """ @@ -208,9 +210,9 @@ def flush_log_streams( # noqa C901 if s["logStreamName"] not in stream_names ] stream_names.extend(new_streams) - positions.update( - [(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions] - ) + positions |= [ + (s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions + ] except ClientError as e: # On the very first training job run on an account, there's no # log group until the container starts logging, so ignore any @@ -219,7 +221,7 @@ def flush_log_streams( # noqa C901 if err.get("Code") != "ResourceNotFoundException": raise - if len(stream_names) > 0: + if stream_names: if not has_streams: print() has_streams = True diff --git a/src/braket/jobs/metrics.py b/src/braket/jobs/metrics.py index 462501cb6..991370f35 100644 --- a/src/braket/jobs/metrics.py +++ b/src/braket/jobs/metrics.py @@ -21,19 +21,17 @@ def log_metric( timestamp: Optional[float] = None, iteration_number: Optional[int] = None, ) -> None: - """ - Records Braket Hybrid Job metrics. + """Records Braket Hybrid Job metrics. Args: - metric_name (str) : The name of the metric. + metric_name (str): The name of the metric. - value (Union[float, int]) : The value of the metric. + value (Union[float, int]): The value of the metric. - timestamp (Optional[float]) : The time the metric data was received, expressed - as the number of seconds - since the epoch. Default: Current system time. + timestamp (Optional[float]): The time the metric data was received, expressed + as the number of seconds since the epoch. Default: Current system time. - iteration_number (Optional[int]) : The iteration number of the metric. + iteration_number (Optional[int]): The iteration number of the metric. """ logged_timestamp = timestamp or time.time() metric_list = [f"Metrics - timestamp={logged_timestamp}; {metric_name}={value};"] diff --git a/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py b/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py index b32cd6b9c..8f5d3dcd5 100644 --- a/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py +++ b/src/braket/jobs/metrics_data/cwl_insights_metrics_fetcher.py @@ -15,7 +15,7 @@ import time from logging import Logger, getLogger -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from braket.aws.aws_session import AwsSession from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType @@ -23,7 +23,7 @@ from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser -class CwlInsightsMetricsFetcher(object): +class CwlInsightsMetricsFetcher: LOG_GROUP_NAME = "/aws/braket/jobs" QUERY_DEFAULT_JOB_DURATION = 3 * 60 * 60 @@ -34,7 +34,8 @@ def __init__( poll_interval_seconds: float = 1, logger: Logger = getLogger(__name__), ): - """ + """Initializes a `CwlInsightsMetricsFetcher`. + Args: aws_session (AwsSession): AwsSession to connect to AWS with. poll_timeout_seconds (float): The polling timeout for retrieving the metrics, @@ -52,32 +53,33 @@ def __init__( @staticmethod def _get_element_from_log_line( - element_name: str, log_line: List[Dict[str, Any]] + element_name: str, log_line: list[dict[str, Any]] ) -> Optional[str]: - """ - Finds and returns an element of a log line from CloudWatch Insights results. + """Finds and returns an element of a log line from CloudWatch Insights results. Args: element_name (str): The element to find. - log_line (List[Dict[str, Any]]): An iterator for RegEx matches on a log line. + log_line (list[dict[str, Any]]): An iterator for RegEx matches on a log line. Returns: - Optional[str] : The value of the element with the element name, or None if no such + Optional[str]: The value of the element with the element name, or None if no such element is found. """ return next( (element["value"] for element in log_line if element["field"] == element_name), None ) - def _get_metrics_results_sync(self, query_id: str) -> List[Any]: - """ - Waits for the CloudWatch Insights query to complete and then returns all the results. + def _get_metrics_results_sync(self, query_id: str) -> list[Any]: + """Waits for the CloudWatch Insights query to complete and then returns all the results. Args: query_id (str): CloudWatch Insights query ID. + Raises: + MetricsRetrievalError: Raised if the query is Failed or Cancelled. + Returns: - List[Any]: The results from CloudWatch insights 'GetQueryResults' operation. + list[Any]: The results from CloudWatch insights 'GetQueryResults' operation. """ timeout_time = time.time() + self._poll_timeout_seconds while time.time() < timeout_time: @@ -92,38 +94,35 @@ def _get_metrics_results_sync(self, query_id: str) -> List[Any]: self._logger.warning(f"Timed out waiting for query {query_id}.") return [] - def _parse_log_line(self, result_entry: List[Dict[str, Any]], parser: LogMetricsParser) -> None: - """ - Parses the single entry from CloudWatch Insights results and adds any metrics it finds + def _parse_log_line(self, result_entry: list[dict[str, Any]], parser: LogMetricsParser) -> None: + """Parses the single entry from CloudWatch Insights results and adds any metrics it finds to 'all_metrics' along with the timestamp for the entry. Args: - result_entry (List[Dict[str, Any]]): A structured result from calling CloudWatch + result_entry (list[dict[str, Any]]): A structured result from calling CloudWatch Insights to get logs that contain metrics. A single entry contains the message (the actual line logged to output), the timestamp (generated by CloudWatch Logs), and other metadata that we (currently) do not use. parser (LogMetricsParser) : The CWL metrics parser. """ - message = self._get_element_from_log_line("@message", result_entry) - if message: + if message := self._get_element_from_log_line("@message", result_entry): timestamp = self._get_element_from_log_line("@timestamp", result_entry) parser.parse_log_message(timestamp, message) def _parse_log_query_results( - self, results: List[Any], metric_type: MetricType, statistic: MetricStatistic - ) -> Dict[str, List[Union[str, float, int]]]: - """ - Parses CloudWatch Insights results and returns all found metrics. + self, results: list[Any], metric_type: MetricType, statistic: MetricStatistic + ) -> dict[str, list[Union[str, float, int]]]: + """Parses CloudWatch Insights results and returns all found metrics. Args: - results (List[Any]): A structured result from calling CloudWatch Insights to get + results (list[Any]): A structured result from calling CloudWatch Insights to get logs that contain metrics. metric_type (MetricType): The type of metrics to get. statistic (MetricStatistic): The statistic to determine which metric value to use when there is a conflict. Returns: - Dict[str, List[Union[str, float, int]]] : The metrics data. + dict[str, list[Union[str, float, int]]]: The metrics data. """ parser = LogMetricsParser() for result in results: @@ -137,9 +136,9 @@ def get_metrics_for_job( statistic: MetricStatistic = MetricStatistic.MAX, job_start_time: int | None = None, job_end_time: int | None = None, - ) -> Dict[str, List[Union[str, float, int]]]: - """ - Synchronously retrieves all the algorithm metrics logged by a given Hybrid Job. + stream_prefix: str | None = None, + ) -> dict[str, list[Union[str, float, int]]]: + """Synchronously retrieves all the algorithm metrics logged by a given Hybrid Job. Args: job_name (str): The name of the Hybrid Job. The name must be exact to ensure only the @@ -151,9 +150,11 @@ def get_metrics_for_job( Default: 3 hours before job_end_time. job_end_time (int | None): If the hybrid job is complete, this should be the time at which the hybrid job finished. Default: current time. + stream_prefix (str | None): If a logs prefix is provided, it will be used instead + of the job name. Returns: - Dict[str, List[Union[str, float, int]]] : The metrics data, where the keys + dict[str, list[Union[str, float, int]]]: The metrics data, where the keys are the column names and the values are a list containing the values in each row. Example: @@ -167,11 +168,11 @@ def get_metrics_for_job( query_end_time = job_end_time or int(time.time()) query_start_time = job_start_time or query_end_time - self.QUERY_DEFAULT_JOB_DURATION - # The hybrid job name needs to be unique to prevent jobs with similar names from being - # conflated. + stream_prefix = (stream_prefix or job_name).replace("/", "\\/") + query = ( f"fields @timestamp, @message " - f"| filter @logStream like /^{job_name}\\// " + f"| filter @logStream like /^{stream_prefix}\\// " f"| filter @message like /Metrics - /" ) diff --git a/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py b/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py index 5e3ef28f2..e8da4ff89 100644 --- a/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py +++ b/src/braket/jobs/metrics_data/cwl_metrics_fetcher.py @@ -13,14 +13,14 @@ import time from logging import Logger, getLogger -from typing import Dict, List, Union +from typing import Union from braket.aws.aws_session import AwsSession from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser -class CwlMetricsFetcher(object): +class CwlMetricsFetcher: LOG_GROUP_NAME = "/aws/braket/jobs" def __init__( @@ -29,7 +29,8 @@ def __init__( poll_timeout_seconds: float = 10, logger: Logger = getLogger(__name__), ): - """ + """Initializes a `CwlMetricsFetcher`. + Args: aws_session (AwsSession): AwsSession to connect to AWS with. poll_timeout_seconds (float): The polling timeout for retrieving the metrics, @@ -44,8 +45,7 @@ def __init__( @staticmethod def _is_metrics_message(message: str) -> bool: - """ - Returns true if a given message is designated as containing Metrics. + """Returns true if a given message is designated as containing Metrics. Args: message (str): The message to check. @@ -53,9 +53,7 @@ def _is_metrics_message(message: str) -> bool: Returns: bool: True if the given message is designated as containing Metrics; False otherwise. """ - if message: - return "Metrics -" in message - return False + return "Metrics -" in message if message else False def _parse_metrics_from_log_stream( self, @@ -63,8 +61,7 @@ def _parse_metrics_from_log_stream( timeout_time: float, parser: LogMetricsParser, ) -> None: - """ - Synchronously retrieves the algorithm metrics logged in a given hybrid job log stream. + """Synchronously retrieves the algorithm metrics logged in a given hybrid job log stream. Args: stream_name (str): The name of the log stream. @@ -93,34 +90,32 @@ def _parse_metrics_from_log_stream( kwargs["nextToken"] = next_token self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") - def _get_log_streams_for_job(self, job_name: str, timeout_time: float) -> List[str]: - """ - Retrieves the list of log streams relevant to a hybrid job. + def _get_log_streams_for_job(self, job_name: str, timeout_time: float) -> list[str]: + """Retrieves the list of log streams relevant to a hybrid job. Args: job_name (str): The name of the hybrid job. timeout_time (float) : Metrics cease getting streamed if the current time exceeds the timeout time. + Returns: - List[str] : A list of log stream names for the given hybrid job. + list[str]: A list of log stream names for the given hybrid job. """ kwargs = { "logGroupName": self.LOG_GROUP_NAME, - "logStreamNamePrefix": job_name + "/algo-", + "logStreamNamePrefix": f"{job_name}/algo-", } log_streams = [] while time.time() < timeout_time: response = self._logs_client.describe_log_streams(**kwargs) - streams = response.get("logStreams") - if streams: + if streams := response.get("logStreams"): for stream in streams: - name = stream.get("logStreamName") - if name: + if name := stream.get("logStreamName"): log_streams.append(name) - next_token = response.get("nextToken") - if not next_token: + if next_token := response.get("nextToken"): + kwargs["nextToken"] = next_token + else: return log_streams - kwargs["nextToken"] = next_token self._logger.warning("Timed out waiting for all metrics. Data may be incomplete.") return log_streams @@ -129,9 +124,8 @@ def get_metrics_for_job( job_name: str, metric_type: MetricType = MetricType.TIMESTAMP, statistic: MetricStatistic = MetricStatistic.MAX, - ) -> Dict[str, List[Union[str, float, int]]]: - """ - Synchronously retrieves all the algorithm metrics logged by a given Hybrid Job. + ) -> dict[str, list[Union[str, float, int]]]: + """Synchronously retrieves all the algorithm metrics logged by a given Hybrid Job. Args: job_name (str): The name of the Hybrid Job. The name must be exact to ensure only the @@ -141,7 +135,7 @@ def get_metrics_for_job( when there is a conflict. Default is MetricStatistic.MAX. Returns: - Dict[str, List[Union[str, float, int]]] : The metrics data, where the keys + dict[str, list[Union[str, float, int]]]: The metrics data, where the keys are the column names and the values are a list containing the values in each row. Example: diff --git a/src/braket/jobs/metrics_data/exceptions.py b/src/braket/jobs/metrics_data/exceptions.py index 677a3a447..41cbf0491 100644 --- a/src/braket/jobs/metrics_data/exceptions.py +++ b/src/braket/jobs/metrics_data/exceptions.py @@ -14,5 +14,3 @@ class MetricsRetrievalError(Exception): """Raised when retrieving metrics fails.""" - - pass diff --git a/src/braket/jobs/metrics_data/log_metrics_parser.py b/src/braket/jobs/metrics_data/log_metrics_parser.py index 7187486c7..1ff5b4d49 100644 --- a/src/braket/jobs/metrics_data/log_metrics_parser.py +++ b/src/braket/jobs/metrics_data/log_metrics_parser.py @@ -12,15 +12,15 @@ # language governing permissions and limitations under the License. import re +from collections.abc import Iterator from logging import Logger, getLogger -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Optional, Union from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType -class LogMetricsParser(object): - """ - This class is used to parse metrics from log lines, and return them in a more +class LogMetricsParser: + """This class is used to parse metrics from log lines, and return them in a more convenient format. """ @@ -43,8 +43,7 @@ def _get_value( new_value: Union[str, float, int], statistic: MetricStatistic, ) -> Union[str, float, int]: - """ - Gets the value based on a statistic. + """Gets the value based on a statistic. Args: current_value (Optional[Union[str, float, int]]): The current value. @@ -64,15 +63,14 @@ def _get_value( def _get_metrics_from_log_line_matches( self, all_matches: Iterator - ) -> Dict[str, Union[str, float, int]]: - """ - Converts matches from a RegEx to a set of metrics. + ) -> dict[str, Union[str, float, int]]: + """Converts matches from a RegEx to a set of metrics. Args: all_matches (Iterator): An iterator for RegEx matches on a log line. Returns: - Dict[str, Union[str, float, int]]: The set of metrics found by the RegEx. The result + dict[str, Union[str, float, int]]: The set of metrics found by the RegEx. The result is in the format { : }. This implies that multiple metrics with the same name are deduped to the last instance of that metric. """ @@ -87,8 +85,7 @@ def _get_metrics_from_log_line_matches( return metrics def parse_log_message(self, timestamp: str, message: str) -> None: - """ - Parses a line from logs, adding all the metrics that have been logged + """Parses a line from logs, adding all the metrics that have been logged on that line. The timestamp is also added to match the corresponding values. Args: @@ -104,26 +101,25 @@ def parse_log_message(self, timestamp: str, message: str) -> None: return if timestamp and self.TIMESTAMP not in parsed_metrics: parsed_metrics[self.TIMESTAMP] = timestamp - node_match = self.NODE_TAG.match(message) - if node_match: + if node_match := self.NODE_TAG.match(message): parsed_metrics[self.NODE_ID] = node_match.group(1) self.all_metrics.append(parsed_metrics) def get_columns_and_pivot_indices( self, pivot: str - ) -> Tuple[Dict[str, List[Union[str, float, int]]], Dict[Tuple[int, str], int]]: - """ - Parses the metrics to find all the metrics that have the pivot column. The values of the + ) -> tuple[dict[str, list[Union[str, float, int]]], dict[tuple[int, str], int]]: + """Parses the metrics to find all the metrics that have the pivot column. The values of the pivot column are paired with the node_id and assigned a row index, so that all metrics with the same pivot value and node_id are stored in the same row. + Args: pivot (str): The name of the pivot column. Must be TIMESTAMP or ITERATION_NUMBER. Returns: - Tuple[Dict[str, List[Union[str, float, int]]], Dict[Tuple[int, str], int]]: Contains: - The Dict[str, List[Any]] is the result table with all the metrics values initialized + tuple[dict[str, list[Union[str, float, int]]], dict[tuple[int, str], int]]: Contains: + The dict[str, list[Any]] is the result table with all the metrics values initialized to None. - The Dict[Tuple[int, str], int] is the list of pivot indices, where the value of a + The dict[tuple[int, str], int] is the list of pivot indices, where the value of a pivot column and node_id is mapped to a row index. """ row_count = 0 @@ -144,9 +140,8 @@ def get_columns_and_pivot_indices( def get_metric_data_with_pivot( self, pivot: str, statistic: MetricStatistic - ) -> Dict[str, List[Union[str, float, int]]]: - """ - Gets the metric data for a given pivot column name. Metrics without the pivot column + ) -> dict[str, list[Union[str, float, int]]]: + """Gets the metric data for a given pivot column name. Metrics without the pivot column are not included in the results. Metrics that have the same value in the pivot column from the same node are returned in the same row. Metrics from different nodes are stored in different rows. If the metric has multiple values for the row, the statistic is used @@ -169,7 +164,7 @@ def get_metric_data_with_pivot( statistic (MetricStatistic): The statistic to determine which value to use. Returns: - Dict[str, List[Union[str, float, int]]] : The metrics data. + dict[str, list[Union[str, float, int]]]: The metrics data. """ table, pivot_indices = self.get_columns_and_pivot_indices(pivot) for metric in self.all_metrics: @@ -184,9 +179,8 @@ def get_metric_data_with_pivot( def get_parsed_metrics( self, metric_type: MetricType, statistic: MetricStatistic - ) -> Dict[str, List[Union[str, float, int]]]: - """ - Gets all the metrics data, where the keys are the column names and the values are a list + ) -> dict[str, list[Union[str, float, int]]]: + """Gets all the metrics data, where the keys are the column names and the values are a list containing the values in each row. Args: @@ -196,7 +190,7 @@ def get_parsed_metrics( when there is a conflict. Returns: - Dict[str, List[Union[str, float, int]]] : The metrics data. + dict[str, list[Union[str, float, int]]]: The metrics data. Example: timestamp energy diff --git a/src/braket/jobs/quantum_job.py b/src/braket/jobs/quantum_job.py index a84118991..32c660bcf 100644 --- a/src/braket/jobs/quantum_job.py +++ b/src/braket/jobs/quantum_job.py @@ -13,7 +13,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType @@ -26,6 +26,7 @@ class QuantumJob(ABC): @abstractmethod def arn(self) -> str: """The ARN (Amazon Resource Name) of the hybrid job. + Returns: str: The ARN (Amazon Resource Name) of the hybrid job. """ @@ -34,6 +35,7 @@ def arn(self) -> str: @abstractmethod def name(self) -> str: """The name of the hybrid job. + Returns: str: The name of the hybrid job. """ @@ -47,6 +49,7 @@ def state(self, use_cached_value: bool = False) -> str: value from the Amazon Braket `GetJob` operation. If `False`, calls the `GetJob` operation to retrieve metadata, which also updates the cached value. Default = `False`. + Returns: str: The value of `status` in `metadata()`. This is the value of the `status` key in the Amazon Braket `GetJob` operation. @@ -95,7 +98,7 @@ def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: # Cloudwatch after the job was marked complete. @abstractmethod - def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: + def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: """Gets the job metadata defined in Amazon Braket. Args: @@ -103,8 +106,9 @@ def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: from the Amazon Braket `GetJob` operation, if it exists; if does not exist, `GetJob` is called to retrieve the metadata. If `False`, always calls `GetJob`, which also updates the cached value. Default: `False`. + Returns: - Dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket. + dict[str, Any]: Dict that specifies the hybrid job metadata defined in Amazon Braket. """ @abstractmethod @@ -112,7 +116,7 @@ def metrics( self, metric_type: MetricType = MetricType.TIMESTAMP, statistic: MetricStatistic = MetricStatistic.MAX, - ) -> Dict[str, List[Any]]: + ) -> dict[str, list[Any]]: """Gets all the metrics data, where the keys are the column names, and the values are a list containing the values in each row. @@ -123,7 +127,7 @@ def metrics( when there is a conflict. Default: MetricStatistic.MAX. Returns: - Dict[str, List[Any]] : The metrics data. + dict[str, list[Any]]: The metrics data. Example: timestamp energy @@ -150,7 +154,7 @@ def result( self, poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Retrieves the hybrid job result persisted using save_job_result() function. Args: @@ -162,7 +166,7 @@ def result( Returns: - Dict[str, Any]: Dict specifying the hybrid job results. + dict[str, Any]: Dict specifying the hybrid job results. Raises: RuntimeError: if hybrid job is in a FAILED or CANCELLED state. diff --git a/src/braket/jobs/quantum_job_creation.py b/src/braket/jobs/quantum_job_creation.py index 9e18faeab..3c4a01b5c 100644 --- a/src/braket/jobs/quantum_job_creation.py +++ b/src/braket/jobs/quantum_job_creation.py @@ -224,7 +224,7 @@ def prepare_quantum_job( "sagemaker_distributed_dataparallel_enabled": "true", "sagemaker_instance_type": instance_config.instanceType, } - hyperparameters.update(distributed_hyperparams) + hyperparameters |= distributed_hyperparams create_job_kwargs = { "jobName": job_name, @@ -241,16 +241,12 @@ def prepare_quantum_job( } if reservation_arn: - create_job_kwargs.update( + create_job_kwargs["associations"] = [ { - "associations": [ - { - "arn": reservation_arn, - "type": "RESERVATION_TIME_WINDOW_ARN", - } - ] + "arn": reservation_arn, + "type": "RESERVATION_TIME_WINDOW_ARN", } - ) + ] return create_job_kwargs @@ -258,8 +254,7 @@ def prepare_quantum_job( def _generate_default_job_name( image_uri: str | None = None, func: Callable | None = None, timestamp: int | str | None = None ) -> str: - """ - Generate default job name using the image uri and entrypoint function. + """Generate default job name using the image uri and entrypoint function. Args: image_uri (str | None): URI for the image container. @@ -269,33 +264,33 @@ def _generate_default_job_name( Returns: str: Hybrid job name. """ - max_length = 50 timestamp = timestamp if timestamp is not None else str(int(time.time() * 1000)) if func: name = func.__name__.replace("_", "-") + max_length = 50 if len(name) + len(timestamp) > max_length: name = name[: max_length - len(timestamp) - 1] warnings.warn( - f"Job name exceeded {max_length} characters. Truncating name to {name}-{timestamp}." + f"Job name exceeded {max_length} characters. " + f"Truncating name to {name}-{timestamp}.", + stacklevel=1, ) + elif not image_uri: + name = "braket-job-default" else: - if not image_uri: - name = "braket-job-default" - else: - job_type_match = re.search("/amazon-braket-(.*)-jobs:", image_uri) or re.search( - "/amazon-braket-([^:/]*)", image_uri - ) - container = f"-{job_type_match.groups()[0]}" if job_type_match else "" - name = f"braket-job{container}" + job_type_match = re.search("/amazon-braket-(.*)-jobs:", image_uri) or re.search( + "/amazon-braket-([^:/]*)", image_uri + ) + container = f"-{job_type_match.groups()[0]}" if job_type_match else "" + name = f"braket-job{container}" return f"{name}-{timestamp}" def _process_s3_source_module( source_module: str, entry_point: str, aws_session: AwsSession, code_location: str ) -> None: - """ - Check that the source module is an S3 URI of the correct type and that entry point is + """Check that the source module is an S3 URI of the correct type and that entry point is provided. Args: @@ -304,6 +299,9 @@ def _process_s3_source_module( aws_session (AwsSession): AwsSession to copy source module to code location. code_location (str): S3 URI pointing to the location where the code will be copied to. + + Raises: + ValueError: The entry point is None or does not end with .tar.gz. """ if entry_point is None: raise ValueError("If source_module is an S3 URI, entry_point must be provided.") @@ -318,9 +316,9 @@ def _process_s3_source_module( def _process_local_source_module( source_module: str, entry_point: str, aws_session: AwsSession, code_location: str ) -> str: - """ - Check that entry point is valid with respect to source module, or provide a default + """Check that entry point is valid with respect to source module, or provide a default value if entry point is not given. Tar and upload source module to code location in S3. + Args: source_module (str): Local path pointing to the source module. entry_point (str): Entry point relative to the source module. @@ -328,14 +326,17 @@ def _process_local_source_module( code_location (str): S3 URI pointing to the location where the code will be uploaded to. + Raises: + ValueError: Raised if the source module file is not found. + Returns: str: Entry point. """ try: # raises FileNotFoundError if not found abs_path_source_module = Path(source_module).resolve(strict=True) - except FileNotFoundError: - raise ValueError(f"Source module not found: {source_module}") + except FileNotFoundError as e: + raise ValueError(f"Source module not found: {source_module}") from e entry_point = entry_point or abs_path_source_module.stem _validate_entry_point(abs_path_source_module, entry_point) @@ -344,12 +345,14 @@ def _process_local_source_module( def _validate_entry_point(source_module_path: Path, entry_point: str) -> None: - """ - Confirm that a valid entry point relative to source module is given. + """Confirm that a valid entry point relative to source module is given. Args: source_module_path (Path): Path to source module. entry_point (str): Entry point relative to source module. + + Raises: + ValueError: Raised if the module was not found. """ importable, _, _method = entry_point.partition(":") sys.path.append(str(source_module_path.parent)) @@ -357,10 +360,10 @@ def _validate_entry_point(source_module_path: Path, entry_point: str) -> None: # second argument allows relative imports importlib.invalidate_caches() module = importlib.util.find_spec(importable, source_module_path.stem) - assert module is not None - # if entry point is nested (ie contains '.'), parent modules are imported - except (ModuleNotFoundError, AssertionError): - raise ValueError(f"Entry point module was not found: {importable}") + if module is None: + raise AssertionError + except (ModuleNotFoundError, AssertionError) as e: + raise ValueError(f"Entry point module was not found: {importable}") from e finally: sys.path.pop() @@ -368,8 +371,7 @@ def _validate_entry_point(source_module_path: Path, entry_point: str) -> None: def _tar_and_upload_to_code_location( source_module_path: Path, aws_session: AwsSession, code_location: str ) -> None: - """ - Tar and upload source module to code location. + """Tar and upload source module to code location. Args: source_module_path (Path): Path to source module. @@ -384,12 +386,14 @@ def _tar_and_upload_to_code_location( def _validate_params(dict_arr: dict[str, tuple[any, any]]) -> None: - """ - Validate that config parameters are of the right type. + """Validate that config parameters are of the right type. Args: dict_arr (dict[str, tuple[any, any]]): dict mapping parameter names to a tuple containing the provided value and expected type. + + Raises: + ValueError: If the user_input is not the same as the expected data type. """ for parameter_name, value_tuple in dict_arr.items(): user_input, expected_datatype = value_tuple @@ -407,8 +411,8 @@ def _process_input_data( aws_session: AwsSession, subdirectory: str, ) -> list[dict[str, Any]]: - """ - Convert input data into a list of dicts compatible with the Braket API. + """Convert input data into a list of dicts compatible with the Braket API. + Args: input_data (str | dict | S3DataSourceConfig): Either a channel definition or a dictionary mapping channel names to channel definitions, where a channel definition @@ -437,8 +441,8 @@ def _process_channel( channel_name: str, subdirectory: str, ) -> S3DataSourceConfig: - """ - Convert a location to an S3DataSourceConfig, uploading local data to S3, if necessary. + """Convert a location to an S3DataSourceConfig, uploading local data to S3, if necessary. + Args: location (str): Local prefix or S3 prefix. job_name (str): Hybrid job name. @@ -451,26 +455,24 @@ def _process_channel( """ if AwsSession.is_s3_uri(location): return S3DataSourceConfig(location) - else: - # local prefix "path/to/prefix" will be mapped to - # s3://bucket/jobs/job-name/subdirectory/data/input/prefix - location_name = Path(location).name - s3_prefix = AwsSession.construct_s3_uri( - aws_session.default_bucket(), - "jobs", - job_name, - subdirectory, - "data", - channel_name, - location_name, - ) - aws_session.upload_local_data(location, s3_prefix) - return S3DataSourceConfig(s3_prefix) + # local prefix "path/to/prefix" will be mapped to + # s3://bucket/jobs/job-name/subdirectory/data/input/prefix + location_name = Path(location).name + s3_prefix = AwsSession.construct_s3_uri( + aws_session.default_bucket(), + "jobs", + job_name, + subdirectory, + "data", + channel_name, + location_name, + ) + aws_session.upload_local_data(location, s3_prefix) + return S3DataSourceConfig(s3_prefix) def _convert_input_to_config(input_data: dict[str, S3DataSourceConfig]) -> list[dict[str, Any]]: - """ - Convert a dictionary mapping channel names to S3DataSourceConfigs into a list of channel + """Convert a dictionary mapping channel names to S3DataSourceConfigs into a list of channel configs compatible with the Braket API. Args: diff --git a/src/braket/jobs/serialization.py b/src/braket/jobs/serialization.py index f8c854d03..179a44970 100644 --- a/src/braket/jobs/serialization.py +++ b/src/braket/jobs/serialization.py @@ -13,26 +13,25 @@ import codecs import pickle -from typing import Any, Dict +from typing import Any from braket.jobs_data import PersistedJobDataFormat def serialize_values( - data_dictionary: Dict[str, Any], data_format: PersistedJobDataFormat -) -> Dict[str, Any]: - """ - Serializes the `data_dictionary` values to the format specified by `data_format`. + data_dictionary: dict[str, Any], data_format: PersistedJobDataFormat +) -> dict[str, Any]: + """Serializes the `data_dictionary` values to the format specified by `data_format`. Args: - data_dictionary (Dict[str, Any]): Dict whose values are to be serialized. + data_dictionary (dict[str, Any]): Dict whose values are to be serialized. data_format (PersistedJobDataFormat): The data format used to serialize the values. Note that for `PICKLED` data formats, the values are base64 encoded after serialization, so that they represent valid UTF-8 text and are compatible with `PersistedJobData.json()`. Returns: - Dict[str, Any]: Dict with same keys as `data_dictionary` and values serialized to + dict[str, Any]: Dict with same keys as `data_dictionary` and values serialized to the specified `data_format`. """ return ( @@ -46,18 +45,17 @@ def serialize_values( def deserialize_values( - data_dictionary: Dict[str, Any], data_format: PersistedJobDataFormat -) -> Dict[str, Any]: - """ - Deserializes the `data_dictionary` values from the format specified by `data_format`. + data_dictionary: dict[str, Any], data_format: PersistedJobDataFormat +) -> dict[str, Any]: + """Deserializes the `data_dictionary` values from the format specified by `data_format`. Args: - data_dictionary (Dict[str, Any]): Dict whose values are to be deserialized. + data_dictionary (dict[str, Any]): Dict whose values are to be deserialized. data_format (PersistedJobDataFormat): The data format that the `data_dictionary` values are currently serialized with. Returns: - Dict[str, Any]: Dict with same keys as `data_dictionary` and values deserialized from + dict[str, Any]: Dict with same keys as `data_dictionary` and values deserialized from the specified `data_format` to plaintext. """ return ( diff --git a/src/braket/parametric/free_parameter.py b/src/braket/parametric/free_parameter.py index db22f5f60..1f3a69e72 100644 --- a/src/braket/parametric/free_parameter.py +++ b/src/braket/parametric/free_parameter.py @@ -22,8 +22,7 @@ class FreeParameter(FreeParameterExpression): - """ - Class 'FreeParameter' + """Class 'FreeParameter' Free parameters can be used in parameterized circuits. Objects that can take a parameter all inherit from :class:'Parameterizable'. The FreeParameter can be swapped in to a circuit @@ -39,8 +38,7 @@ class FreeParameter(FreeParameterExpression): """ def __init__(self, name: str): - """ - Initializes a new :class:'FreeParameter' object. + """Initializes a new :class:'FreeParameter' object. Args: name (str): Name of the :class:'FreeParameter'. Can be a unicode value. @@ -49,19 +47,16 @@ def __init__(self, name: str): >>> param1 = FreeParameter("theta") >>> param1 = FreeParameter("\u03B8") """ - self._name = Symbol(name) + self._set_name(name) super().__init__(expression=self._name) @property def name(self) -> str: - """ - str: Name of this parameter. - """ + """str: Name of this parameter.""" return self._name.name def subs(self, parameter_values: dict[str, Number]) -> Union[FreeParameter, Number]: - """ - Substitutes a value in if the parameter exists within the mapping. + """Substitutes a value in if the parameter exists within the mapping. Args: parameter_values (dict[str, Number]): A mapping of parameter to its @@ -71,7 +66,7 @@ def subs(self, parameter_values: dict[str, Number]) -> Union[FreeParameter, Numb Union[FreeParameter, Number]: The substituted value if this parameter is in parameter_values, otherwise returns self """ - return parameter_values[self.name] if self.name in parameter_values else self + return parameter_values.get(self.name, self) def __str__(self): return str(self.name) @@ -79,20 +74,28 @@ def __str__(self): def __hash__(self) -> int: return hash(tuple(self.name)) - def __eq__(self, other): + def __eq__(self, other: FreeParameter): if isinstance(other, FreeParameter): return self._name == other._name return super().__eq__(other) def __repr__(self) -> str: - """ - The representation of the :class:'FreeParameter'. + """The representation of the :class:'FreeParameter'. Returns: str: The name of the class:'FreeParameter' to represent the class. """ return self.name + def _set_name(self, name: str) -> None: + if not name: + raise ValueError("FreeParameter names must be non empty") + if not isinstance(name, str): + raise TypeError("FreeParameter names must be strings") + if not name[0].isalpha() and name[0] != "_": + raise ValueError("FreeParameter names must start with a letter or an underscore") + self._name = Symbol(name) + def to_dict(self) -> dict: return { "__class__": self.__class__.__name__, diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index cd5fd7f89..fdd2f5474 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -14,15 +14,18 @@ from __future__ import annotations import ast +import operator +from functools import reduce from numbers import Number from typing import Any, Union -from sympy import Expr, Float, Symbol, sympify +import sympy +from oqpy.base import OQPyExpression +from oqpy.classical_types import FloatVar class FreeParameterExpression: - """ - Class 'FreeParameterExpression' + """Class 'FreeParameterExpression' Objects that can take a parameter all inherit from :class:'Parameterizable'. FreeParametersExpressions can hold FreeParameters that can later be @@ -30,9 +33,8 @@ class FreeParameterExpression: present will NOT run. Values must be substituted prior to execution. """ - def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str]): - """ - Initializes a FreeParameterExpression. Best practice is to initialize using + def __init__(self, expression: Union[FreeParameterExpression, Number, sympy.Expr, str]): + """Initializes a FreeParameterExpression. Best practice is to initialize using FreeParameters and Numbers. Not meant to be initialized directly. Below are examples of how FreeParameterExpressions should be made. @@ -40,6 +42,10 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str] Args: expression (Union[FreeParameterExpression, Number, Expr, str]): The expression to use. + Raises: + NotImplementedError: Raised if the expression is not of type + [FreeParameterExpression, Number, Expr, str] + Examples: >>> expression_1 = FreeParameter("theta") * FreeParameter("alpha") >>> expression_2 = 1 + FreeParameter("beta") + 2 * FreeParameter("alpha") @@ -53,7 +59,7 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str] } if isinstance(expression, FreeParameterExpression): self._expression = expression.expression - elif isinstance(expression, (Number, Expr)): + elif isinstance(expression, (Number, sympy.Expr)): self._expression = expression elif isinstance(expression, str): self._expression = self._parse_string_expression(expression).expression @@ -61,8 +67,9 @@ def __init__(self, expression: Union[FreeParameterExpression, Number, Expr, str] raise NotImplementedError @property - def expression(self) -> Union[Number, Expr]: + def expression(self) -> Union[Number, sympy.Expr]: """Gets the expression. + Returns: Union[Number, Expr]: The expression for the FreeParameterExpression. """ @@ -70,7 +77,7 @@ def expression(self) -> Union[Number, Expr]: def subs( self, parameter_values: dict[str, Number] - ) -> Union[FreeParameterExpression, Number, Expr]: + ) -> Union[FreeParameterExpression, Number, sympy.Expr]: """ Similar to a substitution in Sympy. Parameters are swapped for corresponding values or expressions from the dictionary. @@ -83,7 +90,7 @@ def subs( Union[FreeParameterExpression, Number, Expr]: A numerical value if there are no symbols left in the expression otherwise returns a new FreeParameterExpression. """ - new_parameter_values = dict() + new_parameter_values = {} for key, val in parameter_values.items(): if issubclass(type(key), FreeParameterExpression): new_parameter_values[key.expression] = val @@ -100,10 +107,10 @@ def _parse_string_expression(self, expression: str) -> FreeParameterExpression: return self._eval_operation(ast.parse(expression, mode="eval").body) def _eval_operation(self, node: Any) -> FreeParameterExpression: - if isinstance(node, ast.Num): + if isinstance(node, ast.Constant): return FreeParameterExpression(node.n) elif isinstance(node, ast.Name): - return FreeParameterExpression(Symbol(node.id)) + return FreeParameterExpression(sympy.Symbol(node.id)) elif isinstance(node, ast.BinOp): if type(node.op) not in self._operations.keys(): raise ValueError(f"Unsupported binary operation: {type(node.op)}") @@ -117,80 +124,112 @@ def _eval_operation(self, node: Any) -> FreeParameterExpression: else: raise ValueError(f"Unsupported string detected: {node}") - def __add__(self, other): + def __add__(self, other: FreeParameterExpression): if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression + other.expression) else: return FreeParameterExpression(self.expression + other) - def __radd__(self, other): + def __radd__(self, other: FreeParameterExpression): return FreeParameterExpression(other + self.expression) - def __sub__(self, other): + def __sub__(self, other: FreeParameterExpression): if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression - other.expression) else: return FreeParameterExpression(self.expression - other) - def __rsub__(self, other): + def __rsub__(self, other: FreeParameterExpression): return FreeParameterExpression(other - self.expression) - def __mul__(self, other): + def __mul__(self, other: FreeParameterExpression): if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression * other.expression) else: return FreeParameterExpression(self.expression * other) - def __rmul__(self, other): + def __rmul__(self, other: FreeParameterExpression): return FreeParameterExpression(other * self.expression) - def __pow__(self, other, modulo=None): + def __truediv__(self, other): + if issubclass(type(other), FreeParameterExpression): + return FreeParameterExpression(self.expression / other.expression) + else: + return FreeParameterExpression(self.expression / other) + + def __rtruediv__(self, other: FreeParameterExpression): + return FreeParameterExpression(other / self.expression) + + def __pow__(self, other: FreeParameterExpression, modulo: float = None): if issubclass(type(other), FreeParameterExpression): return FreeParameterExpression(self.expression**other.expression) else: return FreeParameterExpression(self.expression**other) - def __rpow__(self, other): + def __rpow__(self, other: FreeParameterExpression): return FreeParameterExpression(other**self.expression) def __neg__(self): return FreeParameterExpression(-1 * self.expression) - def __eq__(self, other): + def __eq__(self, other: FreeParameterExpression): if isinstance(other, FreeParameterExpression): - return sympify(self.expression).equals(sympify(other.expression)) + return sympy.sympify(self.expression).equals(sympy.sympify(other.expression)) return False def __repr__(self) -> str: - """ - The representation of the :class:'FreeParameterExpression'. + """The representation of the :class:'FreeParameterExpression'. Returns: str: The expression of the class:'FreeParameterExpression' to represent the class. """ return repr(self.expression) + def _to_oqpy_expression(self) -> OQPyExpression: + """Transforms into an OQPyExpression. -def subs_if_free_parameter(parameter: Any, **kwargs) -> Any: + Returns: + OQPyExpression: The AST node. + """ + ops = {sympy.Add: operator.add, sympy.Mul: operator.mul, sympy.Pow: operator.pow} + if isinstance(self.expression, tuple(ops)): + return reduce( + ops[type(self.expression)], + map( + lambda x: FreeParameterExpression(x)._to_oqpy_expression(), self.expression.args + ), + ) + elif isinstance(self.expression, sympy.Number): + return float(self.expression) + else: + fvar = FloatVar( + name=self.expression.name, init_expression="input", needs_declaration=False + ) + fvar.size = None + fvar.type.size = None + return fvar + + +def subs_if_free_parameter(parameter: Any, **kwargs: Union[FreeParameterExpression, str]) -> Any: """Substitute a free parameter with the given kwargs, if any. + Args: parameter (Any): The parameter. - ``**kwargs``: The kwargs to use to substitute. + **kwargs (Union[FreeParameterExpression, str]): The kwargs to use to substitute. Returns: Any: The substituted parameters. """ if isinstance(parameter, FreeParameterExpression): substituted = parameter.subs(kwargs) - if isinstance(substituted, Float): + if isinstance(substituted, sympy.Number): substituted = float(substituted) return substituted return parameter def _is_float(argument: str) -> bool: - """ - Checks if a string can be cast into a float. + """Checks if a string can be cast into a float. Args: argument (str): String to check. diff --git a/src/braket/parametric/parameterizable.py b/src/braket/parametric/parameterizable.py index bd5dbc5a7..90c4dc589 100644 --- a/src/braket/parametric/parameterizable.py +++ b/src/braket/parametric/parameterizable.py @@ -21,8 +21,7 @@ class Parameterizable(ABC): - """ - A parameterized object is the abstract definition of an object + """A parameterized object is the abstract definition of an object that can take in FreeParameterExpressions. """ @@ -38,11 +37,13 @@ def parameters(self) -> list[Union[FreeParameterExpression, FreeParameter, float """ @abstractmethod - def bind_values(self, **kwargs) -> Any: - """ - Takes in parameters and returns an object with specified parameters + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> Any: + """Takes in parameters and returns an object with specified parameters replaced with their values. + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. + Returns: Any: The result object will depend on the implementation of the object being bound. """ diff --git a/src/braket/pulse/ast/approximation_parser.py b/src/braket/pulse/ast/approximation_parser.py index 4398d2816..d2dcf65e8 100644 --- a/src/braket/pulse/ast/approximation_parser.py +++ b/src/braket/pulse/ast/approximation_parser.py @@ -15,7 +15,7 @@ from collections import defaultdict from collections.abc import KeysView from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, ClassVar, Optional, Union import numpy as np from openpulse import ast @@ -50,15 +50,16 @@ class _ParseState: class _ApproximationParser(QASMVisitor[_ParseState]): """Walk the AST and build the output signal amplitude, frequency and phases - for each channel.""" + for each channel. + """ - TIME_UNIT_TO_EXP = {"dt": 4, "ns": 3, "us": 2, "ms": 1, "s": 0} + TIME_UNIT_TO_EXP: ClassVar = {"dt": 4, "ns": 3, "us": 2, "ms": 1, "s": 0} def __init__(self, program: Program, frames: dict[str, Frame]): self.amplitudes = defaultdict(TimeSeries) self.frequencies = defaultdict(TimeSeries) self.phases = defaultdict(TimeSeries) - context = _ParseState(variables=dict(), frame_data=_init_frame_data(frames)) + context = _ParseState(variables={}, frame_data=_init_frame_data(frames)) self._qubit_frames_mapping: dict[str, list[str]] = _init_qubit_frame_mapping(frames) self.visit(program.to_ast(include_externs=False), context) @@ -66,11 +67,13 @@ def visit( self, node: Union[ast.QASMNode, ast.Expression], context: Optional[_ParseState] = None ) -> Any: """Visit a node. + Args: node (Union[ast.QASMNode, ast.Expression]): The node to visit. context (Optional[_ParseState]): The parse state context. + Returns: - Any: The parse return value. + Any: The parsed return value. """ return super().visit(node, context) @@ -103,6 +106,7 @@ def _delay_frame(self, frame_id: str, to_delay_time: float, context: _ParseState def visit_Program(self, node: ast.Program, context: _ParseState = None) -> None: """Visit a Program. + Args: node (ast.Program): The program. context (_ParseState): The parse state context. @@ -112,38 +116,49 @@ def visit_Program(self, node: ast.Program, context: _ParseState = None) -> None: def visit_ExpressionStatement(self, node: ast.ExpressionStatement, context: _ParseState) -> Any: """Visit an Expression. + Args: node (ast.ExpressionStatement): The expression. context (_ParseState): The parse state context. + + Returns: + Any: The parsed return value. """ return self.visit(node.expression, context) # need to check def visit_ClassicalDeclaration( self, node: ast.ClassicalDeclaration, context: _ParseState - ) -> None: + ) -> Union[dict, None]: """Visit a Classical Declaration. node.type, node.identifier, node.init_expression angle[20] a = 1+2; waveform wf = []; port a; + Args: node (ast.ClassicalDeclaration): The classical declaration. context (_ParseState): The parse state context. + + Raises: + NotImplementedError: Raised if the node is not a PortType, FrameType, or + WaveformType. + + Returns: + Union[dict, None]: Returns a dict if WaveformType, None otherwise. """ identifier = self.visit(node.identifier, context) if type(node.type) == ast.WaveformType: context.variables[identifier] = self.visit(node.init_expression, context) elif type(node.type) == ast.FrameType: pass - elif type(node.type) == ast.PortType: - pass - else: + elif type(node.type) != ast.PortType: raise NotImplementedError def visit_DelayInstruction(self, node: ast.DelayInstruction, context: _ParseState) -> None: """Visit a Delay Instruction. node.duration, node.qubits delay[100ns] $0; + Args: node (ast.DelayInstruction): The classical declaration. context (_ParseState): The parse state context. @@ -154,7 +169,7 @@ def visit_DelayInstruction(self, node: ast.DelayInstruction, context: _ParseStat # barrier without arguments is applied to all the frames of the context frames = list(context.frame_data.keys()) dts = [context.frame_data[frame_id].dt for frame_id in frames] - max_time = max([context.frame_data[frame_id].current_time for frame_id in frames]) + max_time = max(context.frame_data[frame_id].current_time for frame_id in frames) # All frames are delayed till the first multiple of the LCM([port.dts]) # after the longest time of all considered frames lcm = _lcm_floats(*dts) @@ -168,16 +183,20 @@ def visit_QuantumBarrier(self, node: ast.QuantumBarrier, context: _ParseState) - barrier $0; barrier; barrier frame, frame1; + Args: node (ast.QuantumBarrier): The quantum barrier. context (_ParseState): The parse state context. + + Returns: + None: No return value. """ frames = self._get_frame_parameters(node.qubits, context) if len(frames) == 0: # barrier without arguments is applied to all the frames of the context frames = list(context.frame_data.keys()) dts = [context.frame_data[frame_id].dt for frame_id in frames] - max_time = max([context.frame_data[frame_id].current_time for frame_id in frames]) + max_time = max(context.frame_data[frame_id].current_time for frame_id in frames) # All frames are delayed till the first multiple of the LCM([port.dts]) # after the longest time of all considered frames lcm = _lcm_floats(*dts) @@ -190,9 +209,13 @@ def visit_FunctionCall(self, node: ast.FunctionCall, context: _ParseState) -> An """Visit a Quantum Barrier. node.name, node.arguments f(args,arg2) + Args: node (ast.FunctionCall): The function call. context (_ParseState): The parse state context. + + Returns: + Any: The parsed return value. """ func_name = node.name.name return getattr(self, func_name)(node, context) @@ -204,6 +227,9 @@ def visit_Identifier(self, node: ast.Identifier, context: _ParseState) -> Any: Args: node (ast.Identifier): The identifier. context (_ParseState): The parse state context. + + Returns: + Any: The parsed return value. """ if node.name in context.variables: return context.variables[node.name] @@ -214,9 +240,16 @@ def visit_UnaryExpression(self, node: ast.UnaryExpression, context: _ParseState) """Visit Unary Expression. node.op, node.expression ~ ! - + Args: node (ast.UnaryExpression): The unary expression. context (_ParseState): The parse state context. + + Returns: + bool: The parsed boolean operator. + + Raises: + NotImplementedError: Raised for unsupported boolean operators. """ if node.op == ast.UnaryOperator["-"]: return -1 * self.visit(node.expression, context) @@ -234,9 +267,17 @@ def visit_BinaryExpression(self, node: ast.BinaryExpression, context: _ParseStat 1+2 a.b > < >= <= == != && || | ^ & << >> + - * / % ** . + Args: node (ast.BinaryExpression): The binary expression. context (_ParseState): The parse state context. + + Raises: + NotImplementedError: Raised if the binary operator is not in + [> < >= <= == != && || | ^ & << >> + - * / % ** ] + + Returns: + Any: The parsed binary operator. """ lhs = self.visit(node.lhs, context) rhs = self.visit(node.rhs, context) @@ -284,63 +325,85 @@ def visit_BinaryExpression(self, node: ast.BinaryExpression, context: _ParseStat else: raise NotImplementedError - def visit_ArrayLiteral(self, node: ast.ArrayLiteral, context: _ParseState) -> Any: + def visit_ArrayLiteral(self, node: ast.ArrayLiteral, context: _ParseState) -> list[Any]: """Visit Array Literal. node.values {1,2,4} + Args: node (ast.ArrayLiteral): The array literal. context (_ParseState): The parse state context. + + Returns: + list[Any]: The parsed ArrayLiteral. """ return [self.visit(e, context) for e in node.values] - def visit_IntegerLiteral(self, node: ast.IntegerLiteral, context: _ParseState) -> Any: + def visit_IntegerLiteral(self, node: ast.IntegerLiteral, context: _ParseState) -> int: """Visit Integer Literal. node.value 1 Args: node (ast.IntegerLiteral): The integer literal. context (_ParseState): The parse state context. + + Returns: + int: The parsed int value. """ return int(node.value) - def visit_ImaginaryLiteral(self, node: ast.ImaginaryLiteral, context: _ParseState) -> Any: + def visit_ImaginaryLiteral(self, node: ast.ImaginaryLiteral, context: _ParseState) -> complex: """Visit Imaginary Number Literal. node.value 1.3im Args: - node (ast.visit_ImaginaryLiteral): The imaginary number literal. + node (ast.ImaginaryLiteral): The imaginary number literal. context (_ParseState): The parse state context. + + Returns: + complex: The parsed complex value. """ return complex(node.value * 1j) - def visit_FloatLiteral(self, node: ast.FloatLiteral, context: _ParseState) -> Any: + def visit_FloatLiteral(self, node: ast.FloatLiteral, context: _ParseState) -> float: """Visit Float Literal. node.value 1.1 Args: node (ast.FloatLiteral): The float literal. context (_ParseState): The parse state context. + + Returns: + float: The parsed float value. """ return float(node.value) - def visit_BooleanLiteral(self, node: ast.BooleanLiteral, context: _ParseState) -> Any: + def visit_BooleanLiteral(self, node: ast.BooleanLiteral, context: _ParseState) -> bool: """Visit Boolean Literal. node.value true Args: node (ast.BooleanLiteral): The boolean literal. context (_ParseState): The parse state context. + + Returns: + bool: The parsed boolean value. """ - return True if node.value else False + return bool(node.value) - def visit_DurationLiteral(self, node: ast.DurationLiteral, context: _ParseState) -> Any: + def visit_DurationLiteral(self, node: ast.DurationLiteral, context: _ParseState) -> float: """Visit Duration Literal. node.value, node.unit (node.unit.name, node.unit.value) 1 Args: node (ast.DurationLiteral): The duration literal. context (_ParseState): The parse state context. + + Raises: + ValueError: Raised based on time unit not being in `self.TIME_UNIT_TO_EXP`. + + Returns: + float: The duration represented as a float """ if node.unit.name not in self.TIME_UNIT_TO_EXP: raise ValueError(f"Unexpected duration specified: {node.unit.name}:{node.unit.value}") @@ -351,6 +414,7 @@ def visit_DurationLiteral(self, node: ast.DurationLiteral, context: _ParseState) def set_frequency(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'set_frequency' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. @@ -361,6 +425,7 @@ def set_frequency(self, node: ast.FunctionCall, context: _ParseState) -> None: def shift_frequency(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'shift_frequency' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. @@ -371,6 +436,7 @@ def shift_frequency(self, node: ast.FunctionCall, context: _ParseState) -> None: def set_phase(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'set_phase' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. @@ -381,6 +447,7 @@ def set_phase(self, node: ast.FunctionCall, context: _ParseState) -> None: def shift_phase(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'shift_phase' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. @@ -392,6 +459,7 @@ def shift_phase(self, node: ast.FunctionCall, context: _ParseState) -> None: def set_scale(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'set_scale' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. @@ -402,17 +470,25 @@ def set_scale(self, node: ast.FunctionCall, context: _ParseState) -> None: def capture_v0(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'capture_v0' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. """ - pass def play(self, node: ast.FunctionCall, context: _ParseState) -> None: """A 'play' Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. + + Raises: + NotImplementedError: Raises if not of type + [ast.Identifier, ast.FunctionCall, ast.ArrayLiteral] + + Returns: + None: Returns None """ frame_id = self.visit(node.arguments[0], context) if isinstance(node.arguments[1], ast.ArrayLiteral): @@ -436,9 +512,11 @@ def play(self, node: ast.FunctionCall, context: _ParseState) -> None: def constant(self, node: ast.FunctionCall, context: _ParseState) -> Waveform: """A 'constant' Waveform Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. + Returns: Waveform: The waveform object representing the function call. """ @@ -447,9 +525,11 @@ def constant(self, node: ast.FunctionCall, context: _ParseState) -> Waveform: def gaussian(self, node: ast.FunctionCall, context: _ParseState) -> Waveform: """A 'gaussian' Waveform Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. + Returns: Waveform: The waveform object representing the function call. """ @@ -458,9 +538,11 @@ def gaussian(self, node: ast.FunctionCall, context: _ParseState) -> Waveform: def drag_gaussian(self, node: ast.FunctionCall, context: _ParseState) -> Waveform: """A 'drag_gaussian' Waveform Function call. + Args: node (ast.FunctionCall): The function call node. context (_ParseState): The parse state. + Returns: Waveform: The waveform object representing the function call. """ @@ -469,17 +551,16 @@ def drag_gaussian(self, node: ast.FunctionCall, context: _ParseState) -> Wavefor def _init_frame_data(frames: dict[str, Frame]) -> dict[str, _FrameState]: - frame_states = dict() - for frameId, frame in frames.items(): - frame_states[frameId] = _FrameState( - frame.port.dt, frame.frequency, frame.phase % (2 * np.pi) - ) + frame_states = { + frameId: _FrameState(frame.port.dt, frame.frequency, frame.phase % (2 * np.pi)) + for frameId, frame in frames.items() + } return frame_states def _init_qubit_frame_mapping(frames: dict[str, Frame]) -> dict[str, list[str]]: mapping = {} - for frameId in frames.keys(): + for frameId in frames: if m := ( re.search(r"q(\d+)_q(\d+)_[a-z_]+", frameId) or re.search(r"[rq](\d+)_[a-z_]+", frameId) ): @@ -500,8 +581,10 @@ def _lcm_floats(*dts: list[float]) -> float: Args: *dts (list[float]): list of time resolutions - """ + Returns: + float: The LCM of time increments for a list of frames. + """ sample_rates = [round(1 / dt) for dt in dts] res_gcd = sample_rates[0] for sr in sample_rates[1:]: diff --git a/src/braket/pulse/ast/free_parameters.py b/src/braket/pulse/ast/free_parameters.py index 1581ddd88..1750275ac 100644 --- a/src/braket/pulse/ast/free_parameters.py +++ b/src/braket/pulse/ast/free_parameters.py @@ -11,60 +11,85 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import operator from typing import Union from openpulse import ast -from openqasm3.ast import DurationLiteral from openqasm3.visitor import QASMTransformer - -from braket.parametric.free_parameter_expression import FreeParameterExpression - - -class _FreeParameterExpressionIdentifier(ast.Identifier): - """Dummy AST node with FreeParameterExpression instance attached""" - - def __init__(self, expression: FreeParameterExpression): - super().__init__(name=f"FreeParameterExpression({expression})") - self._expression = expression - - @property - def expression(self) -> FreeParameterExpression: - return self._expression +from oqpy.program import Program +from oqpy.timing import OQDurationLiteral class _FreeParameterTransformer(QASMTransformer): """Walk the AST and evaluate FreeParameterExpressions.""" - def __init__(self, param_values: dict[str, float]): + def __init__(self, param_values: dict[str, float], program: Program): self.param_values = param_values + self.program = program super().__init__() - def visit__FreeParameterExpressionIdentifier( + def visit_Identifier( self, identifier: ast.Identifier - ) -> Union[_FreeParameterExpressionIdentifier, ast.FloatLiteral]: - """Visit a FreeParameterExpressionIdentifier. + ) -> Union[ast.Identifier, ast.FloatLiteral]: + """Visit an Identifier. + + If the Identifier is used to hold a `FreeParameterExpression`, it will be simplified + using the given parameter values. + Args: identifier (Identifier): The identifier. Returns: - Union[_FreeParameterExpressionIdentifier, FloatLiteral]: The transformed expression. + Union[Identifier, FloatLiteral]: The transformed identifier. + """ + if identifier.name in self.param_values: + return ast.FloatLiteral(float(self.param_values[identifier.name])) + return identifier + + def visit_BinaryExpression( + self, node: ast.BinaryExpression + ) -> Union[ast.BinaryExpression, ast.FloatLiteral]: + """Visit a BinaryExpression. + + Visit the operands and simplify if they are literals. + + Args: + node (BinaryExpression): The node. + + Returns: + Union[BinaryExpression, FloatLiteral]: The transformed identifier. """ - new_value = identifier.expression.subs(self.param_values) - if isinstance(new_value, FreeParameterExpression): - return _FreeParameterExpressionIdentifier(new_value) - else: - return ast.FloatLiteral(new_value) - - def visit_DurationLiteral(self, duration_literal: DurationLiteral) -> DurationLiteral: - """Visit Duration Literal. - node.value, node.unit (node.unit.name, node.unit.value) - 1 + lhs = self.visit(node.lhs) + rhs = self.visit(node.rhs) + if isinstance(lhs, ast.FloatLiteral): + ops = { + ast.BinaryOperator["+"]: operator.add, + ast.BinaryOperator["*"]: operator.mul, + ast.BinaryOperator["**"]: operator.pow, + } + if isinstance(rhs, ast.FloatLiteral): + return ast.FloatLiteral(ops[node.op](lhs.value, rhs.value)) + elif isinstance(rhs, ast.DurationLiteral) and node.op == ast.BinaryOperator["*"]: + return OQDurationLiteral(lhs.value * rhs.value).to_ast(self.program) + return ast.BinaryExpression(op=node.op, lhs=lhs, rhs=rhs) + + def visit_UnaryExpression( + self, node: ast.UnaryExpression + ) -> Union[ast.UnaryExpression, ast.FloatLiteral]: + """Visit an UnaryExpression. + + Visit the operand and simplify if it is a literal. + Args: - duration_literal (DurationLiteral): The duration literal. + node (UnaryExpression): The node. + Returns: - DurationLiteral: The transformed duration literal. + Union[UnaryExpression, FloatLiteral]: The transformed identifier. """ - duration = duration_literal.value - if not isinstance(duration, FreeParameterExpression): - return duration_literal - return DurationLiteral(duration.subs(self.param_values), duration_literal.unit) + expression = self.visit(node.expression) + if ( + isinstance(expression, (ast.FloatLiteral, ast.DurationLiteral)) + and node.op == ast.UnaryOperator["-"] + ): + return type(expression)(-expression.value) + return ast.UnaryExpression(op=node.op, expression=node.expression) # pragma: no cover diff --git a/src/braket/pulse/ast/qasm_parser.py b/src/braket/pulse/ast/qasm_parser.py index bd1b26e40..25832aa58 100644 --- a/src/braket/pulse/ast/qasm_parser.py +++ b/src/braket/pulse/ast/qasm_parser.py @@ -15,11 +15,8 @@ from openpulse import ast from openpulse.printer import Printer -from openqasm3.ast import DurationLiteral from openqasm3.printer import PrinterState -from braket.parametric.free_parameter_expression import FreeParameterExpression - class _PulsePrinter(Printer): """Walks the AST and prints it to an OpenQASM3 string.""" @@ -27,29 +24,13 @@ class _PulsePrinter(Printer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def visit__FreeParameterExpressionIdentifier( - self, node: ast.Identifier, context: PrinterState - ) -> None: - """Visit a FreeParameterExpressionIdentifier. + def visit_Identifier(self, node: ast.Identifier, context: PrinterState) -> None: + """Visit an Identifier. Args: node (ast.Identifier): The identifier. context (PrinterState): The printer state context. """ - self.stream.write(str(node.expression.expression)) - - def visit_DurationLiteral(self, node: DurationLiteral, context: PrinterState) -> None: - """Visit Duration Literal. - node.value, node.unit (node.unit.name, node.unit.value) - 1 - Args: - node (ast.DurationLiteral): The duration literal. - context (PrinterState): The printer state context. - """ - duration = node.value - if isinstance(duration, FreeParameterExpression): - self.stream.write(f"({duration.expression}){node.unit.name}") - else: - super().visit_DurationLiteral(node, context) + self.stream.write(str(node.name)) def visit_ClassicalDeclaration( self, node: ast.ClassicalDeclaration, context: PrinterState @@ -59,6 +40,7 @@ def visit_ClassicalDeclaration( angle[20] a = 1+2; waveform wf = []; port a; + Args: node (ast.ClassicalDeclaration): The classical declaration. context (PrinterState): The printer state context. @@ -72,7 +54,7 @@ def ast_to_qasm(ast: ast.Program) -> str: """Converts an AST program to OpenQASM Args: - ast (Program): The AST program. + ast (ast.Program): The AST program. Returns: str: a str representing the OpenPulse program encoding the program. diff --git a/src/braket/pulse/ast/qasm_transformer.py b/src/braket/pulse/ast/qasm_transformer.py index f5e350883..40a6d25d5 100644 --- a/src/braket/pulse/ast/qasm_transformer.py +++ b/src/braket/pulse/ast/qasm_transformer.py @@ -18,8 +18,7 @@ class _IRQASMTransformer(QASMTransformer): - """ - QASMTransformer which walks the AST and makes the necessary modifications needed + """QASMTransformer which walks the AST and makes the necessary modifications needed for IR generation. Currently, it performs the following operations: * Replaces capture_v0 function calls with assignment statements, assigning the readout value to a bit register element. @@ -32,28 +31,29 @@ def __init__(self, register_identifier: Optional[str] = None): def visit_ExpressionStatement(self, expression_statement: ast.ExpressionStatement) -> Any: """Visit an Expression. + Args: expression_statement (ast.ExpressionStatement): The expression statement. + Returns: Any: The expression statement. """ if ( - isinstance(expression_statement.expression, ast.FunctionCall) - and expression_statement.expression.name.name == "capture_v0" - and self._register_identifier + not isinstance(expression_statement.expression, ast.FunctionCall) + or expression_statement.expression.name.name != "capture_v0" + or not self._register_identifier ): - # For capture_v0 nodes, it replaces it with classical assignment statements - # of the form: - # b[0] = capture_v0(...) - # b[1] = capture_v0(...) - new_val = ast.ClassicalAssignment( - # Ideally should use IndexedIdentifier here, but this works since it is just - # for printing. - ast.Identifier(name=f"{self._register_identifier}[{self._capture_v0_count}]"), - ast.AssignmentOperator["="], - expression_statement.expression, - ) - self._capture_v0_count += 1 - return new_val - else: return expression_statement + # For capture_v0 nodes, it replaces it with classical assignment statements + # of the form: + # b[0] = capture_v0(...) + # b[1] = capture_v0(...) + new_val = ast.ClassicalAssignment( + # Ideally should use IndexedIdentifier here, but this works since it is just + # for printing. + ast.Identifier(name=f"{self._register_identifier}[{self._capture_v0_count}]"), + ast.AssignmentOperator["="], + expression_statement.expression, + ) + self._capture_v0_count += 1 + return new_val diff --git a/src/braket/pulse/frame.py b/src/braket/pulse/frame.py index b84e4a440..63700d22e 100644 --- a/src/braket/pulse/frame.py +++ b/src/braket/pulse/frame.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import annotations + import math from typing import Any, Optional @@ -21,9 +23,9 @@ class Frame: - """ - Frame tracks the frame of reference, when interacting with the qubits, throughout the execution - of a program. See https://openqasm.com/language/openpulse.html#frames for more details. + """Frame tracks the frame of reference, when interacting with the qubits, throughout the + execution of a program. See https://openqasm.com/language/openpulse.html#frames for more + details. """ def __init__( @@ -35,7 +37,8 @@ def __init__( is_predefined: bool = False, properties: Optional[dict[str, Any]] = None, ): - """ + """Initializes a Frame. + Args: frame_id (str): str identifying a unique frame. port (Port): port that this frame is attached to. @@ -58,7 +61,7 @@ def id(self) -> str: """Returns a str indicating the frame id.""" return self._frame_id - def __eq__(self, other) -> bool: + def __eq__(self, other: Frame) -> bool: return ( ( (self.id == other.id) diff --git a/src/braket/pulse/port.py b/src/braket/pulse/port.py index 2b1760415..99b1acca5 100644 --- a/src/braket/pulse/port.py +++ b/src/braket/pulse/port.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import annotations + from typing import Any, Optional from oqpy import PortVar @@ -18,13 +20,13 @@ class Port: - """ - Ports represent any input or output component meant to manipulate and observe qubits on + """Ports represent any input or output component meant to manipulate and observe qubits on a device. See https://openqasm.com/language/openpulse.html#ports for more details. """ def __init__(self, port_id: str, dt: float, properties: Optional[dict[str, Any]] = None): - """ + """Initializes a Port. + Args: port_id (str): str identifying a unique port on the device. dt (float): The smallest time step that may be used on the control hardware. @@ -45,7 +47,7 @@ def dt(self) -> float: """Returns the smallest time step that may be used on the control hardware.""" return self._dt - def __eq__(self, other) -> bool: + def __eq__(self, other: Port) -> bool: return self.id == other.id if isinstance(other, Port) else False def _to_oqpy_expression(self) -> OQPyExpression: diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 98ea3c381..9d43127a0 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -20,16 +20,13 @@ from openpulse import ast from oqpy import BitVar, PhysicalQubits, Program -from oqpy.timing import OQDurationLiteral +from sympy import Expr from braket.parametric.free_parameter import FreeParameter from braket.parametric.free_parameter_expression import FreeParameterExpression from braket.parametric.parameterizable import Parameterizable from braket.pulse.ast.approximation_parser import _ApproximationParser -from braket.pulse.ast.free_parameters import ( - _FreeParameterExpressionIdentifier, - _FreeParameterTransformer, -) +from braket.pulse.ast.free_parameters import _FreeParameterTransformer from braket.pulse.ast.qasm_parser import ast_to_qasm from braket.pulse.ast.qasm_transformer import _IRQASMTransformer from braket.pulse.frame import Frame @@ -39,14 +36,13 @@ class PulseSequence: - """ - A representation of a collection of instructions to be performed on a quantum device + """A representation of a collection of instructions to be performed on a quantum device and the requested results. """ def __init__(self): self._capture_v0_count = 0 - self._program = Program() + self._program = Program(simplify_constants=False) self._frames = {} self._waveforms = {} self._free_parameters = set() @@ -74,8 +70,7 @@ def parameters(self) -> set[FreeParameter]: def set_frequency( self, frame: Frame, frequency: Union[float, FreeParameterExpression] ) -> PulseSequence: - """ - Adds an instruction to set the frequency of the frame to the specified `frequency` value. + """Adds an instruction to set the frequency of the frame to the specified `frequency` value. Args: frame (Frame): Frame for which the frequency needs to be set. @@ -85,17 +80,17 @@ def set_frequency( Returns: PulseSequence: self, with the instruction added. """ - _validate_uniqueness(self._frames, frame) - self._program.set_frequency(frame=frame, freq=self._format_parameter_ast(frequency)) + self._register_free_parameters(frequency) + self._program.set_frequency(frame=frame, freq=frequency) self._frames[frame.id] = frame return self def shift_frequency( self, frame: Frame, frequency: Union[float, FreeParameterExpression] ) -> PulseSequence: - """ - Adds an instruction to shift the frequency of the frame by the specified `frequency` value. + """Adds an instruction to shift the frequency of the frame by the specified `frequency` + value. Args: frame (Frame): Frame for which the frequency needs to be shifted. @@ -106,15 +101,15 @@ def shift_frequency( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.shift_frequency(frame=frame, freq=self._format_parameter_ast(frequency)) + self._register_free_parameters(frequency) + self._program.shift_frequency(frame=frame, freq=frequency) self._frames[frame.id] = frame return self def set_phase( self, frame: Frame, phase: Union[float, FreeParameterExpression] ) -> PulseSequence: - """ - Adds an instruction to set the phase of the frame to the specified `phase` value. + """Adds an instruction to set the phase of the frame to the specified `phase` value. Args: frame (Frame): Frame for which the frequency needs to be set. @@ -125,15 +120,15 @@ def set_phase( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.set_phase(frame=frame, phase=self._format_parameter_ast(phase)) + self._register_free_parameters(phase) + self._program.set_phase(frame=frame, phase=phase) self._frames[frame.id] = frame return self def shift_phase( self, frame: Frame, phase: Union[float, FreeParameterExpression] ) -> PulseSequence: - """ - Adds an instruction to shift the phase of the frame by the specified `phase` value. + """Adds an instruction to shift the phase of the frame by the specified `phase` value. Args: frame (Frame): Frame for which the phase needs to be shifted. @@ -144,15 +139,15 @@ def shift_phase( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.shift_phase(frame=frame, phase=self._format_parameter_ast(phase)) + self._register_free_parameters(phase) + self._program.shift_phase(frame=frame, phase=phase) self._frames[frame.id] = frame return self def set_scale( self, frame: Frame, scale: Union[float, FreeParameterExpression] ) -> PulseSequence: - """ - Adds an instruction to set the scale on the frame to the specified `scale` value. + """Adds an instruction to set the scale on the frame to the specified `scale` value. Args: frame (Frame): Frame for which the scale needs to be set. @@ -163,7 +158,8 @@ def set_scale( PulseSequence: self, with the instruction added. """ _validate_uniqueness(self._frames, frame) - self._program.set_scale(frame=frame, scale=self._format_parameter_ast(scale)) + self._register_free_parameters(scale) + self._program.set_scale(frame=frame, scale=scale) self._frames[frame.id] = frame return self @@ -172,21 +168,18 @@ def delay( qubits_or_frames: Union[Frame, list[Frame], QubitSet], duration: Union[float, FreeParameterExpression], ) -> PulseSequence: - """ - Adds an instruction to advance the frame clock by the specified `duration` value. + """Adds an instruction to advance the frame clock by the specified `duration` value. Args: qubits_or_frames (Union[Frame, list[Frame], QubitSet]): Qubits or frame(s) on which the delay needs to be introduced. duration (Union[float, FreeParameterExpression]): value (in seconds) defining the duration of the delay. + Returns: PulseSequence: self, with the instruction added. """ - if isinstance(duration, FreeParameterExpression): - for p in duration.expression.free_symbols: - self._free_parameters.add(FreeParameter(p.name)) - duration = OQDurationLiteral(duration) + self._register_free_parameters(duration) if not isinstance(qubits_or_frames, QubitSet): if not isinstance(qubits_or_frames, list): qubits_or_frames = [qubits_or_frames] @@ -195,13 +188,12 @@ def delay( for frame in qubits_or_frames: self._frames[frame.id] = frame else: - physical_qubits = list(PhysicalQubits[int(x)] for x in qubits_or_frames) + physical_qubits = [PhysicalQubits[int(x)] for x in qubits_or_frames] self._program.delay(time=duration, qubits_or_frames=physical_qubits) return self def barrier(self, qubits_or_frames: Union[list[Frame], QubitSet]) -> PulseSequence: - """ - Adds an instruction to align the frame clocks to the latest time across all the specified + """Adds an instruction to align the frame clocks to the latest time across all the specified frames. Args: @@ -217,7 +209,7 @@ def barrier(self, qubits_or_frames: Union[list[Frame], QubitSet]) -> PulseSequen for frame in qubits_or_frames: self._frames[frame.id] = frame else: - physical_qubits = list(PhysicalQubits[int(x)] for x in qubits_or_frames) + physical_qubits = [PhysicalQubits[int(x)] for x in qubits_or_frames] self._program.barrier(qubits_or_frames=physical_qubits) return self @@ -234,19 +226,16 @@ def play(self, frame: Frame, waveform: Waveform) -> PulseSequence: """ _validate_uniqueness(self._frames, frame) _validate_uniqueness(self._waveforms, waveform) - self._program.play(frame=frame, waveform=waveform) if isinstance(waveform, Parameterizable): for param in waveform.parameters: - if isinstance(param, FreeParameterExpression): - for p in param.expression.free_symbols: - self._free_parameters.add(FreeParameter(p.name)) + self._register_free_parameters(param) + self._program.play(frame=frame, waveform=waveform) self._frames[frame.id] = frame self._waveforms[waveform.id] = waveform return self def capture_v0(self, frame: Frame) -> PulseSequence: - """ - Adds an instruction to capture the bit output from measuring the specified frame. + """Adds an instruction to capture the bit output from measuring the specified frame. Args: frame (Frame): Frame on which the capture operation needs @@ -262,8 +251,7 @@ def capture_v0(self, frame: Frame) -> PulseSequence: return self def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequence: - """ - Binds FreeParameters based upon their name and values passed in. If parameters + """Binds FreeParameters based upon their name and values passed in. If parameters share the same name, all the parameters of that name will be set to the mapped value. Args: @@ -276,11 +264,13 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ """ program = deepcopy(self._program) tree: ast.Program = program.to_ast(include_externs=False, ignore_needs_declaration=True) - new_tree: ast.Program = _FreeParameterTransformer(param_values).visit(tree) + new_tree: ast.Program = _FreeParameterTransformer(param_values, program).visit(tree) - new_program = Program() + new_program = Program(simplify_constants=False) new_program.declared_vars = program.declared_vars new_program.undeclared_vars = program.undeclared_vars + for param_name in param_values: + new_program.undeclared_vars.pop(param_name, None) for x in new_tree.statements: new_program._add_statement(x) @@ -300,19 +290,32 @@ def make_bound_pulse_sequence(self, param_values: dict[str, float]) -> PulseSequ ]._to_oqpy_expression() new_pulse_sequence._capture_v0_count = self._capture_v0_count - new_pulse_sequence._free_parameters = set( - [p for p in self._free_parameters if p.name not in param_values] - ) + new_pulse_sequence._free_parameters = { + p for p in self._free_parameters if p.name not in param_values + } return new_pulse_sequence - def to_ir(self) -> str: + def to_ir(self, sort_input_parameters: bool = False) -> str: """Converts this OpenPulse problem into IR representation. + Args: + sort_input_parameters (bool): whether input parameters should be printed + in a sorted order. Defaults to False. + Returns: str: a str representing the OpenPulse program encoding the PulseSequence. """ program = deepcopy(self._program) + program.autodeclare(encal=False) + parameters = ( + sorted(self.parameters, key=lambda p: p.name, reverse=True) + if sort_input_parameters + else self.parameters + ) + for param in parameters: + program.declare(param._to_oqpy_expression(), to_beginning=True) + if self._capture_v0_count: register_identifier = "psb" program.declare( @@ -324,24 +327,25 @@ def to_ir(self) -> str: tree = program.to_ast(encal=True, include_externs=False) return ast_to_qasm(tree) - def _format_parameter_ast( - self, parameter: Union[float, FreeParameterExpression] - ) -> Union[float, _FreeParameterExpressionIdentifier]: - if isinstance(parameter, FreeParameterExpression): + def _register_free_parameters( + self, + parameter: Union[float, FreeParameterExpression], + ) -> None: + if isinstance(parameter, FreeParameterExpression) and isinstance( + parameter.expression, Expr + ): for p in parameter.expression.free_symbols: self._free_parameters.add(FreeParameter(p.name)) - return _FreeParameterExpressionIdentifier(parameter) - return parameter def _parse_arg_from_calibration_schema( self, argument: dict, waveforms: dict[Waveform], frames: dict[Frame] ) -> Any: nonprimitive_arg_type = { - "frame": getattr(frames, "get"), - "waveform": getattr(waveforms, "get"), + "frame": frames.get, + "waveform": waveforms.get, "expr": FreeParameterExpression, } - if argument["type"] in nonprimitive_arg_type.keys(): + if argument["type"] in nonprimitive_arg_type: return nonprimitive_arg_type[argument["type"]](argument["value"]) else: return getattr(builtins, argument["type"])(argument["value"]) @@ -350,67 +354,68 @@ def _parse_arg_from_calibration_schema( def _parse_from_calibration_schema( cls, calibration: dict, waveforms: dict[Waveform], frames: dict[Frame] ) -> PulseSequence: - """ - Parsing a JSON input based on https://github.com/aws/amazon-braket-schemas-python/blob/main/src/braket/device_schema/pulse/native_gate_calibrations_v1.py#L26. + """Parsing a JSON input based on https://github.com/aws/amazon-braket-schemas-python/blob/main/src/braket/device_schema/pulse/native_gate_calibrations_v1.py#L26. # noqa: E501 Args: calibration (dict): The pulse instruction to parse waveforms (dict[Waveform]): The waveforms supplied for the pulse sequences. frames (dict[Frame]): A dictionary of frame objects to use. + Raises: + ValueError: If the requested instruction has not been implemented for pulses. + Returns: PulseSequence: The parse sequence obtain from parsing a pulse instruction. - """ # noqa: E501 + """ calibration_sequence = cls() for instr in calibration: - if hasattr(PulseSequence, f"{instr['name']}"): - instr_function = getattr(calibration_sequence, instr["name"]) - instr_args_keys = signature(instr_function).parameters.keys() - instr_args = {} - if instr["arguments"] is not None: - for argument in instr["arguments"]: - if argument["name"] in {"qubit", "frame"} and instr["name"] in { - "barrier", - "delay", - }: - argument_value = ( - [frames[argument["value"]]] - if argument["name"] == "frame" - else instr_args.get("qubits_or_frames", QubitSet()) - ) - # QubitSet is an IndexedSet so the ordering matters - if argument["name"] == "frame": - argument_value = ( - instr_args.get("qubits_or_frames", []) + argument_value - ) - else: - argument_value.update(QubitSet(int(argument["value"]))) - instr_args["qubits_or_frames"] = argument_value - elif argument["name"] in instr_args_keys: - instr_args[argument["name"]] = ( - calibration_sequence._parse_arg_from_calibration_schema( - argument, waveforms, frames - ) + if not hasattr(PulseSequence, f"{instr['name']}"): + raise ValueError(f"The {instr['name']} instruction has not been implemented") + instr_function = getattr(calibration_sequence, instr["name"]) + instr_args_keys = signature(instr_function).parameters.keys() + instr_args = {} + if instr["arguments"] is not None: + for argument in instr["arguments"]: + if argument["name"] in {"qubit", "frame"} and instr["name"] in { + "barrier", + "delay", + }: + argument_value = ( + [frames[argument["value"]]] + if argument["name"] == "frame" + else instr_args.get("qubits_or_frames", QubitSet()) + ) + # QubitSet is an IndexedSet so the ordering matters + if argument["name"] == "frame": + argument_value = instr_args.get("qubits_or_frames", []) + argument_value + else: + argument_value.update(QubitSet(int(argument["value"]))) + instr_args["qubits_or_frames"] = argument_value + elif argument["name"] in instr_args_keys: + instr_args[argument["name"]] = ( + calibration_sequence._parse_arg_from_calibration_schema( + argument, waveforms, frames ) - else: - instr_args["qubits_or_frames"] = [] - instr_function(**instr_args) + ) else: - raise ValueError(f"The {instr['name']} instruction has not been implemented") + instr_args["qubits_or_frames"] = [] + instr_function(**instr_args) return calibration_sequence - def __call__(self, arg: Any | None = None, **kwargs) -> PulseSequence: - """ - Implements the call function to easily make a bound PulseSequence. + def __call__( + self, arg: Any | None = None, **kwargs: Union[FreeParameter, str] + ) -> PulseSequence: + """Implements the call function to easily make a bound PulseSequence. Args: arg (Any | None): A value to bind to all parameters. Defaults to None and can be overridden if the parameter is in kwargs. + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. Returns: PulseSequence: A pulse sequence with the specified parameters bound. """ - param_values = dict() + param_values = {} if arg is not None: for param in self.parameters: param_values[str(param)] = arg @@ -418,11 +423,12 @@ def __call__(self, arg: Any | None = None, **kwargs) -> PulseSequence: param_values[str(key)] = val return self.make_bound_pulse_sequence(param_values) - def __eq__(self, other): + def __eq__(self, other: PulseSequence): + sort_input_parameters = True return ( isinstance(other, PulseSequence) and self.parameters == other.parameters - and self.to_ir() == other.to_ir() + and self.to_ir(sort_input_parameters) == other.to_ir(sort_input_parameters) ) diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index dbf89e146..915d187a8 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -21,7 +21,6 @@ import numpy as np from oqpy import WaveformVar, bool_, complex128, declare_waveform_generator, duration, float64 from oqpy.base import OQPyExpression -from oqpy.timing import OQDurationLiteral from braket.parametric.free_parameter import FreeParameter from braket.parametric.free_parameter_expression import ( @@ -29,15 +28,13 @@ subs_if_free_parameter, ) from braket.parametric.parameterizable import Parameterizable -from braket.pulse.ast.free_parameters import _FreeParameterExpressionIdentifier class Waveform(ABC): - """ - A waveform is a time-dependent envelope that can be used to emit signals on an output port + """A waveform is a time-dependent envelope that can be used to emit signals on an output port or receive signals from an input port. As such, when transmitting signals to the qubit, a frame determines time at which the waveform envelope is emitted, its carrier frequency, and - it’s phase offset. When capturing signals from a qubit, at minimum a frame determines the + it's phase offset. When capturing signals from a qubit, at minimum a frame determines the time at which the signal is captured. See https://openqasm.com/language/openpulse.html#waveforms for more details. """ @@ -49,32 +46,35 @@ def _to_oqpy_expression(self) -> OQPyExpression: @abstractmethod def sample(self, dt: float) -> np.ndarray: """Generates a sample of amplitudes for this Waveform based on the given time resolution. + Args: dt (float): The time resolution. + Returns: - ndarray: The sample amplitudes for this waveform. + np.ndarray: The sample amplitudes for this waveform. """ @staticmethod @abstractmethod def _from_calibration_schema(waveform_json: dict) -> Waveform: - """ - Parses a JSON input and returns the BDK waveform. See https://github.com/aws/amazon-braket-schemas-python/blob/main/src/braket/device_schema/pulse/native_gate_calibrations_v1.py#L104 + """Parses a JSON input and returns the BDK waveform. See https://github.com/aws/amazon-braket-schemas-python/blob/main/src/braket/device_schema/pulse/native_gate_calibrations_v1.py#L104 # noqa: E501 Args: waveform_json (dict): A JSON object with the needed parameters for making the Waveform. Returns: Waveform: A Waveform object parsed from the supplied JSON. - """ # noqa: E501 + """ class ArbitraryWaveform(Waveform): """An arbitrary waveform with amplitudes at each timestep explicitly specified using - an array.""" + an array. + """ def __init__(self, amplitudes: list[complex], id: Optional[str] = None): - """ + """Initializes an `ArbitraryWaveform`. + Args: amplitudes (list[complex]): Array of complex values specifying the waveform amplitude at each timestep. The timestep is determined by the sampling rate @@ -85,7 +85,10 @@ def __init__(self, amplitudes: list[complex], id: Optional[str] = None): self.amplitudes = list(amplitudes) self.id = id or _make_identifier_name() - def __eq__(self, other): + def __repr__(self) -> str: + return f"ArbitraryWaveform('id': {self.id}, 'amplitudes': {self.amplitudes})" + + def __eq__(self, other: ArbitraryWaveform): return isinstance(other, ArbitraryWaveform) and (self.amplitudes, self.id) == ( other.amplitudes, other.id, @@ -93,6 +96,7 @@ def __eq__(self, other): def _to_oqpy_expression(self) -> OQPyExpression: """Returns an OQPyExpression defining this waveform. + Returns: OQPyExpression: The OQPyExpression. """ @@ -100,10 +104,15 @@ def _to_oqpy_expression(self) -> OQPyExpression: def sample(self, dt: float) -> np.ndarray: """Generates a sample of amplitudes for this Waveform based on the given time resolution. + Args: dt (float): The time resolution. + + Raises: + NotImplementedError: This class does not implement sample. + Returns: - ndarray: The sample amplitudes for this waveform. + np.ndarray: The sample amplitudes for this waveform. """ raise NotImplementedError @@ -116,12 +125,14 @@ def _from_calibration_schema(waveform_json: dict) -> ArbitraryWaveform: class ConstantWaveform(Waveform, Parameterizable): """A constant waveform which holds the supplied `iq` value as its amplitude for the - specified length.""" + specified length. + """ def __init__( self, length: Union[float, FreeParameterExpression], iq: complex, id: Optional[str] = None ): - """ + """Initializes a `ConstantWaveform`. + Args: length (Union[float, FreeParameterExpression]): Value (in seconds) specifying the duration of the waveform. @@ -133,16 +144,26 @@ def __init__( self.iq = iq self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return f"ConstantWaveform('id': {self.id}, 'length': {self.length}, 'iq': {self.iq})" + @property def parameters(self) -> list[Union[FreeParameterExpression, FreeParameter, float]]: """Returns the parameters associated with the object, either unbound free parameter - expressions or bound values.""" + expressions or bound values. + + Returns: + list[Union[FreeParameterExpression, FreeParameter, float]]: a list of parameters. + """ return [self.length] - def bind_values(self, **kwargs) -> ConstantWaveform: + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> ConstantWaveform: """Takes in parameters and returns an object with specified parameters replaced with their values. + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. + Returns: ConstantWaveform: A copy of this waveform with the requested parameters bound. """ @@ -153,7 +174,7 @@ def bind_values(self, **kwargs) -> ConstantWaveform: } return ConstantWaveform(**constructor_kwargs) - def __eq__(self, other): + def __eq__(self, other: ConstantWaveform): return isinstance(other, ConstantWaveform) and (self.length, self.iq, self.id) == ( other.length, other.iq, @@ -162,6 +183,7 @@ def __eq__(self, other): def _to_oqpy_expression(self) -> OQPyExpression: """Returns an OQPyExpression defining this waveform. + Returns: OQPyExpression: The OQPyExpression. """ @@ -169,16 +191,18 @@ def _to_oqpy_expression(self) -> OQPyExpression: "constant", [("length", duration), ("iq", complex128)] ) return WaveformVar( - init_expression=constant_generator(_map_to_oqpy_type(self.length, True), self.iq), + init_expression=constant_generator(self.length, self.iq), name=self.id, ) def sample(self, dt: float) -> np.ndarray: """Generates a sample of amplitudes for this Waveform based on the given time resolution. + Args: dt (float): The time resolution. + Returns: - ndarray: The sample amplitudes for this waveform. + np.ndarray: The sample amplitudes for this waveform. """ # Amplitudes should be gated by [0:self.length] sample_range = np.arange(0, self.length, dt) @@ -217,7 +241,8 @@ def __init__( zero_at_edges: bool = False, id: Optional[str] = None, ): - """ + """Initializes a `DragGaussianWaveform`. + Args: length (Union[float, FreeParameterExpression]): Value (in seconds) specifying the duration of the waveform. @@ -238,16 +263,27 @@ def __init__( self.zero_at_edges = zero_at_edges self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return ( + f"DragGaussianWaveform('id': {self.id}, 'length': {self.length}, " + f"'sigma': {self.sigma}, 'beta': {self.beta}, 'amplitude': {self.amplitude}, " + f"'zero_at_edges': {self.zero_at_edges})" + ) + @property def parameters(self) -> list[Union[FreeParameterExpression, FreeParameter, float]]: """Returns the parameters associated with the object, either unbound free parameter - expressions or bound values.""" + expressions or bound values. + """ return [self.length, self.sigma, self.beta, self.amplitude] - def bind_values(self, **kwargs) -> DragGaussianWaveform: + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> DragGaussianWaveform: """Takes in parameters and returns an object with specified parameters replaced with their values. + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. + Returns: DragGaussianWaveform: A copy of this waveform with the requested parameters bound. """ @@ -261,7 +297,7 @@ def bind_values(self, **kwargs) -> DragGaussianWaveform: } return DragGaussianWaveform(**constructor_kwargs) - def __eq__(self, other): + def __eq__(self, other: DragGaussianWaveform): return isinstance(other, DragGaussianWaveform) and ( self.length, self.sigma, @@ -273,6 +309,7 @@ def __eq__(self, other): def _to_oqpy_expression(self) -> OQPyExpression: """Returns an OQPyExpression defining this waveform. + Returns: OQPyExpression: The OQPyExpression. """ @@ -288,10 +325,10 @@ def _to_oqpy_expression(self) -> OQPyExpression: ) return WaveformVar( init_expression=drag_gaussian_generator( - _map_to_oqpy_type(self.length, True), - _map_to_oqpy_type(self.sigma, True), - _map_to_oqpy_type(self.beta), - _map_to_oqpy_type(self.amplitude), + self.length, + self.sigma, + self.beta, + self.amplitude, self.zero_at_edges, ), name=self.id, @@ -299,10 +336,12 @@ def _to_oqpy_expression(self) -> OQPyExpression: def sample(self, dt: float) -> np.ndarray: """Generates a sample of amplitudes for this Waveform based on the given time resolution. + Args: dt (float): The time resolution. + Returns: - ndarray: The sample amplitudes for this waveform. + np.ndarray: The sample amplitudes for this waveform. """ sample_range = np.arange(0, self.length, dt) t0 = self.length / 2 @@ -343,7 +382,8 @@ def __init__( zero_at_edges: bool = False, id: Optional[str] = None, ): - """ + """Initializes a `GaussianWaveform`. + Args: length (Union[float, FreeParameterExpression]): Value (in seconds) specifying the duration of the waveform. @@ -362,16 +402,26 @@ def __init__( self.zero_at_edges = zero_at_edges self.id = id or _make_identifier_name() + def __repr__(self) -> str: + return ( + f"GaussianWaveform('id': {self.id}, 'length': {self.length}, 'sigma': {self.sigma}, " + f"'amplitude': {self.amplitude}, 'zero_at_edges': {self.zero_at_edges})" + ) + @property def parameters(self) -> list[Union[FreeParameterExpression, FreeParameter, float]]: """Returns the parameters associated with the object, either unbound free parameter - expressions or bound values.""" + expressions or bound values. + """ return [self.length, self.sigma, self.amplitude] - def bind_values(self, **kwargs) -> GaussianWaveform: + def bind_values(self, **kwargs: Union[FreeParameter, str]) -> GaussianWaveform: """Takes in parameters and returns an object with specified parameters replaced with their values. + Args: + **kwargs (Union[FreeParameter, str]): Arbitrary keyword arguments. + Returns: GaussianWaveform: A copy of this waveform with the requested parameters bound. """ @@ -384,7 +434,7 @@ def bind_values(self, **kwargs) -> GaussianWaveform: } return GaussianWaveform(**constructor_kwargs) - def __eq__(self, other): + def __eq__(self, other: GaussianWaveform): return isinstance(other, GaussianWaveform) and ( self.length, self.sigma, @@ -395,6 +445,7 @@ def __eq__(self, other): def _to_oqpy_expression(self) -> OQPyExpression: """Returns an OQPyExpression defining this waveform. + Returns: OQPyExpression: The OQPyExpression. """ @@ -409,9 +460,9 @@ def _to_oqpy_expression(self) -> OQPyExpression: ) return WaveformVar( init_expression=gaussian_generator( - _map_to_oqpy_type(self.length, True), - _map_to_oqpy_type(self.sigma, True), - _map_to_oqpy_type(self.amplitude), + self.length, + self.sigma, + self.amplitude, self.zero_at_edges, ), name=self.id, @@ -419,10 +470,12 @@ def _to_oqpy_expression(self) -> OQPyExpression: def sample(self, dt: float) -> np.ndarray: """Generates a sample of amplitudes for this Waveform based on the given time resolution. + Args: dt (float): The time resolution. + Returns: - ndarray: The sample amplitudes for this waveform. + np.ndarray: The sample amplitudes for this waveform. """ sample_range = np.arange(0, self.length, dt) t0 = self.length / 2 @@ -449,19 +502,7 @@ def _from_calibration_schema(waveform_json: dict) -> GaussianWaveform: def _make_identifier_name() -> str: - return "".join([random.choice(string.ascii_letters) for _ in range(10)]) - - -def _map_to_oqpy_type( - parameter: Union[FreeParameterExpression, float], is_duration_type: bool = False -) -> Union[_FreeParameterExpressionIdentifier, OQPyExpression]: - if isinstance(parameter, FreeParameterExpression): - return ( - OQDurationLiteral(parameter) - if is_duration_type - else _FreeParameterExpressionIdentifier(parameter) - ) - return parameter + return "".join([random.choice(string.ascii_letters) for _ in range(10)]) # noqa S311 def _parse_waveform_from_calibration_schema(waveform: dict) -> Waveform: @@ -471,10 +512,9 @@ def _parse_waveform_from_calibration_schema(waveform: dict) -> Waveform: "gaussian": GaussianWaveform._from_calibration_schema, "constant": ConstantWaveform._from_calibration_schema, } - if "amplitudes" in waveform.keys(): + if "amplitudes" in waveform: waveform["name"] = "arbitrary" if waveform["name"] in waveform_names: return waveform_names[waveform["name"]](waveform) - else: - id = waveform["waveformId"] - raise ValueError(f"The waveform {id} of cannot be constructed") + waveform_id = waveform["waveformId"] + raise ValueError(f"The waveform {waveform_id} of cannot be constructed") diff --git a/src/braket/quantum_information/pauli_string.py b/src/braket/quantum_information/pauli_string.py index f1ca55406..0de8e8e3b 100644 --- a/src/braket/quantum_information/pauli_string.py +++ b/src/braket/quantum_information/pauli_string.py @@ -17,7 +17,7 @@ from typing import Optional, Union from braket.circuits.circuit import Circuit -from braket.circuits.observables import TensorProduct, X, Y, Z +from braket.circuits.observables import I, TensorProduct, X, Y, Z _IDENTITY = "I" _PAULI_X = "X" @@ -29,22 +29,25 @@ "Y": {"X": ["Z", -1j], "Z": ["X", 1j]}, "Z": {"X": ["Y", 1j], "Y": ["X", -1j]}, } +_ID_OBS = I() _PAULI_OBSERVABLES = {_PAULI_X: X(), _PAULI_Y: Y(), _PAULI_Z: Z()} _SIGN_MAP = {"+": 1, "-": -1} class PauliString: - """ - A lightweight representation of a Pauli string with its phase. - """ + """A lightweight representation of a Pauli string with its phase.""" def __init__(self, pauli_string: Union[str, PauliString]): - """ + """Initializes a `PauliString`. + Args: pauli_string (Union[str, PauliString]): The representation of the pauli word, either a string or another PauliString object. A valid string consists of an optional phase, specified by an optional sign +/- followed by an uppercase string in {I, X, Y, Z}. Example valid strings are: XYZ, +YIZY, -YX + + Raises: + ValueError: If the Pauli String is empty. """ if not pauli_string: raise ValueError("pauli_string must not be empty") @@ -74,14 +77,29 @@ def qubit_count(self) -> int: """int: The number of qubits this Pauli string acts on.""" return self._qubit_count - def to_unsigned_observable(self) -> TensorProduct: + def to_unsigned_observable(self, include_trivial: bool = False) -> TensorProduct: """Returns the observable corresponding to the unsigned part of the Pauli string. For example, for a Pauli string -XYZ, the corresponding observable is X ⊗ Y ⊗ Z. + Args: + include_trivial (bool): Whether to include explicit identity factors in the observable. + Default: False. + Returns: TensorProduct: The tensor product of the unsigned factors in the Pauli string. """ + if include_trivial: + return TensorProduct( + [ + ( + _PAULI_OBSERVABLES[self._nontrivial[qubit]] + if qubit in self._nontrivial + else _ID_OBS + ) + for qubit in range(self._qubit_count) + ] + ) return TensorProduct( [_PAULI_OBSERVABLES[self._nontrivial[qubit]] for qubit in sorted(self._nontrivial)] ) @@ -190,7 +208,7 @@ def dot(self, other: PauliString, inplace: bool = False) -> PauliString: # ignore complex global phase if phase_result.real < 0 or phase_result.imag < 0: - pauli_result = "-" + pauli_result + pauli_result = f"-{pauli_result}" out_pauli_string = PauliString(pauli_result) if inplace: @@ -327,7 +345,7 @@ def to_circuit(self) -> Circuit: circ = circ.z(qubit) return circ - def __eq__(self, other): + def __eq__(self, other: PauliString): if isinstance(other, PauliString): return ( self._phase == other._phase @@ -336,7 +354,7 @@ def __eq__(self, other): ) return False - def __getitem__(self, item): + def __getitem__(self, item: int): if item >= self._qubit_count: raise IndexError(item) return _PAULI_INDICES[self._nontrivial.get(item, "I")] diff --git a/src/braket/registers/qubit.py b/src/braket/registers/qubit.py index 479a453e3..4c91ebd25 100644 --- a/src/braket/registers/qubit.py +++ b/src/braket/registers/qubit.py @@ -20,19 +20,22 @@ class Qubit(int): - """ - A quantum bit index. The index of this qubit is locally scoped towards the contained + """A quantum bit index. The index of this qubit is locally scoped towards the contained circuit. This may not be the exact qubit index on the quantum device. """ - def __new__(cls, index: int): - """ + def __new__(cls, index: int) -> Qubit: + """Creates a new `Qubit`. + Args: index (int): Index of the qubit. Raises: ValueError: If `index` is less than zero. + Returns: + Qubit: Returns a new Qubit object. + Examples: >>> Qubit(0) >>> Qubit(1) @@ -51,8 +54,7 @@ def __str__(self): @staticmethod def new(qubit: QubitInput) -> Qubit: - """ - Helper constructor - if input is a `Qubit` it returns the same value, + """Helper constructor - if input is a `Qubit` it returns the same value, else a new `Qubit` is constructed. Args: @@ -61,8 +63,4 @@ def new(qubit: QubitInput) -> Qubit: Returns: Qubit: The qubit. """ - - if isinstance(qubit, Qubit): - return qubit - else: - return Qubit(qubit) + return qubit if isinstance(qubit, Qubit) else Qubit(qubit) diff --git a/src/braket/registers/qubit_set.py b/src/braket/registers/qubit_set.py index 0f9d0a7ac..e7bed6aa8 100644 --- a/src/braket/registers/qubit_set.py +++ b/src/braket/registers/qubit_set.py @@ -32,8 +32,7 @@ def _flatten(other: Any) -> Any: class QubitSet(IndexedSet): - """ - An ordered, unique set of quantum bits. + """An ordered, unique set of quantum bits. Note: QubitSet implements `__hash__()` but is a mutable object, therefore be careful when @@ -41,7 +40,8 @@ class QubitSet(IndexedSet): """ def __init__(self, qubits: QubitSetInput | None = None): - """ + """Initializes a `QubitSet`. + Args: qubits (QubitSetInput | None): Qubits to be included in the `QubitSet`. Default is `None`. @@ -63,13 +63,11 @@ def __init__(self, qubits: QubitSetInput | None = None): Qubit(2) Qubit(3) """ - _qubits = [Qubit.new(qubit) for qubit in _flatten(qubits)] if qubits is not None else None super().__init__(_qubits) def map(self, mapping: dict[QubitInput, QubitInput]) -> QubitSet: - """ - Creates a new `QubitSet` where this instance's qubits are mapped to the values in `mapping`. + """Creates a new `QubitSet` where this instance's qubits are mapped to the values in `mapping`. If this instance contains a qubit that is not in the `mapping` that qubit is not modified. Args: @@ -85,8 +83,7 @@ def map(self, mapping: dict[QubitInput, QubitInput]) -> QubitSet: >>> mapping = {0: 10, Qubit(1): Qubit(11)} >>> qubits.map(mapping) QubitSet([Qubit(10), Qubit(11)]) - """ - + """ # noqa E501 new_qubits = [mapping.get(qubit, qubit) for qubit in self] return QubitSet(new_qubits) diff --git a/src/braket/tasks/__init__.py b/src/braket/tasks/__init__.py index bb6cc6e77..d40b0547c 100644 --- a/src/braket/tasks/__init__.py +++ b/src/braket/tasks/__init__.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import braket.ipython_utils as ipython_utils +from braket import ipython_utils from braket.tasks.analog_hamiltonian_simulation_quantum_task_result import ( # noqa: F401 AnalogHamiltonianSimulationQuantumTaskResult, AnalogHamiltonianSimulationShotStatus, diff --git a/src/braket/tasks/analog_hamiltonian_simulation_quantum_task_result.py b/src/braket/tasks/analog_hamiltonian_simulation_quantum_task_result.py index abc39753d..7bfb57eb3 100644 --- a/src/braket/tasks/analog_hamiltonian_simulation_quantum_task_result.py +++ b/src/braket/tasks/analog_hamiltonian_simulation_quantum_task_result.py @@ -38,7 +38,7 @@ class ShotResult: pre_sequence: np.ndarray = None post_sequence: np.ndarray = None - def __eq__(self, other) -> bool: + def __eq__(self, other: ShotResult) -> bool: if isinstance(other, ShotResult): return ( self.status == other.status @@ -54,7 +54,7 @@ class AnalogHamiltonianSimulationQuantumTaskResult: additional_metadata: AdditionalMetadata measurements: list[ShotResult] = None - def __eq__(self, other) -> bool: + def __eq__(self, other: AnalogHamiltonianSimulationQuantumTaskResult) -> bool: if isinstance(other, AnalogHamiltonianSimulationQuantumTaskResult): return ( self.task_metadata.id == other.task_metadata.id @@ -116,9 +116,8 @@ def get_counts(self) -> dict[str, int]: Returns: dict[str, int]: number of times each state configuration is measured. Returns None if none of shot measurements are successful. - Only succesful shots contribute to the state count. + Only successful shots contribute to the state count. """ - state_counts = Counter() states = ["e", "r", "g"] for shot in self.measurements: @@ -129,8 +128,8 @@ def get_counts(self) -> dict[str, int]: state_idx = [ 0 if pre_i == 0 else 1 if post_i == 0 else 2 for pre_i, post_i in zip(pre, post) ] - state = "".join(map(lambda s_idx: states[s_idx], state_idx)) - state_counts.update((state,)) + state = "".join(states[s_idx] for s_idx in state_idx) + state_counts.update([state]) return dict(state_counts) @@ -138,9 +137,8 @@ def get_avg_density(self) -> np.ndarray: """Get the average Rydberg state densities from the result Returns: - ndarray: The average densities from the result + np.ndarray: The average densities from the result """ - counts = self.get_counts() N_ryd, N_ground = [], [] diff --git a/src/braket/tasks/annealing_quantum_task_result.py b/src/braket/tasks/annealing_quantum_task_result.py index 1a5fbc2cb..bf5e9900d 100644 --- a/src/braket/tasks/annealing_quantum_task_result.py +++ b/src/braket/tasks/annealing_quantum_task_result.py @@ -13,10 +13,11 @@ from __future__ import annotations +from collections.abc import Generator from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional -import numpy +import numpy as np from braket.annealing import ProblemType from braket.task_result import AdditionalMetadata, AnnealingTaskResult, TaskMetadata @@ -24,12 +25,11 @@ @dataclass class AnnealingQuantumTaskResult: - """ - Result of an annealing problem quantum task execution. This class is intended + """Result of an annealing problem quantum task execution. This class is intended to be initialized by a QuantumTask class. Args: - record_array (numpy.recarray): numpy array with keys 'solution' (numpy.ndarray) + record_array (np.recarray): numpy array with keys 'solution' (np.ndarray) where row is solution, column is value of the variable, 'solution_count' (numpy.ndarray) the number of times the solutions occurred, and 'value' (numpy.ndarray) the output or energy of the solutions. @@ -39,7 +39,7 @@ class AnnealingQuantumTaskResult: additional_metadata (AdditionalMetadata): Additional metadata about the quantum task """ - record_array: numpy.recarray + record_array: np.recarray variable_count: int problem_type: ProblemType task_metadata: TaskMetadata @@ -47,38 +47,37 @@ class AnnealingQuantumTaskResult: def data( self, - selected_fields: Optional[List[str]] = None, + selected_fields: Optional[list[str]] = None, sorted_by: str = "value", reverse: bool = False, - ) -> Tuple: - """ - Iterate over the data in record_array + ) -> Generator[tuple]: + """Yields the data in record_array Args: - selected_fields (Optional[List[str]]): selected fields to return. + selected_fields (Optional[list[str]]): selected fields to return. Options are 'solution', 'value', and 'solution_count'. Default is None. sorted_by (str): Sorts the data by this field. Options are 'solution', 'value', and 'solution_count'. Default is 'value'. reverse (bool): If True, returns the data in reverse order. Default is False. Yields: - Tuple: data in record_array + Generator[tuple]: data in record_array """ if selected_fields is None: selected_fields = ["solution", "value", "solution_count"] if sorted_by is None: - order = numpy.arange(len(self.record_array)) + order = np.arange(len(self.record_array)) else: - order = numpy.argsort(self.record_array[sorted_by]) + order = np.argsort(self.record_array[sorted_by]) if reverse: - order = numpy.flip(order) + order = np.flip(order) for i in order: yield tuple(self.record_array[field][i] for field in selected_fields) - def __eq__(self, other) -> bool: + def __eq__(self, other: AnnealingQuantumTaskResult) -> bool: if isinstance(other, AnnealingQuantumTaskResult): # __eq__ on numpy arrays results in an array of booleans and therefore can't use # the default dataclass __eq__ implementation. Override equals to check if all @@ -100,8 +99,7 @@ def __eq__(self, other) -> bool: @staticmethod def from_object(result: AnnealingTaskResult) -> AnnealingQuantumTaskResult: - """ - Create AnnealingQuantumTaskResult from AnnealingTaskResult object + """Create AnnealingQuantumTaskResult from AnnealingTaskResult object Args: result (AnnealingTaskResult): AnnealingTaskResult object @@ -114,8 +112,7 @@ def from_object(result: AnnealingTaskResult) -> AnnealingQuantumTaskResult: @staticmethod def from_string(result: str) -> AnnealingQuantumTaskResult: - """ - Create AnnealingQuantumTaskResult from string + """Create AnnealingQuantumTaskResult from string Args: result (str): JSON object string @@ -127,12 +124,12 @@ def from_string(result: str) -> AnnealingQuantumTaskResult: @classmethod def _from_object(cls, result: AnnealingTaskResult) -> AnnealingQuantumTaskResult: - solutions = numpy.asarray(result.solutions, dtype=int) - values = numpy.asarray(result.values, dtype=float) + solutions = np.asarray(result.solutions, dtype=int) + values = np.asarray(result.values, dtype=float) if not result.solutionCounts: - solution_counts = numpy.ones(len(solutions), dtype=int) + solution_counts = np.ones(len(solutions), dtype=int) else: - solution_counts = numpy.asarray(result.solutionCounts, dtype=int) + solution_counts = np.asarray(result.solutionCounts, dtype=int) record_array = AnnealingQuantumTaskResult._create_record_array( solutions, solution_counts, values ) @@ -150,15 +147,17 @@ def _from_object(cls, result: AnnealingTaskResult) -> AnnealingQuantumTaskResult @staticmethod def _create_record_array( - solutions: numpy.ndarray, solution_counts: numpy.ndarray, values: numpy.ndarray - ) -> numpy.recarray: - """ - Create a solutions record for AnnealingQuantumTaskResult + solutions: np.ndarray, solution_counts: np.ndarray, values: np.ndarray + ) -> np.recarray: + """Create a solutions record for AnnealingQuantumTaskResult Args: - solutions (numpy.ndarray): row is solution, column is value of the variable - solution_counts (numpy.ndarray): list of number of times the solutions occurred - values (numpy.ndarray): list of the output or energy of the solutions + solutions (np.ndarray): row is solution, column is value of the variable + solution_counts (np.ndarray): list of number of times the solutions occurred + values (np.ndarray): list of the output or energy of the solutions + + Returns: + np.recarray: A record array for solutions, value, and solution_count. """ num_solutions, variable_count = solutions.shape datatypes = [ @@ -167,7 +166,7 @@ def _create_record_array( ("solution_count", solution_counts.dtype), ] - record = numpy.rec.array(numpy.zeros(num_solutions, dtype=datatypes)) + record = np.rec.array(np.zeros(num_solutions, dtype=datatypes)) record["solution"] = solutions record["value"] = values record["solution_count"] = solution_counts diff --git a/src/braket/tasks/gate_model_quantum_task_result.py b/src/braket/tasks/gate_model_quantum_task_result.py index 631ca45e7..81f90ae7b 100644 --- a/src/braket/tasks/gate_model_quantum_task_result.py +++ b/src/braket/tasks/gate_model_quantum_task_result.py @@ -36,8 +36,7 @@ @dataclass class GateModelQuantumTaskResult: - """ - Result of a gate model quantum task execution. This class is intended + """Result of a gate model quantum task execution. This class is intended to be initialized by a QuantumTask class. Args: @@ -96,16 +95,15 @@ class GateModelQuantumTaskResult: def __post_init__(self): if self.result_types is not None: - self._result_types_indices = dict( - (GateModelQuantumTaskResult._result_type_hash(rt.type), i) + self._result_types_indices = { + GateModelQuantumTaskResult._result_type_hash(rt.type): i for i, rt in enumerate(self.result_types) - ) + } else: self._result_types_indices = {} def get_value_by_result_type(self, result_type: ResultType) -> Any: - """ - Get value by result type. The result type must have already been + """Get value by result type. The result type must have already been requested in the circuit sent to the device for this quantum task result. Args: @@ -123,20 +121,19 @@ def get_value_by_result_type(self, result_type: ResultType) -> Any: rt_hash = GateModelQuantumTaskResult._result_type_hash(rt_ir) result_type_index = self._result_types_indices[rt_hash] return self.values[result_type_index] - except KeyError: + except KeyError as e: raise ValueError( "Result type not found in result. " - + "Result types must be added to circuit before circuit is run on device." - ) + "Result types must be added to circuit before circuit is run on device." + ) from e - def __eq__(self, other) -> bool: + def __eq__(self, other: GateModelQuantumTaskResult) -> bool: if isinstance(other, GateModelQuantumTaskResult): return self.task_metadata.id == other.task_metadata.id return NotImplemented def get_compiled_circuit(self) -> Optional[str]: - """ - Get the compiled circuit, if one is available. + """Get the compiled circuit, if one is available. Returns: Optional[str]: The compiled circuit or None. @@ -153,27 +150,25 @@ def get_compiled_circuit(self) -> Optional[str]: @staticmethod def measurement_counts_from_measurements(measurements: np.ndarray) -> Counter: - """ - Creates measurement counts from measurements + """Creates measurement counts from measurements Args: - measurements (ndarray): 2d array - row is shot and column is qubit. + measurements (np.ndarray): 2d array - row is shot and column is qubit. Returns: Counter: A Counter of measurements. Key is the measurements in a big endian binary string. Value is the number of times that measurement occurred. """ - bitstrings = [] - for j in range(len(measurements)): - bitstrings.append("".join([str(element) for element in measurements[j]])) + bitstrings = [ + "".join([str(element) for element in measurements[j]]) for j in range(len(measurements)) + ] return Counter(bitstrings) @staticmethod def measurement_probabilities_from_measurement_counts( measurement_counts: Counter, ) -> dict[str, float]: - """ - Creates measurement probabilities from measurement counts + """Creates measurement probabilities from measurement counts Args: measurement_counts (Counter): A Counter of measurements. Key is the measurements @@ -184,19 +179,18 @@ def measurement_probabilities_from_measurement_counts( dict[str, float]: A dictionary of probabilistic results. Key is the measurements in a big endian binary string. Value is the probability the measurement occurred. """ - measurement_probabilities = {} shots = sum(measurement_counts.values()) - for key, count in measurement_counts.items(): - measurement_probabilities[key] = count / shots + measurement_probabilities = { + key: count / shots for key, count in measurement_counts.items() + } return measurement_probabilities @staticmethod def measurements_from_measurement_probabilities( measurement_probabilities: dict[str, float], shots: int ) -> np.ndarray: - """ - Creates measurements from measurement probabilities. + """Creates measurements from measurement probabilities. Args: measurement_probabilities (dict[str, float]): A dictionary of probabilistic results. @@ -205,7 +199,7 @@ def measurements_from_measurement_probabilities( shots (int): number of iterations on device. Returns: - ndarray: A dictionary of probabilistic results. + np.ndarray: A dictionary of probabilistic results. Key is the measurements in a big endian binary string. Value is the probability the measurement occurred. """ @@ -220,8 +214,7 @@ def measurements_from_measurement_probabilities( @staticmethod def from_object(result: GateModelTaskResult) -> GateModelQuantumTaskResult: - """ - Create GateModelQuantumTaskResult from GateModelTaskResult object. + """Create GateModelQuantumTaskResult from GateModelTaskResult object. Args: result (GateModelTaskResult): GateModelTaskResult object @@ -237,8 +230,7 @@ def from_object(result: GateModelTaskResult) -> GateModelQuantumTaskResult: @staticmethod def from_string(result: str) -> GateModelQuantumTaskResult: - """ - Create GateModelQuantumTaskResult from string. + """Create GateModelQuantumTaskResult from string. Args: result (str): JSON object string, with GateModelQuantumTaskResult attributes as keys. @@ -294,11 +286,6 @@ def _from_object_internal_computational_basis_sampling( " the result obj", ) measured_qubits = result.measuredQubits - if len(measured_qubits) != measurements.shape[1]: - raise ValueError( - f"Measured qubits {measured_qubits} is not equivalent to number of qubits " - + f"{measurements.shape[1]} in measurements" - ) if result.resultTypes: # Jaqcd does not return anything in the resultTypes schema field since the # result types are easily parsable from the IR. However, an OpenQASM program @@ -350,8 +337,7 @@ def _from_dict_internal_simulator_only( @staticmethod def cast_result_types(gate_model_task_result: GateModelTaskResult) -> None: - """ - Casts the result types to the types expected by the SDK. + """Casts the result types to the types expected by the SDK. Args: gate_model_task_result (GateModelTaskResult): GateModelTaskResult representing the @@ -360,13 +346,14 @@ def cast_result_types(gate_model_task_result: GateModelTaskResult) -> None: if gate_model_task_result.resultTypes: for result_type in gate_model_task_result.resultTypes: type = result_type.type.type - if type == "probability": + if type == "amplitude": + for state in result_type.value: + result_type.value[state] = complex(*result_type.value[state]) + + elif type == "probability": result_type.value = np.array(result_type.value) elif type == "statevector": result_type.value = np.array([complex(*value) for value in result_type.value]) - elif type == "amplitude": - for state in result_type.value: - result_type.value[state] = complex(*result_type.value[state]) @staticmethod def _calculate_result_types( diff --git a/src/braket/tasks/local_quantum_task.py b/src/braket/tasks/local_quantum_task.py index 69a0f1bd6..8417c71b2 100644 --- a/src/braket/tasks/local_quantum_task.py +++ b/src/braket/tasks/local_quantum_task.py @@ -39,6 +39,11 @@ def __init__( @property def id(self) -> str: + """Gets the task ID. + + Returns: + str: The ID of the task. + """ return str(self._id) def cancel(self) -> None: @@ -46,6 +51,11 @@ def cancel(self) -> None: raise NotImplementedError("Cannot cancel completed local task") def state(self) -> str: + """Gets the state of the task. + + Returns: + str: Returns COMPLETED + """ return "COMPLETED" def result( @@ -57,8 +67,12 @@ def result( def async_result(self) -> asyncio.Task: """Get the quantum task result asynchronously. + + Raises: + NotImplementedError: Asynchronous local simulation unsupported + Returns: - Task: Get the quantum task result asynchronously. + asyncio.Task: Get the quantum task result asynchronously. """ # TODO: Allow for asynchronous simulation raise NotImplementedError("Asynchronous local simulation unsupported") diff --git a/src/braket/tasks/photonic_model_quantum_task_result.py b/src/braket/tasks/photonic_model_quantum_task_result.py index 4d2fa0201..59eba68ac 100644 --- a/src/braket/tasks/photonic_model_quantum_task_result.py +++ b/src/braket/tasks/photonic_model_quantum_task_result.py @@ -26,15 +26,14 @@ class PhotonicModelQuantumTaskResult: additional_metadata: AdditionalMetadata measurements: np.ndarray = None - def __eq__(self, other) -> bool: + def __eq__(self, other: PhotonicModelQuantumTaskResult) -> bool: if isinstance(other, PhotonicModelQuantumTaskResult): return self.task_metadata.id == other.task_metadata.id return NotImplemented @staticmethod def from_object(result: PhotonicModelTaskResult) -> PhotonicModelQuantumTaskResult: - """ - Create PhotonicModelQuantumTaskResult from PhotonicModelTaskResult object. + """Create PhotonicModelQuantumTaskResult from PhotonicModelTaskResult object. Args: result (PhotonicModelTaskResult): PhotonicModelTaskResult object diff --git a/src/braket/tasks/quantum_task.py b/src/braket/tasks/quantum_task.py index c306078ca..f07a1be7b 100644 --- a/src/braket/tasks/quantum_task.py +++ b/src/braket/tasks/quantum_task.py @@ -27,6 +27,7 @@ class QuantumTask(ABC): @abstractmethod def id(self) -> str: """Get the quantum task ID. + Returns: str: The quantum task ID. """ @@ -38,6 +39,7 @@ def cancel(self) -> None: @abstractmethod def state(self) -> str: """Get the state of the quantum task. + Returns: str: State of the quantum task. """ @@ -49,22 +51,23 @@ def result( GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult ]: """Get the quantum task result. + Returns: - Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: # noqa - Get the quantum task result. Call async_result if you want the result in an + Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: Get + the quantum task result. Call async_result if you want the result in an asynchronous way. - """ + """ # noqa E501 @abstractmethod def async_result(self) -> asyncio.Task: """Get the quantum task result asynchronously. + Returns: - Task: Get the quantum task result asynchronously. + asyncio.Task: Get the quantum task result asynchronously. """ - def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: - """ - Get task metadata. + def metadata(self, use_cached_value: bool = False) -> dict[str, Any]: # noqa B027 + """Get task metadata. Args: use_cached_value (bool): If True, uses the value retrieved from the previous diff --git a/src/braket/tasks/quantum_task_batch.py b/src/braket/tasks/quantum_task_batch.py index ff0a0b82a..790f6e19a 100644 --- a/src/braket/tasks/quantum_task_batch.py +++ b/src/braket/tasks/quantum_task_batch.py @@ -31,7 +31,8 @@ def results( ] ]: """Get the quantum task results. + Returns: - list[Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]]:: # noqa - Get the quantum task results. - """ + list[Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]]: Get + the quantum task results. + """ # noqa: E501 diff --git a/src/braket/timings/time_series.py b/src/braket/timings/time_series.py index 9d2b9cf00..67797d6c4 100644 --- a/src/braket/timings/time_series.py +++ b/src/braket/timings/time_series.py @@ -19,7 +19,7 @@ from decimal import Decimal from enum import Enum from numbers import Number -from typing import Union +from typing import Optional @dataclass @@ -105,6 +105,9 @@ def from_lists(times: list[float], values: list[float]) -> TimeSeries: Returns: TimeSeries: time series constructed from lists + + Raises: + ValueError: If the len of `times` does not equal len of `values`. """ if len(times) != len(values): raise ValueError( @@ -118,12 +121,12 @@ def from_lists(times: list[float], values: list[float]) -> TimeSeries: return ts @staticmethod - def constant_like(times: Union[list[float], TimeSeries], constant: float = 0.0) -> TimeSeries: + def constant_like(times: list | float | TimeSeries, constant: float = 0.0) -> TimeSeries: """Obtain a constant time series given another time series or the list of time points, - and the constant values + and the constant values. Args: - times (Union[list[float], TimeSeries]): list of time points or a time series + times (list | float | TimeSeries): list of time points or a time series constant (float): constant value Returns: @@ -149,11 +152,15 @@ def concatenate(self, other: TimeSeries) -> TimeSeries: Notes: Keeps the time points in both time series unchanged. Assumes that the time points in the first TimeSeries - are at earler times then the time points in the second TimeSeries. + are at earlier times then the time points in the second TimeSeries. Returns: TimeSeries: The concatenated time series. + Raises: + ValueError: If the timeseries is not empty and time points in the first + TimeSeries are not strictly smaller than in the second. + Example: :: time_series_1 = TimeSeries.from_lists(times=[0, 0.1], values=[1, 2]) @@ -165,7 +172,6 @@ def concatenate(self, other: TimeSeries) -> TimeSeries: concat_ts.times() = [0, 0.1, 0.2, 0.3] concat_ts.values() = [1, 2, 4, 5] """ - not_empty_ts = len(other.times()) * len(self.times()) != 0 if not_empty_ts and min(other.times()) <= max(self.times()): raise ValueError( @@ -202,6 +208,9 @@ def stitch( Returns: TimeSeries: The stitched time series. + Raises: + ValueError: If boundary is not one of {"mean", "left", "right"}. + Example (StitchBoundaryCondition.MEAN): :: time_series_1 = TimeSeries.from_lists(times=[0, 0.1], values=[1, 2]) @@ -229,7 +238,6 @@ def stitch( stitch_ts.times() = [0, 0.1, 0.3] stitch_ts.values() = [1, 4, 5] """ - if len(self.times()) == 0: return TimeSeries.from_lists(times=other.times(), values=other.values()) if len(other.times()) == 0: @@ -260,24 +268,33 @@ def stitch( return new_time_series - def discretize(self, time_resolution: Decimal, value_resolution: Decimal) -> TimeSeries: + def discretize( + self, time_resolution: Optional[Decimal], value_resolution: Optional[Decimal] + ) -> TimeSeries: """Creates a discretized version of the time series, rounding all times and values to the closest multiple of the corresponding resolution. Args: - time_resolution (Decimal): Time resolution - value_resolution (Decimal): Value resolution + time_resolution (Optional[Decimal]): Time resolution + value_resolution (Optional[Decimal]): Value resolution Returns: TimeSeries: A new discretized time series. """ discretized_ts = TimeSeries() for item in self: - discretized_ts.put( - time=round(Decimal(item.time) / time_resolution) * time_resolution, - value=round(Decimal(item.value) / value_resolution) * value_resolution, - ) + if time_resolution is None: + discretized_time = Decimal(item.time) + else: + discretized_time = round(Decimal(item.time) / time_resolution) * time_resolution + + if value_resolution is None: + discretized_value = Decimal(item.value) + else: + discretized_value = round(Decimal(item.value) / value_resolution) * value_resolution + + discretized_ts.put(time=discretized_time, value=discretized_value) return discretized_ts @staticmethod @@ -287,18 +304,20 @@ def periodic_signal(times: list[float], values: list[float], num_repeat: int = 1 Args: times (list[float]): List of time points in a single block values (list[float]): Values for the time series in a single block - num_repeat (int): Number of block repeatitions + num_repeat (int): Number of block repetitions + + Raises: + ValueError: If the first and last values are not the same Returns: TimeSeries: A new periodic time series. """ - - if not (values[0] == values[-1]): - raise ValueError("The first and last values must coinscide to guarantee periodicity") + if values[0] != values[-1]: + raise ValueError("The first and last values must coincide to guarantee periodicity") new_time_series = TimeSeries() repeating_block = TimeSeries.from_lists(times=times, values=values) - for index in range(num_repeat): + for _index in range(num_repeat): new_time_series = new_time_series.stitch(repeating_block) return new_time_series @@ -316,6 +335,9 @@ def trapezoidal_signal( slew_rate_max (float): The maximum slew rate time_separation_min (float): The minimum separation of time points + Raises: + ValueError: If the time separation is negative + Returns: TimeSeries: A trapezoidal time series @@ -324,7 +346,6 @@ def trapezoidal_signal( f(t) from t=0 to t=T, where T is the duration. We also assume the trapezoidal time series starts and ends at zero. """ - if area <= 0.0: raise ValueError("The area of the trapezoidal time series has to be positive.") if value_max <= 0.0: @@ -370,8 +391,7 @@ def trapezoidal_signal( # TODO: Verify if this belongs here. def _all_close(first: TimeSeries, second: TimeSeries, tolerance: Number = 1e-7) -> bool: - """ - Returns True if the times and values in two time series are all within (less than) + """Returns True if the times and values in two time series are all within (less than) a given tolerance range. The values in the TimeSeries must be numbers that can be subtracted from each-other, support getting the absolute value, and can be compared against the tolerance. diff --git a/src/braket/tracking/pricing.py b/src/braket/tracking/pricing.py index 2b049b251..c269208c2 100644 --- a/src/braket/tracking/pricing.py +++ b/src/braket/tracking/pricing.py @@ -53,9 +53,13 @@ def get_prices(self) -> None: text_response.readline() self._price_list = list(csv.DictReader(text_response)) - @lru_cache() - def price_search(self, **kwargs) -> list[dict[str, str]]: + @lru_cache + def price_search(self, **kwargs: str) -> list[dict[str, str]]: """Searches the price list for a given set of parameters. + + Args: + **kwargs (str): Arbitrary keyword arguments. + Returns: list[dict[str, str]]: The price list. """ diff --git a/src/braket/tracking/tracker.py b/src/braket/tracking/tracker.py index 84b294ad2..47c13625a 100644 --- a/src/braket/tracking/tracker.py +++ b/src/braket/tracking/tracker.py @@ -30,8 +30,7 @@ class Tracker: - """ - Amazon Braket cost tracker. + """Amazon Braket cost tracker. Use this class to track costs incurred from quantum tasks on Amazon Braket. """ @@ -47,6 +46,7 @@ def __exit__(self, *args): def start(self) -> Tracker: """Start tracking resources with this tracker. + Returns: Tracker: self. """ @@ -54,6 +54,7 @@ def start(self) -> Tracker: def stop(self) -> Tracker: """Stop tracking resources with this tracker. + Returns: Tracker: self. """ @@ -61,14 +62,14 @@ def stop(self) -> Tracker: def receive_event(self, event: _TaskCreationEvent) -> None: """Process a Tack Creation Event. + Args: event (_TaskCreationEvent): The event to process. """ self._recieve_internal(event) def tracked_resources(self) -> list[str]: - """ - Resources tracked by this tracker. + """Resources tracked by this tracker. Returns: list[str]: The list of quantum task ids for quantum tasks tracked by this tracker. @@ -76,8 +77,7 @@ def tracked_resources(self) -> list[str]: return list(self._resources.keys()) def qpu_tasks_cost(self) -> Decimal: - """ - Estimate cost of all quantum tasks tracked by this tracker that use Braket qpu devices. + """Estimate cost of all quantum tasks tracked by this tracker that use Braket qpu devices. Note: Charges shown are estimates based on your Amazon Braket simulator and quantum processing unit (QPU) task usage. Estimated charges shown may differ from your actual @@ -95,8 +95,8 @@ def qpu_tasks_cost(self) -> Decimal: return total_cost def simulator_tasks_cost(self) -> Decimal: - """ - Estimate cost of all quantum tasks tracked by this tracker using Braket simulator devices. + """Estimate cost of all quantum tasks tracked by this tracker using Braket simulator + devices. Note: The cost of a simulator quantum task is not available until after the results for the task have been fetched. Call `result()` on an `AwsQuantumTask` before estimating its cost @@ -118,12 +118,11 @@ def simulator_tasks_cost(self) -> Decimal: return total_cost def quantum_tasks_statistics(self) -> dict[str, dict[str, Any]]: - """ - Get a summary of quantum tasks grouped by device. + """Get a summary of quantum tasks grouped by device. Returns: - dict[str,dict[str,Any]] : A dictionary where each key is a device arn, and maps to - a dictionary sumarizing the quantum tasks run on the device. The summary includes the + dict[str, dict[str, Any]]: A dictionary where each key is a device arn, and maps to + a dictionary summarizing the quantum tasks run on the device. The summary includes the total shots sent to the device and the most recent status of the quantum tasks created on this device. For finished quantum tasks on simulator devices, the summary also includes the duration of the simulation. @@ -272,7 +271,7 @@ def _get_simulator_task_cost(task_arn: str, details: dict) -> Decimal: product_family = "Simulator Task" operation = "CompleteTask" if details["status"] == "FAILED" and device_name == "TN1": - # Rehersal step of TN1 can fail and charges still apply. + # Rehearsal step of TN1 can fail and charges still apply. operation = "FailedTask" search_dict = { diff --git a/src/braket/tracking/tracking_context.py b/src/braket/tracking/tracking_context.py index 128af025d..37f4a3dc0 100644 --- a/src/braket/tracking/tracking_context.py +++ b/src/braket/tracking/tracking_context.py @@ -20,6 +20,7 @@ def __init__(self): def register_tracker(self, tracker: Tracker) -> None: # noqa F821 """Registers a tracker. + Args: tracker (Tracker): The tracker. """ @@ -27,6 +28,7 @@ def register_tracker(self, tracker: Tracker) -> None: # noqa F821 def deregister_tracker(self, tracker: Tracker) -> None: # noqa F821 """Deregisters a tracker. + Args: tracker (Tracker): The tracker. """ @@ -34,13 +36,19 @@ def deregister_tracker(self, tracker: Tracker) -> None: # noqa F821 def broadcast_event(self, event: _TrackingEvent) -> None: # noqa F821 """Broadcasts an event to all trackers. + Args: event (_TrackingEvent): The event to broadcast. """ for tracker in self._trackers: tracker.receive_event(event) - def active_trackers(self) -> None: + def active_trackers(self) -> set: + """Gets the active trackers. + + Returns: + set: The set of active trackers. + """ return self._trackers diff --git a/test/integ_tests/conftest.py b/test/integ_tests/conftest.py index 3c9edd2f8..d187ca3c4 100644 --- a/test/integ_tests/conftest.py +++ b/test/integ_tests/conftest.py @@ -12,18 +12,74 @@ # language governing permissions and limitations under the License. import os +import random +import string import boto3 import pytest from botocore.exceptions import ClientError +from braket.aws.aws_device import AwsDevice +from braket.aws.aws_quantum_job import AwsQuantumJob from braket.aws.aws_session import AwsSession +SV1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" +DM1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/dm1" +TN1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/tn1" +SIMULATOR_ARNS = [SV1_ARN, DM1_ARN, TN1_ARN] + +job_complete_name = "".join(random.choices(string.ascii_lowercase + string.digits, k=12)) +job_fail_name = "".join(random.choices(string.ascii_lowercase + string.digits, k=12)) + + +def pytest_configure_node(node): + """xdist hook""" + node.workerinput["JOB_COMPLETED_NAME"] = job_complete_name + node.workerinput["JOB_FAILED_NAME"] = job_fail_name + if endpoint := os.getenv("BRAKET_ENDPOINT"): + node.workerinput["BRAKET_ENDPOINT"] = endpoint + node.workerinput["AWS_REGION"] = os.getenv("AWS_REGION") + + +def pytest_xdist_node_collection_finished(ids): + """Uses the pytest xdist hook to check whether tests with jobs are to be ran. + If they are, the first reporting worker sets a flag that it created the tests + to avoid concurrency limits. This is the first time in the pytest setup the + controller has all the tests to be ran from the worker nodes. + """ + run_jobs = any("job" in test for test in ids) + profile_name = os.environ["AWS_PROFILE"] + region_name = os.getenv("AWS_REGION") + aws_session = AwsSession( + boto3.session.Session(profile_name=profile_name, region_name=region_name) + ) + if run_jobs and os.getenv("JOBS_STARTED") is None and region_name != "eu-north-1": + AwsQuantumJob.create( + "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + job_name=job_fail_name, + source_module="test/integ_tests/job_test_script.py", + entry_point="job_test_script:start_here", + aws_session=aws_session, + wait_until_complete=False, + hyperparameters={"test_case": "failed"}, + ) + AwsQuantumJob.create( + "arn:aws:braket:::device/quantum-simulator/amazon/sv1", + job_name=job_complete_name, + source_module="test/integ_tests/job_test_script.py", + entry_point="job_test_script:start_here", + aws_session=aws_session, + wait_until_complete=False, + hyperparameters={"test_case": "completed"}, + ) + os.environ["JOBS_STARTED"] = "True" + @pytest.fixture(scope="session") -def boto_session(): +def boto_session(request): profile_name = os.environ["AWS_PROFILE"] - return boto3.session.Session(profile_name=profile_name) + region_name = request.config.workerinput["AWS_REGION"] + return boto3.session.Session(profile_name=profile_name, region_name=region_name) @pytest.fixture(scope="session") @@ -82,3 +138,54 @@ def s3_prefix(): @pytest.fixture(scope="module") def s3_destination_folder(s3_bucket, s3_prefix): return AwsSession.S3DestinationFolder(s3_bucket, s3_prefix) + + +@pytest.fixture(scope="session") +def braket_simulators(aws_session): + return ( + {simulator_arn: AwsDevice(simulator_arn, aws_session) for simulator_arn in SIMULATOR_ARNS} + if aws_session.region != "eu-north-1" + else None + ) + + +@pytest.fixture(scope="session") +def braket_devices(): + return AwsDevice.get_devices(statuses=["RETIRED", "ONLINE", "OFFLINE"]) + + +@pytest.fixture(scope="session", autouse=True) +def created_braket_devices(aws_session, braket_devices): + return {device.arn: device for device in braket_devices} + + +@pytest.fixture(scope="session") +def job_completed_name(request): + return request.config.workerinput["JOB_COMPLETED_NAME"] + + +@pytest.fixture(scope="session") +def job_failed_name(request): + return request.config.workerinput["JOB_FAILED_NAME"] + + +@pytest.fixture(scope="session", autouse=True) +def completed_quantum_job(job_completed_name): + job_arn = [ + job["jobArn"] + for job in boto3.client("braket").search_jobs(filters=[])["jobs"] + if job["jobName"] == job_completed_name + ][0] + + return AwsQuantumJob(arn=job_arn) + + +@pytest.fixture(scope="session", autouse=True) +def failed_quantum_job(job_failed_name): + job_arn = [ + job["jobArn"] + for job in boto3.client("braket").search_jobs(filters=[])["jobs"] + if job["jobName"] == job_failed_name + ][0] + + return AwsQuantumJob(arn=job_arn) diff --git a/test/integ_tests/gate_model_device_testing_utils.py b/test/integ_tests/gate_model_device_testing_utils.py index 6ccd6f05b..901cc819e 100644 --- a/test/integ_tests/gate_model_device_testing_utils.py +++ b/test/integ_tests/gate_model_device_testing_utils.py @@ -13,7 +13,7 @@ import concurrent.futures import math -from typing import Any, Dict, Union +from typing import Any, Union import numpy as np @@ -26,11 +26,11 @@ from braket.tasks import GateModelQuantumTaskResult -def get_tol(shots: int) -> Dict[str, float]: - return {"atol": 0.1, "rtol": 0.15} if shots else {"atol": 0.01, "rtol": 0} +def get_tol(shots: int) -> dict[str, float]: + return {"atol": 0.2, "rtol": 0.3} if shots else {"atol": 0.01, "rtol": 0} -def qubit_ordering_testing(device: Device, run_kwargs: Dict[str, Any]): +def qubit_ordering_testing(device: Device, run_kwargs: dict[str, Any]): # |110> should get back value of "110" state_110 = Circuit().x(0).x(1).i(2) result = device.run(state_110, **run_kwargs).result() @@ -51,8 +51,8 @@ def qubit_ordering_testing(device: Device, run_kwargs: Dict[str, Any]): def no_result_types_testing( program: Union[Circuit, OpenQasmProgram], device: Device, - run_kwargs: Dict[str, Any], - expected: Dict[str, float], + run_kwargs: dict[str, Any], + expected: dict[str, float], ): shots = run_kwargs["shots"] tol = get_tol(shots) @@ -63,14 +63,14 @@ def no_result_types_testing( assert len(result.measurements) == shots -def no_result_types_bell_pair_testing(device: Device, run_kwargs: Dict[str, Any]): +def no_result_types_bell_pair_testing(device: Device, run_kwargs: dict[str, Any]): bell = Circuit().h(0).cnot(0, 1) bell_qasm = bell.to_ir(ir_type=IRType.OPENQASM) for task in (bell, bell_qasm): no_result_types_testing(task, device, run_kwargs, {"00": 0.5, "11": 0.5}) -def result_types_observable_not_in_instructions(device: Device, run_kwargs: Dict[str, Any]): +def result_types_observable_not_in_instructions(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) bell = ( @@ -81,8 +81,8 @@ def result_types_observable_not_in_instructions(device: Device, run_kwargs: Dict .variance(observable=Observable.Y(), target=[3]) ) bell_qasm = bell.to_ir(ir_type=IRType.OPENQASM) - for task in (bell, bell_qasm): - result = device.run(task, **run_kwargs).result() + results = device.run_batch([bell, bell_qasm], **run_kwargs).results() + for result in results: assert np.allclose(result.values[0], 0, **tol) assert np.allclose(result.values[1], 1, **tol) @@ -90,7 +90,7 @@ def result_types_observable_not_in_instructions(device: Device, run_kwargs: Dict def result_types_zero_shots_bell_pair_testing( device: Device, include_state_vector: bool, - run_kwargs: Dict[str, Any], + run_kwargs: dict[str, Any], include_amplitude: bool = True, ): circuit = ( @@ -103,9 +103,9 @@ def result_types_zero_shots_bell_pair_testing( circuit.amplitude(["01", "10", "00", "11"]) if include_state_vector: circuit.state_vector() - tasks = (circuit, circuit.to_ir(ir_type=IRType.OPENQASM)) - for task in tasks: - result = device.run(task, **run_kwargs).result() + tasks = [circuit, circuit.to_ir(ir_type=IRType.OPENQASM)] + results = device.run_batch(tasks, **run_kwargs).results() + for result in results: assert len(result.result_types) == 3 if include_state_vector else 2 assert np.allclose( result.get_value_by_result_type( @@ -128,7 +128,7 @@ def result_types_zero_shots_bell_pair_testing( assert np.isclose(amplitude["11"], 1 / np.sqrt(2)) -def result_types_bell_pair_full_probability_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_bell_pair_full_probability_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) circuit = Circuit().h(0).cnot(0, 1).probability() @@ -139,11 +139,11 @@ def result_types_bell_pair_full_probability_testing(device: Device, run_kwargs: assert np.allclose( result.get_value_by_result_type(ResultType.Probability()), np.array([0.5, 0, 0, 0.5]), - **tol + **tol, ) -def result_types_bell_pair_marginal_probability_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_bell_pair_marginal_probability_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) circuit = Circuit().h(0).cnot(0, 1).probability(0) @@ -154,11 +154,11 @@ def result_types_bell_pair_marginal_probability_testing(device: Device, run_kwar assert np.allclose( result.get_value_by_result_type(ResultType.Probability(target=0)), np.array([0.5, 0.5]), - **tol + **tol, ) -def result_types_nonzero_shots_bell_pair_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_nonzero_shots_bell_pair_testing(device: Device, run_kwargs: dict[str, Any]): circuit = ( Circuit() .h(0) @@ -188,7 +188,7 @@ def result_types_nonzero_shots_bell_pair_testing(device: Device, run_kwargs: Dic def result_types_hermitian_testing( - device: Device, run_kwargs: Dict[str, Any], test_program: bool = True + device: Device, run_kwargs: dict[str, Any], test_program: bool = True ): shots = run_kwargs["shots"] theta = 0.543 @@ -202,7 +202,7 @@ def result_types_hermitian_testing( ) if shots: circuit.add_result_type(ResultType.Sample(Observable.Hermitian(array), 0)) - tasks = (circuit,) if not test_program else (circuit, circuit.to_ir(ir_type=IRType.OPENQASM)) + tasks = (circuit, circuit.to_ir(ir_type=IRType.OPENQASM)) if test_program else (circuit,) for task in tasks: result = device.run(task, **run_kwargs).result() @@ -215,7 +215,7 @@ def result_types_hermitian_testing( def result_types_all_selected_testing( - device: Device, run_kwargs: Dict[str, Any], test_program: bool = True + device: Device, run_kwargs: dict[str, Any], test_program: bool = True ): shots = run_kwargs["shots"] theta = 0.543 @@ -231,7 +231,7 @@ def result_types_all_selected_testing( if shots: circuit.add_result_type(ResultType.Sample(Observable.Hermitian(array), 1)) - tasks = (circuit,) if not test_program else (circuit, circuit.to_ir(ir_type=IRType.OPENQASM)) + tasks = (circuit, circuit.to_ir(ir_type=IRType.OPENQASM)) if test_program else (circuit,) for task in tasks: result = device.run(task, **run_kwargs).result() @@ -280,7 +280,7 @@ def assert_variance_expectation_sample_result( assert np.allclose(variance, expected_var, **tol) -def result_types_tensor_x_y_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_tensor_x_y_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] theta = 0.432 phi = 0.123 @@ -308,7 +308,7 @@ def result_types_tensor_x_y_testing(device: Device, run_kwargs: Dict[str, Any]): ) -def result_types_tensor_z_z_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_tensor_z_z_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] theta = 0.432 phi = 0.123 @@ -329,7 +329,7 @@ def result_types_tensor_z_z_testing(device: Device, run_kwargs: Dict[str, Any]): ) -def result_types_tensor_hermitian_hermitian_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_tensor_hermitian_hermitian_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] theta = 0.432 phi = 0.123 @@ -359,7 +359,7 @@ def result_types_tensor_hermitian_hermitian_testing(device: Device, run_kwargs: ) -def result_types_tensor_z_h_y_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_tensor_z_h_y_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] theta = 0.432 phi = 0.123 @@ -386,7 +386,7 @@ def result_types_tensor_z_h_y_testing(device: Device, run_kwargs: Dict[str, Any] ) -def result_types_tensor_z_hermitian_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_tensor_z_hermitian_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] theta = 0.432 phi = 0.123 @@ -450,7 +450,7 @@ def result_types_tensor_z_hermitian_testing(device: Device, run_kwargs: Dict[str ) -def result_types_tensor_y_hermitian_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_tensor_y_hermitian_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] theta = 0.432 phi = 0.123 @@ -479,7 +479,7 @@ def result_types_tensor_y_hermitian_testing(device: Device, run_kwargs: Dict[str ) -def result_types_noncommuting_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_noncommuting_testing(device: Device, run_kwargs: dict[str, Any]): shots = 0 theta = 0.432 phi = 0.123 @@ -525,7 +525,7 @@ def result_types_noncommuting_testing(device: Device, run_kwargs: Dict[str, Any] assert np.allclose(result.values[3], expected_mean3) -def result_types_noncommuting_flipped_targets_testing(device: Device, run_kwargs: Dict[str, Any]): +def result_types_noncommuting_flipped_targets_testing(device: Device, run_kwargs: dict[str, Any]): circuit = ( Circuit() .h(0) @@ -540,7 +540,7 @@ def result_types_noncommuting_flipped_targets_testing(device: Device, run_kwargs assert np.allclose(result.values[1], np.sqrt(2) / 2) -def result_types_noncommuting_all(device: Device, run_kwargs: Dict[str, Any]): +def result_types_noncommuting_all(device: Device, run_kwargs: dict[str, Any]): array = np.array([[1, 2j], [-2j, 0]]) circuit = ( Circuit() @@ -556,7 +556,7 @@ def result_types_noncommuting_all(device: Device, run_kwargs: Dict[str, Any]): assert np.allclose(result.values[1], [0, 0]) -def multithreaded_bell_pair_testing(device: Device, run_kwargs: Dict[str, Any]): +def multithreaded_bell_pair_testing(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) bell = Circuit().h(0).cnot(0, 1) @@ -581,7 +581,7 @@ def run_circuit(circuit): assert len(result.measurements) == shots -def noisy_circuit_1qubit_noise_full_probability(device: Device, run_kwargs: Dict[str, Any]): +def noisy_circuit_1qubit_noise_full_probability(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) circuit = Circuit().x(0).x(1).bit_flip(0, 0.1).probability() @@ -592,11 +592,11 @@ def noisy_circuit_1qubit_noise_full_probability(device: Device, run_kwargs: Dict assert np.allclose( result.get_value_by_result_type(ResultType.Probability()), np.array([0.0, 0.1, 0, 0.9]), - **tol + **tol, ) -def noisy_circuit_2qubit_noise_full_probability(device: Device, run_kwargs: Dict[str, Any]): +def noisy_circuit_2qubit_noise_full_probability(device: Device, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) K0 = np.eye(4) * np.sqrt(0.9) @@ -611,11 +611,11 @@ def noisy_circuit_2qubit_noise_full_probability(device: Device, run_kwargs: Dict assert np.allclose( result.get_value_by_result_type(ResultType.Probability()), np.array([0.1, 0.0, 0, 0.9]), - **tol + **tol, ) -def batch_bell_pair_testing(device: AwsDevice, run_kwargs: Dict[str, Any]): +def batch_bell_pair_testing(device: AwsDevice, run_kwargs: dict[str, Any]): shots = run_kwargs["shots"] tol = get_tol(shots) circuits = [Circuit().h(0).cnot(0, 1) for _ in range(10)] @@ -630,7 +630,7 @@ def batch_bell_pair_testing(device: AwsDevice, run_kwargs: Dict[str, Any]): assert [task.result() for task in batch.tasks] == results -def bell_pair_openqasm_testing(device: AwsDevice, run_kwargs: Dict[str, Any]): +def bell_pair_openqasm_testing(device: AwsDevice, run_kwargs: dict[str, Any]): openqasm_string = ( "OPENQASM 3;" "qubit[2] q;" @@ -649,7 +649,7 @@ def bell_pair_openqasm_testing(device: AwsDevice, run_kwargs: Dict[str, Any]): def openqasm_noisy_circuit_1qubit_noise_full_probability( - device: Device, run_kwargs: Dict[str, Any] + device: Device, run_kwargs: dict[str, Any] ): shots = run_kwargs["shots"] tol = get_tol(shots) @@ -671,11 +671,11 @@ def openqasm_noisy_circuit_1qubit_noise_full_probability( assert np.allclose( result.get_value_by_result_type(ResultType.Probability(target=[0, 1])), np.array([0.0, 0.1, 0, 0.9]), - **tol + **tol, ) -def openqasm_result_types_bell_pair_testing(device: Device, run_kwargs: Dict[str, Any]): +def openqasm_result_types_bell_pair_testing(device: Device, run_kwargs: dict[str, Any]): openqasm_string = ( "OPENQASM 3;" "qubit[2] q;" diff --git a/test/integ_tests/job_test_script.py b/test/integ_tests/job_test_script.py index 95b890d60..d2a74e6cc 100644 --- a/test/integ_tests/job_test_script.py +++ b/test/integ_tests/job_test_script.py @@ -43,8 +43,8 @@ def completed_job_script(): device = AwsDevice(get_job_device_arn()) bell = Circuit().h(0).cnot(0, 1) - for count in range(5): - task = device.run(bell, shots=100) + for _ in range(3): + task = device.run(bell, shots=10) print(task.result().measurement_counts) save_job_result({"converged": True, "energy": -0.2}) save_job_checkpoint({"some_data": "abc"}, checkpoint_file_suffix="plain_data") diff --git a/test/integ_tests/job_testing_utils.py b/test/integ_tests/job_testing_utils.py new file mode 100644 index 000000000..4493df180 --- /dev/null +++ b/test/integ_tests/job_testing_utils.py @@ -0,0 +1,25 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import re + +from braket.aws import AwsSession +from braket.jobs import Framework, retrieve_image + + +def decorator_python_version(): + aws_session = AwsSession() + image_uri = retrieve_image(Framework.BASE, aws_session.region) + tag = aws_session.get_full_image_tag(image_uri) + major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups() + return int(major_version), int(minor_version) diff --git a/test/integ_tests/test_cost_tracking.py b/test/integ_tests/test_cost_tracking.py index 7e97c6f78..0c06f297a 100644 --- a/test/integ_tests/test_cost_tracking.py +++ b/test/integ_tests/test_cost_tracking.py @@ -18,14 +18,11 @@ import pytest from botocore.exceptions import ClientError -from braket.aws import AwsDevice, AwsSession +from braket.aws import AwsDevice, AwsDeviceType, AwsSession from braket.circuits import Circuit -from braket.devices import Devices from braket.tracking import Tracker from braket.tracking.tracker import MIN_SIMULATOR_DURATION -_RESERVATION_ONLY_DEVICES = {Devices.IonQ.Forte1} - @pytest.mark.parametrize( "qpu", @@ -94,24 +91,28 @@ def test_all_devices_price_search(): tasks = {} for region in AwsDevice.REGIONS: s = AwsSession(boto3.Session(region_name=region)) - for device in [device for device in devices if device.arn not in _RESERVATION_ONLY_DEVICES]: - try: - s.get_device(device.arn) - - # If we are here, device can create tasks in region - details = { - "shots": 100, - "device": device.arn, - "billed_duration": MIN_SIMULATOR_DURATION, - "job_task": False, - "status": "COMPLETED", - } - tasks[f"task:for:{device.name}:{region}"] = details.copy() - details["job_task"] = True - tasks[f"jobtask:for:{device.name}:{region}"] = details - except s.braket_client.exceptions.ResourceNotFoundException: - # device does not exist in region, so nothing to test + # Skip devices with empty execution windows + for device in [device for device in devices if device.properties.service.executionWindows]: + if region == "eu-north-1" and device.type == AwsDeviceType.SIMULATOR: pass + else: + try: + s.get_device(device.arn) + + # If we are here, device can create tasks in region + details = { + "shots": 100, + "device": device.arn, + "billed_duration": MIN_SIMULATOR_DURATION, + "job_task": False, + "status": "COMPLETED", + } + tasks[f"task:for:{device.name}:{region}"] = details.copy() + details["job_task"] = True + tasks[f"jobtask:for:{device.name}:{region}"] = details + except s.braket_client.exceptions.ResourceNotFoundException: + # device does not exist in region, so nothing to test + pass t = Tracker() t._resources = tasks diff --git a/test/integ_tests/test_create_local_quantum_job.py b/test/integ_tests/test_create_local_quantum_job.py index ad91d0a03..16c001f35 100644 --- a/test/integ_tests/test_create_local_quantum_job.py +++ b/test/integ_tests/test_create_local_quantum_job.py @@ -81,7 +81,7 @@ def test_completed_local_job(aws_session, capsys): }, ), ]: - with open(file_name, "r") as f: + with open(file_name) as f: assert json.loads(f.read()) == expected_data # Capture logs @@ -99,6 +99,7 @@ def test_completed_local_job(aws_session, capsys): for data in logs_to_validate: assert data in log_data + finally: os.chdir(current_dir) diff --git a/test/integ_tests/test_create_quantum_job.py b/test/integ_tests/test_create_quantum_job.py index 02c16313b..ce88d122b 100644 --- a/test/integ_tests/test_create_quantum_job.py +++ b/test/integ_tests/test_create_quantum_job.py @@ -16,47 +16,37 @@ import re import sys import tempfile +import time from pathlib import Path import job_test_script import pytest from job_test_module.job_test_submodule.job_test_submodule_file import submodule_helper +from job_testing_utils import decorator_python_version -from braket.aws import AwsSession from braket.aws.aws_quantum_job import AwsQuantumJob from braket.devices import Devices -from braket.jobs import Framework, get_input_data_dir, hybrid_job, retrieve_image, save_job_result +from braket.jobs import get_input_data_dir, hybrid_job, save_job_result -def decorator_python_version(): - aws_session = AwsSession() - image_uri = retrieve_image(Framework.BASE, aws_session.region) - tag = aws_session.get_full_image_tag(image_uri) - major_version, minor_version = re.search(r"-py(\d)(\d+)-", tag).groups() - return int(major_version), int(minor_version) - - -def test_failed_quantum_job(aws_session, capsys): +def test_failed_quantum_job(aws_session, capsys, failed_quantum_job): """Asserts the hybrid job is failed with the output, checkpoints, quantum tasks not created in bucket and only input is uploaded to s3. Validate the results/download results have the response raising RuntimeError. Also, check if the logs displays the Assertion Error. """ - - job = AwsQuantumJob.create( - "arn:aws:braket:::device/quantum-simulator/amazon/sv1", - source_module="test/integ_tests/job_test_script.py", - entry_point="job_test_script:start_here", - aws_session=aws_session, - wait_until_complete=True, - hyperparameters={"test_case": "failed"}, - ) + job = failed_quantum_job + job_name = job.name pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$" assert re.match(pattern=pattern, string=job.arn) # Check job is in failed state. - assert job.state() == "FAILED" + while True: + time.sleep(5) + if job.state() in AwsQuantumJob.TERMINAL_STATES: + break + assert job.state(use_cached_value=True) == "FAILED" # Check whether the respective folder with files are created for script, # output, tasks and checkpoints. @@ -65,7 +55,7 @@ def test_failed_quantum_job(aws_session, capsys): subdirectory = re.match( rf"s3://{s3_bucket}/jobs/{job.name}/(\d+)/script/source.tar.gz", job.metadata()["algorithmSpecification"]["scriptModeConfig"]["s3Uri"], - ).group(1) + )[1] keys = aws_session.list_keys( bucket=s3_bucket, prefix=f"jobs/{job_name}/{subdirectory}/", @@ -97,27 +87,22 @@ def test_failed_quantum_job(aws_session, capsys): ) -def test_completed_quantum_job(aws_session, capsys): +def test_completed_quantum_job(aws_session, capsys, completed_quantum_job): """Asserts the hybrid job is completed with the output, checkpoints, quantum tasks and script folder created in S3 for respective hybrid job. Validate the results are downloaded and results are what we expect. Also, assert that logs contains all the necessary steps for setup and running the hybrid job and is displayed to the user. """ - job = AwsQuantumJob.create( - "arn:aws:braket:::device/quantum-simulator/amazon/sv1", - source_module="test/integ_tests/job_test_script.py", - entry_point="job_test_script:start_here", - wait_until_complete=True, - aws_session=aws_session, - hyperparameters={"test_case": "completed"}, - ) - + job = completed_quantum_job + job_name = job.name pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$" assert re.match(pattern=pattern, string=job.arn) - # check job is in completed state. - assert job.state() == "COMPLETED" + # Check the job has completed + job.result() + + assert job.state(use_cached_value=True) == "COMPLETED" # Check whether the respective folder with files are created for script, # output, tasks and checkpoints. @@ -126,7 +111,7 @@ def test_completed_quantum_job(aws_session, capsys): subdirectory = re.match( rf"s3://{s3_bucket}/jobs/{job.name}/(\d+)/script/source.tar.gz", job.metadata()["algorithmSpecification"]["scriptModeConfig"]["s3Uri"], - ).group(1) + )[1] keys = aws_session.list_keys( bucket=s3_bucket, prefix=f"jobs/{job_name}/{subdirectory}/", @@ -179,19 +164,11 @@ def test_completed_quantum_job(aws_session, capsys): == expected_data ) - # Check downloaded results exists in the file system after the call. - downloaded_result = f"{job_name}/{AwsQuantumJob.RESULTS_FILENAME}" current_dir = Path.cwd() with tempfile.TemporaryDirectory() as temp_dir: os.chdir(temp_dir) try: - job.download_result() - assert ( - Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists() - and Path(downloaded_result).exists() - ) - # Check results match the expectations. assert job.result() == {"converged": True, "energy": -0.2} finally: @@ -235,9 +212,9 @@ def __str__(self): input_data=str(Path("test", "integ_tests", "requirements")), ) def decorator_job(a, b: int, c=0, d: float = 1.0, **extras): - with open(Path(get_input_data_dir()) / "requirements.txt", "r") as f: + with open(Path(get_input_data_dir()) / "requirements.txt") as f: assert f.readlines() == ["pytest\n"] - with open(Path("test", "integ_tests", "requirements.txt"), "r") as f: + with open(Path("test", "integ_tests", "requirements.txt")) as f: assert f.readlines() == ["pytest\n"] assert dir(pytest) assert a.attribute == "value" @@ -247,7 +224,7 @@ def decorator_job(a, b: int, c=0, d: float = 1.0, **extras): assert extras["extra_arg"] == "extra_value" hp_file = os.environ["AMZN_BRAKET_HP_FILE"] - with open(hp_file, "r") as f: + with open(hp_file) as f: hyperparameters = json.load(f) assert hyperparameters == { "a": "MyClass{value}", @@ -270,7 +247,7 @@ def decorator_job(a, b: int, c=0, d: float = 1.0, **extras): os.chdir(temp_dir) try: job.download_result() - with open(Path(job.name, "test", "output_file.txt"), "r") as f: + with open(Path(job.name, "test", "output_file.txt")) as f: assert f.read() == "hello" assert ( Path(job.name, "results.json").exists() @@ -301,12 +278,12 @@ def test_decorator_job_submodule(): }, ) def decorator_job_submodule(): - with open(Path(get_input_data_dir("my_input")) / "requirements.txt", "r") as f: + with open(Path(get_input_data_dir("my_input")) / "requirements.txt") as f: assert f.readlines() == ["pytest\n"] - with open(Path("test", "integ_tests", "requirements.txt"), "r") as f: + with open(Path("test", "integ_tests", "requirements.txt")) as f: assert f.readlines() == ["pytest\n"] with open( - Path(get_input_data_dir("my_dir")) / "job_test_submodule" / "requirements.txt", "r" + Path(get_input_data_dir("my_dir")) / "job_test_submodule" / "requirements.txt" ) as f: assert f.readlines() == ["pytest\n"] with open( @@ -317,7 +294,6 @@ def decorator_job_submodule(): "job_test_submodule", "requirements.txt", ), - "r", ) as f: assert f.readlines() == ["pytest\n"] assert dir(pytest) diff --git a/test/integ_tests/test_density_matrix_simulator.py b/test/integ_tests/test_density_matrix_simulator.py index b377bdebe..1fba5ebcb 100644 --- a/test/integ_tests/test_density_matrix_simulator.py +++ b/test/integ_tests/test_density_matrix_simulator.py @@ -7,7 +7,7 @@ from braket.aws import AwsDevice from braket.circuits import Circuit, Noise, Observable -SHOTS = 1000 +SHOTS = 500 DM1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/dm1" SIMULATOR_ARNS = [DM1_ARN] diff --git a/test/integ_tests/test_device_creation.py b/test/integ_tests/test_device_creation.py index 4cb7de2b1..540c09f61 100644 --- a/test/integ_tests/test_device_creation.py +++ b/test/integ_tests/test_device_creation.py @@ -11,7 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from typing import List, Set import pytest @@ -28,8 +27,8 @@ @pytest.mark.parametrize( "arn", [(RIGETTI_ARN), (IONQ_ARN), (OQC_ARN), (SIMULATOR_ARN), (PULSE_ARN)] ) -def test_device_creation(arn, aws_session): - device = AwsDevice(arn, aws_session=aws_session) +def test_device_creation(arn, created_braket_devices): + device = created_braket_devices[arn] assert device.arn == arn assert device.name assert device.status @@ -39,17 +38,17 @@ def test_device_creation(arn, aws_session): @pytest.mark.parametrize("arn", [(PULSE_ARN)]) -def test_device_pulse_properties(arn, aws_session): - device = AwsDevice(arn, aws_session=aws_session) +def test_device_pulse_properties(arn, aws_session, created_braket_devices): + device = created_braket_devices[arn] assert device.ports assert device.frames -def test_device_across_regions(aws_session): +def test_device_across_regions(aws_session, created_braket_devices): # assert QPUs across different regions can be created using the same aws_session - AwsDevice(RIGETTI_ARN, aws_session) - AwsDevice(IONQ_ARN, aws_session) - AwsDevice(OQC_ARN, aws_session) + created_braket_devices[RIGETTI_ARN] + created_braket_devices[IONQ_ARN] + created_braket_devices[OQC_ARN] @pytest.mark.parametrize("arn", [(RIGETTI_ARN), (IONQ_ARN), (OQC_ARN), (SIMULATOR_ARN)]) @@ -59,8 +58,8 @@ def test_get_devices_arn(arn): @pytest.mark.parametrize("arn", [(PULSE_ARN)]) -def test_device_gate_calibrations(arn, aws_session): - device = AwsDevice(arn, aws_session=aws_session) +def test_device_gate_calibrations(arn, aws_session, created_braket_devices): + device = created_braket_devices[arn] assert device.gate_calibrations @@ -76,8 +75,8 @@ def test_get_devices_others(): assert result.status in statuses -def test_get_devices_all(): - result_arns = [result.arn for result in AwsDevice.get_devices()] +def test_get_devices_all(braket_devices): + result_arns = [result.arn for result in braket_devices] for arn in [RIGETTI_ARN, IONQ_ARN, SIMULATOR_ARN, OQC_ARN]: assert arn in result_arns @@ -108,15 +107,14 @@ def _get_device_name(device: AwsDevice) -> str: return device_name -def _get_active_providers(aws_devices: List[AwsDevice]) -> Set[str]: - active_providers = set() - for device in aws_devices: - if device.status != "RETIRED": - active_providers.add(_get_provider_name(device)) +def _get_active_providers(aws_devices: list[AwsDevice]) -> set[str]: + active_providers = { + _get_provider_name(device) for device in aws_devices if device.status != "RETIRED" + } return active_providers -def _validate_device(device: AwsDevice, active_providers: Set[str]): +def _validate_device(device: AwsDevice, active_providers: set[str]): provider_name = _get_provider_name(device) if provider_name not in active_providers: provider_name = f"_{provider_name}" @@ -127,17 +125,16 @@ def _validate_device(device: AwsDevice, active_providers: Set[str]): assert getattr(getattr(Devices, provider_name), device_name) == device.arn -def test_device_enum(): - aws_devices = AwsDevice.get_devices() - active_providers = _get_active_providers(aws_devices) +def test_device_enum(braket_devices, created_braket_devices): + active_providers = _get_active_providers(braket_devices) # validate all devices in API - for device in aws_devices: + for device in braket_devices: _validate_device(device, active_providers) # validate all devices in enum providers = [getattr(Devices, attr) for attr in dir(Devices) if not attr.startswith("__")] for provider in providers: for device_arn in provider: - device = AwsDevice(device_arn) + device = created_braket_devices[device_arn] _validate_device(device, active_providers) diff --git a/test/integ_tests/test_measure.py b/test/integ_tests/test_measure.py new file mode 100644 index 000000000..b7fef275b --- /dev/null +++ b/test/integ_tests/test_measure.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import re + +import pytest +from botocore.exceptions import ClientError + +from braket.aws.aws_device import AwsDevice +from braket.circuits.circuit import Circuit +from braket.devices import LocalSimulator + +DEVICE = LocalSimulator() +SHOTS = 8000 + +IONQ_ARN = "arn:aws:braket:us-east-1::device/qpu/ionq/Harmony" +SIMULATOR_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" +OQC_ARN = "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy" + + +@pytest.mark.parametrize("arn", [(IONQ_ARN), (SIMULATOR_ARN)]) +def test_unsupported_devices(arn): + device = AwsDevice(arn) + if device.status == "OFFLINE": + pytest.skip("Device offline") + + circ = Circuit().h(0).cnot(0, 1).h(2).measure([0, 1]) + error_string = re.escape( + "An error occurred (ValidationException) when calling the " + "CreateQuantumTask operation: Device requires all qubits in the program to be measured. " + "This may be caused by declaring non-contiguous qubits or measuring partial qubits" + ) + with pytest.raises(ClientError, match=error_string): + device.run(circ, shots=1000) + + +@pytest.mark.parametrize("sim", [("braket_sv"), ("braket_dm")]) +def test_measure_on_local_sim(sim): + circ = Circuit().h(0).cnot(0, 1).h(2).measure([0, 1]) + device = LocalSimulator(sim) + result = device.run(circ, SHOTS).result() + assert len(result.measurements[0]) == 2 + assert result.measured_qubits == [0, 1] + + +@pytest.mark.parametrize("arn", [(OQC_ARN)]) +def test_measure_on_supported_devices(arn): + device = AwsDevice(arn) + if not device.is_available: + pytest.skip("Device offline") + circ = Circuit().h(0).cnot(0, 1).measure([0]) + result = device.run(circ, SHOTS).result() + assert len(result.measurements[0]) == 1 + assert result.measured_qubits == [0] + + +@pytest.mark.parametrize( + "circuit, expected_measured_qubits", + [ + (Circuit().h(0).cnot(0, 1).cnot(1, 2).cnot(2, 3).measure([0, 1, 3]), [0, 1, 3]), + (Circuit().h(0).measure(0), [0]), + ], +) +def test_measure_targets(circuit, expected_measured_qubits): + result = DEVICE.run(circuit, SHOTS).result() + assert result.measured_qubits == expected_measured_qubits + assert len(result.measurements[0]) == len(expected_measured_qubits) + + +def test_measure_with_noise(): + device = LocalSimulator("braket_dm") + circuit = Circuit().x(0).x(1).bit_flip(0, probability=0.1).measure(0) + result = device.run(circuit, SHOTS).result() + assert result.measured_qubits == [0] + assert len(result.measurements[0]) == 1 diff --git a/test/integ_tests/test_queue_information.py b/test/integ_tests/test_queue_information.py index 3398fde40..7e7590a80 100644 --- a/test/integ_tests/test_queue_information.py +++ b/test/integ_tests/test_queue_information.py @@ -11,7 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from braket.aws import AwsDevice, AwsQuantumJob + +from braket.aws import AwsDevice from braket.aws.queue_information import ( HybridJobQueueInfo, QuantumTaskQueueInfo, @@ -47,15 +48,11 @@ def test_task_queue_position(): assert queue_information.message is None -def test_job_queue_position(aws_session): - job = AwsQuantumJob.create( - device=Devices.Amazon.SV1, - source_module="test/integ_tests/job_test_script.py", - entry_point="job_test_script:start_here", - aws_session=aws_session, - wait_until_complete=True, - hyperparameters={"test_case": "completed"}, - ) +def test_job_queue_position(aws_session, completed_quantum_job): + job = completed_quantum_job + + # Check the job is complete + job.result() # call the queue_position method. queue_information = job.queue_position() diff --git a/test/integ_tests/test_reservation_arn.py b/test/integ_tests/test_reservation_arn.py index e0736f802..64135f76e 100644 --- a/test/integ_tests/test_reservation_arn.py +++ b/test/integ_tests/test_reservation_arn.py @@ -15,10 +15,9 @@ import pytest from botocore.exceptions import ClientError -from test_create_quantum_job import decorator_python_version +from job_testing_utils import decorator_python_version -from braket.aws import AwsDevice -from braket.aws.aws_quantum_job import AwsQuantumJob +from braket.aws import AwsDevice, DirectReservation from braket.circuits import Circuit from braket.devices import Devices from braket.jobs import get_job_device_arn, hybrid_job @@ -37,11 +36,11 @@ def test_create_task_via_invalid_reservation_arn_on_qpu(reservation_arn): device = AwsDevice(Devices.IonQ.Harmony) with pytest.raises(ClientError, match="Reservation arn is invalid"): - device.run( - circuit, - shots=10, - reservation_arn=reservation_arn, - ) + device.run(circuit, shots=10, reservation_arn=reservation_arn) + + with pytest.raises(ClientError, match="Reservation arn is invalid"): + with DirectReservation(device, reservation_arn=reservation_arn): + device.run(circuit, shots=10) def test_create_task_via_reservation_arn_on_simulator(reservation_arn): @@ -49,24 +48,11 @@ def test_create_task_via_reservation_arn_on_simulator(reservation_arn): device = AwsDevice(Devices.Amazon.SV1) with pytest.raises(ClientError, match="Braket Direct is not supported for"): - device.run( - circuit, - shots=10, - reservation_arn=reservation_arn, - ) + device.run(circuit, shots=10, reservation_arn=reservation_arn) - -def test_create_job_via_invalid_reservation_arn_on_qpu(aws_session, reservation_arn): - with pytest.raises(ClientError, match="Reservation arn is invalid"): - AwsQuantumJob.create( - device=Devices.IonQ.Harmony, - source_module="test/integ_tests/job_test_script.py", - entry_point="job_test_script:start_here", - wait_until_complete=True, - aws_session=aws_session, - hyperparameters={"test_case": "completed"}, - reservation_arn=reservation_arn, - ) + with pytest.raises(ClientError, match="Braket Direct is not supported for"): + with DirectReservation(device, reservation_arn=reservation_arn): + device.run(circuit, shots=10) @pytest.mark.xfail( diff --git a/test/integ_tests/test_simulator_quantum_task.py b/test/integ_tests/test_simulator_quantum_task.py index 7b9e7d208..0dd4f3f30 100644 --- a/test/integ_tests/test_simulator_quantum_task.py +++ b/test/integ_tests/test_simulator_quantum_task.py @@ -46,7 +46,7 @@ # shots-based tests in this file have the capacity to fail rarely due to probabilistic checks. # this parameter can be adjusted if we find tests regularly failing. -SHOTS = 9000 +SHOTS = 5000 SV1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" DM1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/dm1" SIMULATOR_ARNS = [SV1_ARN, DM1_ARN] @@ -54,16 +54,18 @@ @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_no_result_types_bell_pair(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_no_result_types_bell_pair( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] no_result_types_bell_pair_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_qubit_ordering(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_qubit_ordering(simulator_arn, aws_session, s3_destination_folder, braket_simulators): + device = braket_simulators[simulator_arn] qubit_ordering_testing(device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder}) @@ -71,9 +73,9 @@ def test_qubit_ordering(simulator_arn, aws_session, s3_destination_folder): "simulator_arn, include_amplitude", list(zip(SIMULATOR_ARNS, [True, False])) ) def test_result_types_no_shots( - simulator_arn, include_amplitude, aws_session, s3_destination_folder + simulator_arn, include_amplitude, aws_session, s3_destination_folder, braket_simulators ): - device = AwsDevice(simulator_arn, aws_session) + device = braket_simulators[simulator_arn] result_types_zero_shots_bell_pair_testing( device, False, @@ -83,16 +85,20 @@ def test_result_types_no_shots( @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_result_types_nonzero_shots_bell_pair(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_nonzero_shots_bell_pair( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_nonzero_shots_bell_pair_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_result_types_bell_pair_full_probability(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_bell_pair_full_probability( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_bell_pair_full_probability_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @@ -100,41 +106,49 @@ def test_result_types_bell_pair_full_probability(simulator_arn, aws_session, s3_ @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) def test_result_types_bell_pair_marginal_probability( - simulator_arn, aws_session, s3_destination_folder + simulator_arn, aws_session, s3_destination_folder, braket_simulators ): - device = AwsDevice(simulator_arn, aws_session) + device = braket_simulators[simulator_arn] result_types_bell_pair_marginal_probability_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_tensor_x_y(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_tensor_x_y( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_tensor_x_y_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_tensor_z_h_y(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_tensor_z_h_y( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_tensor_z_h_y_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_hermitian(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_hermitian( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_hermitian_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_tensor_z_z(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_tensor_z_z( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_tensor_z_z_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @@ -142,103 +156,117 @@ def test_result_types_tensor_z_z(simulator_arn, shots, aws_session, s3_destinati @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) def test_result_types_tensor_hermitian_hermitian( - simulator_arn, shots, aws_session, s3_destination_folder + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators ): - device = AwsDevice(simulator_arn, aws_session) + device = braket_simulators[simulator_arn] result_types_tensor_hermitian_hermitian_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_tensor_y_hermitian(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_tensor_y_hermitian( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_tensor_y_hermitian_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_tensor_z_hermitian(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_tensor_z_hermitian( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_tensor_z_hermitian_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) -def test_result_types_all_selected(simulator_arn, shots, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_all_selected( + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_all_selected_testing( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_result_types_noncommuting(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_noncommuting( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_noncommuting_testing(device, {"s3_destination_folder": s3_destination_folder}) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) def test_result_types_noncommuting_flipped_targets( - simulator_arn, aws_session, s3_destination_folder + simulator_arn, aws_session, s3_destination_folder, braket_simulators ): - device = AwsDevice(simulator_arn, aws_session) + device = braket_simulators[simulator_arn] result_types_noncommuting_flipped_targets_testing( device, {"s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_result_types_noncommuting_all(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_result_types_noncommuting_all( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] result_types_noncommuting_all(device, {"s3_destination_folder": s3_destination_folder}) @pytest.mark.parametrize("simulator_arn,shots", ARNS_WITH_SHOTS) def test_result_types_observable_not_in_instructions( - simulator_arn, shots, aws_session, s3_destination_folder + simulator_arn, shots, aws_session, s3_destination_folder, braket_simulators ): - device = AwsDevice(simulator_arn, aws_session) + device = braket_simulators[simulator_arn] result_types_observable_not_in_instructions( device, {"shots": shots, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_multithreaded_bell_pair(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_multithreaded_bell_pair( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] multithreaded_bell_pair_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_batch_bell_pair(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_batch_bell_pair(simulator_arn, aws_session, s3_destination_folder, braket_simulators): + device = braket_simulators[simulator_arn] batch_bell_pair_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_bell_pair_openqasm(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_bell_pair_openqasm(simulator_arn, aws_session, s3_destination_folder, braket_simulators): + device = braket_simulators[simulator_arn] bell_pair_openqasm_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) -def test_bell_pair_openqasm_results(simulator_arn, aws_session, s3_destination_folder): - device = AwsDevice(simulator_arn, aws_session) +def test_bell_pair_openqasm_results( + simulator_arn, aws_session, s3_destination_folder, braket_simulators +): + device = braket_simulators[simulator_arn] openqasm_result_types_bell_pair_testing( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} ) -def test_openqasm_probability_results(aws_session, s3_destination_folder): +def test_openqasm_probability_results(aws_session, s3_destination_folder, braket_simulators): device = AwsDevice("arn:aws:braket:::device/quantum-simulator/amazon/dm1", aws_session) openqasm_noisy_circuit_1qubit_noise_full_probability( device, {"shots": SHOTS, "s3_destination_folder": s3_destination_folder} @@ -247,10 +275,12 @@ def test_openqasm_probability_results(aws_session, s3_destination_folder): @pytest.mark.parametrize("simulator_arn", SIMULATOR_ARNS) @pytest.mark.parametrize("num_layers", [50, 100, 500, 1000]) -def test_many_layers(simulator_arn, num_layers, aws_session, s3_destination_folder): +def test_many_layers( + simulator_arn, num_layers, aws_session, s3_destination_folder, braket_simulators +): num_qubits = 10 circuit = many_layers(num_qubits, num_layers) - device = AwsDevice(simulator_arn, aws_session) + device = braket_simulators[simulator_arn] tol = get_tol(SHOTS) result = device.run(circuit, shots=SHOTS, s3_destination_folder=s3_destination_folder).result() diff --git a/test/integ_tests/test_tensor_network_simulator.py b/test/integ_tests/test_tensor_network_simulator.py index 093781b62..5c406eadd 100644 --- a/test/integ_tests/test_tensor_network_simulator.py +++ b/test/integ_tests/test_tensor_network_simulator.py @@ -20,7 +20,7 @@ from braket.aws import AwsDevice from braket.circuits import Circuit -SHOTS = 1000 +SHOTS = 100 TN1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/tn1" SIMULATOR_ARNS = [TN1_ARN] diff --git a/test/unit_tests/braket/ahs/test_analog_hamiltonian_simulation.py b/test/unit_tests/braket/ahs/test_analog_hamiltonian_simulation.py index bdc3e92f2..83178c120 100644 --- a/test/unit_tests/braket/ahs/test_analog_hamiltonian_simulation.py +++ b/test/unit_tests/braket/ahs/test_analog_hamiltonian_simulation.py @@ -23,7 +23,7 @@ AtomArrangement, DiscretizationError, DrivingField, - ShiftingField, + LocalDetuning, SiteType, ) from braket.ahs.atom_arrangement import AtomArrangementItem @@ -61,8 +61,8 @@ def driving_field(): @pytest.fixture -def shifting_field(): - return ShiftingField( +def local_detuning(): + return LocalDetuning( Field( TimeSeries().put(0.0, -1.25664e8).put(3.0e-6, 1.25664e8), Pattern([0.5, 1.0, 0.5, 0.5, 0.5, 0.5]), @@ -78,8 +78,8 @@ def test_create(): assert mock1 == ahs.hamiltonian -def test_to_ir(register, driving_field, shifting_field): - hamiltonian = driving_field + shifting_field +def test_to_ir(register, driving_field, local_detuning): + hamiltonian = driving_field + local_detuning ahs = AnalogHamiltonianSimulation(register=register, hamiltonian=hamiltonian) problem = ahs.to_ir() assert Program.parse_raw(problem.json()) == problem @@ -123,8 +123,8 @@ def test_invalid_action_name(): AnalogHamiltonianSimulation(register=Mock(), hamiltonian=Mock()).discretize(device) -def test_discretize(register, driving_field, shifting_field): - hamiltonian = driving_field + shifting_field +def test_discretize(register, driving_field, local_detuning): + hamiltonian = driving_field + local_detuning ahs = AnalogHamiltonianSimulation(register=register, hamiltonian=hamiltonian) action = Mock() @@ -141,8 +141,6 @@ def test_discretize(register, driving_field, shifting_field): device.properties.paradigm.rydberg.rydbergGlobal.phaseResolution = Decimal("5E-7") device.properties.paradigm.rydberg.rydbergLocal.timeResolution = Decimal("1E-9") - device.properties.paradigm.rydberg.rydbergLocal.commonDetuningResolution = Decimal("2000.0") - device.properties.paradigm.rydberg.rydbergLocal.localDetuningResolution = Decimal("0.01") discretized_ahs = ahs.discretize(device) discretized_ir = discretized_ahs.to_ir() @@ -177,11 +175,12 @@ def test_discretize(register, driving_field, shifting_field): "values": ["-125664000.0", "-125664000.0", "125664000.0", "125664000.0"], }, } - assert discretized_json["hamiltonian"]["shiftingFields"][0]["magnitude"] == { - "pattern": ["0.50", "1.00", "0.50", "0.50", "0.50", "0.50"], + local_detuning = discretized_json["hamiltonian"]["localDetuning"][0]["magnitude"] + assert local_detuning == { + "pattern": ["0.5", "1", "0.5", "0.5", "0.5", "0.5"], "time_series": { "times": ["0E-9", "0.000003000"], - "values": ["-125664000.0", "125664000.0"], + "values": ["-125664000", "125664000"], }, } diff --git a/test/unit_tests/braket/ahs/test_atom_arrangement.py b/test/unit_tests/braket/ahs/test_atom_arrangement.py index 425458547..06a926163 100644 --- a/test/unit_tests/braket/ahs/test_atom_arrangement.py +++ b/test/unit_tests/braket/ahs/test_atom_arrangement.py @@ -52,9 +52,7 @@ def test_iteration(): atom_arrangement = AtomArrangement() for value in values: atom_arrangement.add(value) - returned_values = [] - for site in atom_arrangement: - returned_values.append(site.coordinate) + returned_values = [site.coordinate for site in atom_arrangement] assert values == returned_values diff --git a/test/unit_tests/braket/ahs/test_field.py b/test/unit_tests/braket/ahs/test_field.py index 4212ba336..2ff6714ce 100644 --- a/test/unit_tests/braket/ahs/test_field.py +++ b/test/unit_tests/braket/ahs/test_field.py @@ -16,7 +16,6 @@ import pytest -from braket.ahs.discretization_types import DiscretizationError from braket.ahs.field import Field from braket.ahs.pattern import Pattern from braket.timings.time_series import TimeSeries @@ -80,6 +79,12 @@ def test_discretize( [ (Decimal("0.1"), Decimal("10"), Decimal("0.5")), (Decimal("10"), Decimal("20"), None), + (Decimal("0.1"), None, Decimal("0.5")), + (None, Decimal("10"), Decimal("0.5")), + (None, None, Decimal("0.5")), + (None, Decimal("10"), None), + (Decimal("0.1"), None, None), + (None, None, None), (Decimal("100"), Decimal("0.1"), Decimal("1")), ], ) @@ -93,14 +98,3 @@ def test_uniform_field( ) or expected.pattern.series == actual.pattern.series assert expected.time_series.times() == actual.time_series.times() assert expected.time_series.values() == actual.time_series.values() - - -@pytest.mark.parametrize( - "time_res, value_res, pattern_res", - [ - (Decimal("10"), Decimal("20"), None), - ], -) -@pytest.mark.xfail(raises=DiscretizationError) -def test_invalid_pattern_res(default_field, time_res, value_res, pattern_res): - default_field.discretize(time_res, value_res, pattern_res) diff --git a/test/unit_tests/braket/ahs/test_shifting_field.py b/test/unit_tests/braket/ahs/test_local_detuning.py similarity index 70% rename from test/unit_tests/braket/ahs/test_shifting_field.py rename to test/unit_tests/braket/ahs/test_local_detuning.py index 0249b8758..8768dce64 100644 --- a/test/unit_tests/braket/ahs/test_shifting_field.py +++ b/test/unit_tests/braket/ahs/test_local_detuning.py @@ -16,49 +16,49 @@ import pytest from braket.ahs.hamiltonian import Hamiltonian -from braket.ahs.shifting_field import ShiftingField +from braket.ahs.local_detuning import LocalDetuning from braket.timings.time_series import StitchBoundaryCondition @pytest.fixture -def default_shifting_field(): - return ShiftingField(Mock()) +def default_local_detuning(): + return LocalDetuning(Mock()) def test_create(): mock0 = Mock() - field = ShiftingField(magnitude=mock0) + field = LocalDetuning(magnitude=mock0) assert mock0 == field.magnitude -def test_add_hamiltonian(default_shifting_field): - expected = [default_shifting_field, Mock(), Mock(), Mock()] +def test_add_hamiltonian(default_local_detuning): + expected = [default_local_detuning, Mock(), Mock(), Mock()] result = expected[0] + Hamiltonian([expected[1], expected[2], expected[3]]) assert result.terms == expected -def test_add_to_hamiltonian(default_shifting_field): - expected = [Mock(), Mock(), Mock(), default_shifting_field] +def test_add_to_hamiltonian(default_local_detuning): + expected = [Mock(), Mock(), Mock(), default_local_detuning] result = Hamiltonian([expected[0], expected[1], expected[2]]) + expected[3] assert result.terms == expected def test_add_to_other(): - field0 = ShiftingField(Mock()) - field1 = ShiftingField(Mock()) + field0 = LocalDetuning(Mock()) + field1 = LocalDetuning(Mock()) result = field0 + field1 assert type(result) is Hamiltonian assert result.terms == [field0, field1] -def test_add_to_self(default_shifting_field): - result = default_shifting_field + default_shifting_field +def test_add_to_self(default_local_detuning): + result = default_local_detuning + default_local_detuning assert type(result) is Hamiltonian - assert result.terms == [default_shifting_field, default_shifting_field] + assert result.terms == [default_local_detuning, default_local_detuning] -def test_iadd_to_other(default_shifting_field): - expected = [Mock(), Mock(), Mock(), default_shifting_field] +def test_iadd_to_other(default_local_detuning): + expected = [Mock(), Mock(), Mock(), default_local_detuning] other = Hamiltonian([expected[0], expected[1], expected[2]]) other += expected[3] assert other.terms == expected @@ -69,7 +69,7 @@ def test_from_lists(): glob_amplitude = [0.5, 0.8, 0.9, 1.0] pattern = [0.3, 0.7, 0.6, -0.5, 0, 1.6] - sh_field = ShiftingField.from_lists(times, glob_amplitude, pattern) + sh_field = LocalDetuning.from_lists(times, glob_amplitude, pattern) assert sh_field.magnitude.time_series.values() == glob_amplitude assert sh_field.magnitude.pattern.series == pattern @@ -82,7 +82,7 @@ def test_from_lists_not_eq_length(): glob_amplitude = [0.5, 0.8, 0.9, 1.0] pattern = [0.3, 0.7, 0.6, -0.5, 0, 1.6] - ShiftingField.from_lists(times, glob_amplitude, pattern) + LocalDetuning.from_lists(times, glob_amplitude, pattern) def test_stitch(): @@ -94,8 +94,8 @@ def test_stitch(): glob_amplitude_2 = [0.5, 0.8, 0.9, 1.0] pattern_2 = pattern_1 - sh_field_1 = ShiftingField.from_lists(times_1, glob_amplitude_1, pattern_1) - sh_field_2 = ShiftingField.from_lists(times_2, glob_amplitude_2, pattern_2) + sh_field_1 = LocalDetuning.from_lists(times_1, glob_amplitude_1, pattern_1) + sh_field_2 = LocalDetuning.from_lists(times_2, glob_amplitude_2, pattern_2) new_sh_field = sh_field_1.stitch(sh_field_2, boundary=StitchBoundaryCondition.LEFT) @@ -116,8 +116,8 @@ def test_stitch_not_eq_pattern(): glob_amplitude_2 = [0.5, 0.8, 0.9, 1.0] pattern_2 = [-0.3, 0.7, 0.6, -0.5, 0, 1.6] - sh_field_1 = ShiftingField.from_lists(times_1, glob_amplitude_1, pattern_1) - sh_field_2 = ShiftingField.from_lists(times_2, glob_amplitude_2, pattern_2) + sh_field_1 = LocalDetuning.from_lists(times_1, glob_amplitude_1, pattern_1) + sh_field_2 = LocalDetuning.from_lists(times_2, glob_amplitude_2, pattern_2) sh_field_1.stitch(sh_field_2) @@ -125,17 +125,15 @@ def test_stitch_not_eq_pattern(): def test_discretize(): magnitude_mock = Mock() mock_properties = Mock() - field = ShiftingField(magnitude=magnitude_mock) + field = LocalDetuning(magnitude=magnitude_mock) discretized_field = field.discretize(mock_properties) magnitude_mock.discretize.assert_called_with( time_resolution=mock_properties.rydberg.rydbergLocal.timeResolution, - value_resolution=mock_properties.rydberg.rydbergLocal.commonDetuningResolution, - pattern_resolution=mock_properties.rydberg.rydbergLocal.localDetuningResolution, ) assert field is not discretized_field assert discretized_field.magnitude == magnitude_mock.discretize.return_value @pytest.mark.xfail(raises=ValueError) -def test_iadd_to_itself(default_shifting_field): - default_shifting_field += Hamiltonian(Mock()) +def test_iadd_to_itself(default_local_detuning): + default_local_detuning += Hamiltonian(Mock()) diff --git a/test/unit_tests/braket/ahs/test_pattern.py b/test/unit_tests/braket/ahs/test_pattern.py index d84f3a925..920f2cc29 100644 --- a/test/unit_tests/braket/ahs/test_pattern.py +++ b/test/unit_tests/braket/ahs/test_pattern.py @@ -20,7 +20,15 @@ @pytest.fixture def default_values(): - return [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10] + return [ + Decimal(0), + Decimal("0.1"), + Decimal(1), + Decimal("0.5"), + Decimal("0.2"), + Decimal("0.001"), + Decimal("1e-10"), + ] @pytest.fixture @@ -38,6 +46,18 @@ def test_create(): "res, expected_series", [ # default pattern: [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10] + ( + None, + [ + Decimal("0"), + Decimal("0.1"), + Decimal("1"), + Decimal("0.5"), + Decimal("0.2"), + Decimal("0.001"), + Decimal("1e-10"), + ], + ), ( Decimal("0.001"), [ diff --git a/test/unit_tests/braket/aws/common_test_utils.py b/test/unit_tests/braket/aws/common_test_utils.py index f1d6d96df..aaca559f5 100644 --- a/test/unit_tests/braket/aws/common_test_utils.py +++ b/test/unit_tests/braket/aws/common_test_utils.py @@ -21,6 +21,7 @@ IONQ_ARN = "arn:aws:braket:us-east-1::device/qpu/ionq/Harmony" OQC_ARN = "arn:aws:braket:eu-west-2::device/qpu/oqc/Lucy" SV1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1" +DM1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/dm1" TN1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/tn1" XANADU_ARN = "arn:aws:braket:us-east-1::device/qpu/xanadu/Borealis" @@ -221,7 +222,7 @@ def run_and_assert( run_args.append(inputs) if gate_definitions is not None: run_args.append(gate_definitions) - run_args += extra_args if extra_args else [] + run_args += extra_args or [] run_kwargs = extra_kwargs or {} if reservation_arn: run_kwargs.update({"reservation_arn": reservation_arn}) @@ -294,7 +295,7 @@ def run_batch_and_assert( run_args.append(inputs) if gate_definitions is not None: run_args.append(gate_definitions) - run_args += extra_args if extra_args else [] + run_args += extra_args or [] run_kwargs = extra_kwargs or {} if reservation_arn: run_kwargs.update({"reservation_arn": reservation_arn}) @@ -349,7 +350,7 @@ def _create_task_args_and_kwargs( s3_folder if s3_folder is not None else default_s3_folder, shots if shots is not None else default_shots, ] - create_args += extra_args if extra_args else [] + create_args += extra_args or [] create_kwargs = extra_kwargs or {} create_kwargs.update( { diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index 29adf2517..a778dfb5f 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -13,6 +13,7 @@ import io import json import os +import textwrap from datetime import datetime from unittest.mock import Mock, PropertyMock, patch from urllib.error import URLError @@ -21,6 +22,7 @@ import pytest from botocore.exceptions import ClientError from common_test_utils import ( + DM1_ARN, DWAVE_ARN, IONQ_ARN, OQC_ARN, @@ -35,8 +37,9 @@ from braket.aws import AwsDevice, AwsDeviceType, AwsQuantumTask from braket.aws.queue_information import QueueDepthInfo, QueueType -from braket.circuits import Circuit, FreeParameter, Gate, QubitSet +from braket.circuits import Circuit, FreeParameter, Gate, Noise, QubitSet from braket.circuits.gate_calibrations import GateCalibrations +from braket.circuits.noise_model import GateCriteria, NoiseModel from braket.device_schema.device_execution_window import DeviceExecutionWindow from braket.device_schema.dwave import DwaveDeviceCapabilities from braket.device_schema.rigetti import RigettiDeviceCapabilities @@ -359,7 +362,59 @@ def test_d_wave_schema(): "actionType": "braket.ir.jaqcd.program", "version": ["1"], "supportedOperations": ["H"], - } + }, + }, + "paradigm": {"qubitCount": 30}, + "deviceParameters": {}, +} + +MOCK_GATE_MODEL_NOISE_SIMULATOR_CAPABILITIES_JSON = { + "braketSchemaHeader": { + "name": "braket.device_schema.simulators.gate_model_simulator_device_capabilities", + "version": "1", + }, + "service": { + "executionWindows": [ + { + "executionDay": "Everyday", + "windowStartHour": "11:00", + "windowEndHour": "12:00", + } + ], + "shotsRange": [1, 10], + }, + "action": { + "braket.ir.openqasm.program": { + "actionType": "braket.ir.openqasm.program", + "version": ["1"], + "supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"], + "supportedResultTypes": [ + { + "name": "StateVector", + "observables": ["x", "y", "z"], + "minShots": 0, + "maxShots": 0, + }, + ], + "supportedPragmas": [ + "braket_noise_bit_flip", + "braket_noise_depolarizing", + "braket_noise_kraus", + "braket_noise_pauli_channel", + "braket_noise_generalized_amplitude_damping", + "braket_noise_amplitude_damping", + "braket_noise_phase_flip", + "braket_noise_phase_damping", + "braket_noise_two_qubit_dephasing", + "braket_noise_two_qubit_depolarizing", + "braket_unitary_matrix", + "braket_result_type_sample", + "braket_result_type_expectation", + "braket_result_type_variance", + "braket_result_type_probability", + "braket_result_type_density_matrix", + ], + }, }, "paradigm": {"qubitCount": 30}, "deviceParameters": {}, @@ -376,6 +431,18 @@ def test_gate_model_sim_schema(): ) +MOCK_GATE_MODEL_NOISE_SIMULATOR_CAPABILITIES = GateModelSimulatorDeviceCapabilities.parse_obj( + MOCK_GATE_MODEL_NOISE_SIMULATOR_CAPABILITIES_JSON +) + + +def test_gate_model_sim_schema(): + validate( + MOCK_GATE_MODEL_NOISE_SIMULATOR_CAPABILITIES_JSON, + GateModelSimulatorDeviceCapabilities.schema(), + ) + + MOCK_GATE_MODEL_SIMULATOR = { "deviceName": "SV1", "deviceType": "SIMULATOR", @@ -384,6 +451,16 @@ def test_gate_model_sim_schema(): "deviceCapabilities": MOCK_GATE_MODEL_SIMULATOR_CAPABILITIES.json(), } + +MOCK_GATE_MODEL_NOISE_SIMULATOR = { + "deviceName": "DM1", + "deviceType": "SIMULATOR", + "providerName": "provider1", + "deviceStatus": "ONLINE", + "deviceCapabilities": MOCK_GATE_MODEL_NOISE_SIMULATOR_CAPABILITIES.json(), +} + + MOCK_DEFAULT_S3_DESTINATION_FOLDER = ( "amazon-braket-us-test-1-00000000", "tasks", @@ -756,7 +833,7 @@ def test_gate_calibration_refresh_no_url(arn): mock_session.region = RIGETTI_REGION device = AwsDevice(arn, mock_session) - assert device.refresh_gate_calibrations() == None + assert device.refresh_gate_calibrations() is None @patch("urllib.request.urlopen") @@ -868,7 +945,7 @@ def test_repr(arn): mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1 mock_session.region = RIGETTI_REGION device = AwsDevice(arn, mock_session) - expected = "Device('name': {}, 'arn': {})".format(device.name, device.arn) + expected = f"Device('name': {device.name}, 'arn': {device.arn})" assert repr(device) == expected @@ -1061,7 +1138,7 @@ def test_run_param_circuit_with_reservation_arn_batch_task( 43200, 0.25, inputs, - None, + {}, reservation_arn="arn:aws:braket:us-west-2:123456789123:reservation/a1b123cd-45e6-789f-gh01-i234567jk8l9", ) @@ -1093,6 +1170,7 @@ def test_run_param_circuit_with_inputs_batch_task( 43200, 0.25, inputs, + {}, ) @@ -1226,7 +1304,9 @@ def test_batch_circuit_with_task_and_input_mismatch( inputs = [{"beta": 0.2}, {"gamma": 0.1}, {"theta": 0.2}] circ_1 = Circuit().ry(angle=3, target=0) task_specifications = [[circ_1, single_circuit_input], openqasm_program] - wrong_number_of_inputs = "Multiple inputs and task specifications must " "be equal in number." + wrong_number_of_inputs = ( + "Multiple inputs, task specifications and gate definitions must be equal in length." + ) with pytest.raises(ValueError, match=wrong_number_of_inputs): _run_batch_and_assert( @@ -1241,6 +1321,7 @@ def test_batch_circuit_with_task_and_input_mismatch( 43200, 0.25, inputs, + {}, ) @@ -1417,7 +1498,7 @@ def test_run_with_positional_args_and_kwargs( 86400, 0.25, {}, - ["foo"], + {}, "arn:aws:braket:us-west-2:123456789123:reservation/a1b123cd-45e6-789f-gh01-i234567jk8l9", None, {"bar": 1, "baz": 2}, @@ -1457,6 +1538,7 @@ def test_run_batch_no_extra( 43200, 0.25, {}, + {}, ) @@ -1483,6 +1565,7 @@ def test_run_batch_with_shots( 43200, 0.25, {}, + {}, ) @@ -1509,6 +1592,7 @@ def test_run_batch_with_max_parallel_and_kwargs( 43200, 0.25, inputs={"theta": 0.2}, + gate_definitions={}, extra_kwargs={"bar": 1, "baz": 2}, ) @@ -1669,6 +1753,16 @@ def test_get_devices(mock_copy_session, aws_session): "providerName": "OQC", } ], + # eu-north-1 + [ + { + "deviceArn": SV1_ARN, + "deviceName": "SV1", + "deviceType": "SIMULATOR", + "deviceStatus": "ONLINE", + "providerName": "Amazon Braket", + }, + ], # Only two regions to search outside of current ValueError("should not be reachable"), ] @@ -1679,7 +1773,7 @@ def test_get_devices(mock_copy_session, aws_session): ValueError("should not be reachable"), ] mock_copy_session.return_value = session_for_region - # Search order: us-east-1, us-west-1, us-west-2, eu-west-2 + # Search order: us-east-1, us-west-1, us-west-2, eu-west-2, eu-north-1 results = AwsDevice.get_devices( arns=[SV1_ARN, DWAVE_ARN, IONQ_ARN, OQC_ARN], provider_names=["Amazon Braket", "D-Wave", "IonQ", "OQC"], @@ -1774,6 +1868,16 @@ def test_get_devices_with_error_in_region(mock_copy_session, aws_session): "providerName": "OQC", } ], + # eu-north-1 + [ + { + "deviceArn": SV1_ARN, + "deviceName": "SV1", + "deviceType": "SIMULATOR", + "deviceStatus": "ONLINE", + "providerName": "Amazon Braket", + }, + ], # Only two regions to search outside of current ValueError("should not be reachable"), ] @@ -1783,7 +1887,7 @@ def test_get_devices_with_error_in_region(mock_copy_session, aws_session): ValueError("should not be reachable"), ] mock_copy_session.return_value = session_for_region - # Search order: us-east-1, us-west-1, us-west-2, eu-west-2 + # Search order: us-east-1, us-west-1, us-west-2, eu-west-2, eu-north-1 results = AwsDevice.get_devices( statuses=["ONLINE"], aws_session=aws_session, @@ -1798,7 +1902,8 @@ def test_get_devices_invalid_order_by(): @patch("braket.aws.aws_device.datetime") def test_get_device_availability(mock_utc_now): - class Expando(object): + + class Expando: pass class MockDevice(AwsDevice): @@ -1806,19 +1911,18 @@ def __init__(self, status, *execution_window_args): self._status = status self._properties = Expando() self._properties.service = Expando() - execution_windows = [] - for execution_day, window_start_hour, window_end_hour in execution_window_args: - execution_windows.append( - DeviceExecutionWindow.parse_raw( - json.dumps( - { - "executionDay": execution_day, - "windowStartHour": window_start_hour, - "windowEndHour": window_end_hour, - } - ) + execution_windows = [ + DeviceExecutionWindow.parse_raw( + json.dumps( + { + "executionDay": execution_day, + "windowStartHour": window_start_hour, + "windowEndHour": window_end_hour, + } ) ) + for execution_day, window_start_hour, window_end_hour in execution_window_args + ] self._properties.service.executionWindows = execution_windows test_sets = ( @@ -1957,7 +2061,7 @@ def test_device_topology_graph_data(get_device_data, expected_graph, arn): def test_device_no_href(): mock_session = Mock() mock_session.get_device.return_value = MOCK_GATE_MODEL_QPU_1 - device = AwsDevice(DWAVE_ARN, mock_session) + AwsDevice(DWAVE_ARN, mock_session) def test_parse_calibration_data(): @@ -2070,3 +2174,78 @@ def test_queue_depth(arn): quantum_tasks={QueueType.NORMAL: "19", QueueType.PRIORITY: "3"}, jobs="0 (3 prioritized job(s) running)", ) + + +@pytest.fixture +def noise_model(): + return ( + NoiseModel() + .add_noise(Noise.BitFlip(0.05), GateCriteria(Gate.H)) + .add_noise(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)) + ) + + +@patch.dict( + os.environ, + {"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"}, +) +@patch("braket.aws.aws_device.AwsSession") +@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") +def test_run_with_noise_model(aws_quantum_task_mock, aws_session_init, aws_session, noise_model): + arn = DM1_ARN + aws_session_init.return_value = aws_session + aws_session.get_device.return_value = MOCK_GATE_MODEL_NOISE_SIMULATOR + device = AwsDevice(arn, noise_model=noise_model) + circuit = Circuit().h(0).cnot(0, 1) + _ = device.run(circuit) + + expected_circuit = textwrap.dedent( + """ + OPENQASM 3.0; + bit[2] b; + qubit[2] q; + h q[0]; + #pragma braket noise bit_flip(0.05) q[0] + cnot q[0], q[1]; + #pragma braket noise two_qubit_depolarizing(0.1) q[0], q[1] + b[0] = measure q[0]; + b[1] = measure q[1]; + """ + ).strip() + + expected_circuit = Circuit().h(0).bit_flip(0, 0.05).cnot(0, 1).two_qubit_depolarizing(0, 1, 0.1) + assert aws_quantum_task_mock.call_args_list[0][0][2] == expected_circuit + + +@patch.dict( + os.environ, + {"AMZN_BRAKET_TASK_RESULTS_S3_URI": "s3://env_bucket/env/path"}, +) +@patch("braket.aws.aws_device.AwsSession") +@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") +def test_run_batch_with_noise_model( + aws_quantum_task_mock, aws_session_init, aws_session, noise_model +): + arn = DM1_ARN + aws_session_init.return_value = aws_session + aws_session.get_device.return_value = MOCK_GATE_MODEL_NOISE_SIMULATOR + device = AwsDevice(arn, noise_model=noise_model) + circuit = Circuit().h(0).cnot(0, 1) + _ = device.run_batch([circuit] * 2) + + expected_circuit = textwrap.dedent( + """ + OPENQASM 3.0; + bit[2] b; + qubit[2] q; + h q[0]; + #pragma braket noise bit_flip(0.05) q[0] + cnot q[0], q[1]; + #pragma braket noise two_qubit_depolarizing(0.1) q[0], q[1] + b[0] = measure q[0]; + b[1] = measure q[1]; + """ + ).strip() + + expected_circuit = Circuit().h(0).bit_flip(0, 0.05).cnot(0, 1).two_qubit_depolarizing(0, 1, 0.1) + assert aws_quantum_task_mock.call_args_list[0][0][2] == expected_circuit diff --git a/test/unit_tests/braket/aws/test_aws_quantum_job.py b/test/unit_tests/braket/aws/test_aws_quantum_job.py index 7f9dc1a84..67ca98228 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_job.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_job.py @@ -364,7 +364,7 @@ def test_download_result_when_extract_path_not_provided( job_name = job_metadata["jobName"] quantum_job.download_result() - with open(f"{job_name}/results.json", "r") as file: + with open(f"{job_name}/results.json") as file: actual_data = json.loads(file.read())["dataDictionary"] assert expected_saved_data == actual_data @@ -382,7 +382,7 @@ def test_download_result_when_extract_path_provided( with tempfile.TemporaryDirectory() as temp_dir: quantum_job.download_result(temp_dir) - with open(f"{temp_dir}/{job_name}/results.json", "r") as file: + with open(f"{temp_dir}/{job_name}/results.json") as file: actual_data = json.loads(file.read())["dataDictionary"] assert expected_saved_data == actual_data @@ -1104,3 +1104,26 @@ def test_bad_device_arn_format(aws_session): with pytest.raises(ValueError, match=device_not_found): AwsQuantumJob._initialize_session(aws_session, "bad-arn-format", logger) + + +def test_logs_prefix(job_region, quantum_job_name, aws_session, generate_get_job_response): + aws_session.get_job.return_value = generate_get_job_response(jobName=quantum_job_name) + + # old jobs with the `arn:.../job-name` style ARN use `job-name/` as the logs prefix + name_arn = f"arn:aws:braket:{job_region}:875981177017:job/{quantum_job_name}" + quantum_job = AwsQuantumJob(name_arn, aws_session) + assert quantum_job._logs_prefix == f"{quantum_job_name}" + + # jobs with the `arn:.../uuid` style ARN use `job-name/uuid/` as the logs prefix + uuid_1 = "UUID-123456789" + uuid_2 = "UUID-987654321" + uuid_arn_1 = f"arn:aws:braket:{job_region}:875981177017:job/{uuid_1}" + uuid_job_1 = AwsQuantumJob(uuid_arn_1, aws_session) + uuid_arn_2 = f"arn:aws:braket:{job_region}:875981177017:job/{uuid_2}" + uuid_job_2 = AwsQuantumJob(uuid_arn_2, aws_session) + assert ( + uuid_job_1._logs_prefix + == f"{quantum_job_name}/{uuid_1}" + != uuid_job_2._logs_prefix + == f"{quantum_job_name}/{uuid_2}" + ) diff --git a/test/unit_tests/braket/aws/test_aws_quantum_task.py b/test/unit_tests/braket/aws/test_aws_quantum_task.py index e96af57f5..28032d943 100644 --- a/test/unit_tests/braket/aws/test_aws_quantum_task.py +++ b/test/unit_tests/braket/aws/test_aws_quantum_task.py @@ -33,6 +33,7 @@ IRType, OpenQASMSerializationProperties, QubitReferenceType, + SerializableProgram, ) from braket.device_schema import GateModelParameters, error_mitigation from braket.device_schema.dwave import ( @@ -123,6 +124,19 @@ def openqasm_program(): return OpenQASMProgram(source="OPENQASM 3.0; h $0;") +class DummySerializableProgram(SerializableProgram): + def __init__(self, source: str): + self.source = source + + def to_ir(self, ir_type: IRType = IRType.OPENQASM) -> str: + return self.source + + +@pytest.fixture +def serializable_program(): + return DummySerializableProgram(source="OPENQASM 3.0; h $0;") + + @pytest.fixture def blackbird_program(): return BlackbirdProgram(source="Vac | q[0]") @@ -172,7 +186,7 @@ def test_equality(arn, aws_session): def test_str(quantum_task): - expected = "AwsQuantumTask('id/taskArn':'{}')".format(quantum_task.id) + expected = f"AwsQuantumTask('id/taskArn':'{quantum_task.id}')" assert str(quantum_task) == expected @@ -614,6 +628,20 @@ def test_create_openqasm_program_em_serialized(aws_session, arn, openqasm_progra ) +def test_create_serializable_program(aws_session, arn, serializable_program): + aws_session.create_quantum_task.return_value = arn + shots = 21 + AwsQuantumTask.create(aws_session, SIMULATOR_ARN, serializable_program, S3_TARGET, shots) + + _assert_create_quantum_task_called_with( + aws_session, + SIMULATOR_ARN, + OpenQASMProgram(source=serializable_program.to_ir()).json(), + S3_TARGET, + shots, + ) + + def test_create_blackbird_program(aws_session, arn, blackbird_program): aws_session.create_quantum_task.return_value = arn shots = 21 @@ -676,7 +704,7 @@ def test_create_pulse_sequence(aws_session, arn, pulse_sequence): "}", ] ) - expected_program = OpenQASMProgram(source=expected_openqasm) + expected_program = OpenQASMProgram(source=expected_openqasm, inputs={}) aws_session.create_quantum_task.return_value = arn AwsQuantumTask.create(aws_session, SIMULATOR_ARN, pulse_sequence, S3_TARGET, 10) @@ -1216,20 +1244,16 @@ def _assert_create_quantum_task_called_with( } if device_parameters is not None: - test_kwargs.update({"deviceParameters": device_parameters.json(exclude_none=True)}) + test_kwargs["deviceParameters"] = device_parameters.json(exclude_none=True) if tags is not None: - test_kwargs.update({"tags": tags}) + test_kwargs["tags"] = tags if reservation_arn: - test_kwargs.update( + test_kwargs["associations"] = [ { - "associations": [ - { - "arn": reservation_arn, - "type": "RESERVATION_TIME_WINDOW_ARN", - } - ] + "arn": reservation_arn, + "type": "RESERVATION_TIME_WINDOW_ARN", } - ) + ] aws_session.create_quantum_task.assert_called_with(**test_kwargs) diff --git a/test/unit_tests/braket/aws/test_aws_session.py b/test/unit_tests/braket/aws/test_aws_session.py index c61d22606..fb0b309fb 100644 --- a/test/unit_tests/braket/aws/test_aws_session.py +++ b/test/unit_tests/braket/aws/test_aws_session.py @@ -58,7 +58,6 @@ def aws_session(boto_session, braket_client, account_id): _aws_session._sts.get_caller_identity.return_value = { "Account": account_id, } - _aws_session._s3 = Mock() return _aws_session @@ -411,7 +410,7 @@ def test_create_quantum_task_with_job_token(aws_session): } with patch.dict(os.environ, {"AMZN_BRAKET_JOB_TOKEN": job_token}): assert aws_session.create_quantum_task(**kwargs) == arn - kwargs.update({"jobToken": job_token}) + kwargs["jobToken"] = job_token aws_session.braket_client.create_quantum_task.assert_called_with(**kwargs) @@ -998,6 +997,12 @@ def test_upload_to_s3(aws_session): aws_session._s3.upload_file.assert_called_with(filename, bucket, key) +def test_account_id_idempotency(aws_session, account_id): + acc_id = aws_session.account_id + assert acc_id == aws_session.account_id + assert acc_id == account_id + + def test_upload_local_data(aws_session): with tempfile.TemporaryDirectory() as temp_dir: os.chdir(temp_dir) @@ -1279,10 +1284,10 @@ def test_describe_log_streams(aws_session, limit, next_token): } if limit: - describe_log_stream_args.update({"limit": limit}) + describe_log_stream_args["limit"] = limit if next_token: - describe_log_stream_args.update({"nextToken": next_token}) + describe_log_stream_args["nextToken"] = next_token aws_session.describe_log_streams(log_group, log_stream_prefix, limit, next_token) @@ -1309,7 +1314,7 @@ def test_get_log_events(aws_session, next_token): } if next_token: - log_events_args.update({"nextToken": next_token}) + log_events_args["nextToken"] = next_token aws_session.get_log_events(log_group, log_stream_name, start_time, start_from_head, next_token) diff --git a/test/unit_tests/braket/aws/test_direct_reservations.py b/test/unit_tests/braket/aws/test_direct_reservations.py new file mode 100644 index 000000000..332421e7a --- /dev/null +++ b/test/unit_tests/braket/aws/test_direct_reservations.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from braket.aws import AwsDevice, AwsSession, DirectReservation +from braket.devices import LocalSimulator + +RESERVATION_ARN = "arn:aws:braket:us-east-1:123456789:reservation/uuid" +DEVICE_ARN = "arn:aws:braket:us-east-1:123456789:device/qpu/ionq/Forte-1" +VALUE_ERROR_MESSAGE = "Device must be an AwsDevice or its ARN, or a local simulator device." +RUNTIME_ERROR_MESSAGE = "Another reservation is already active." + + +@pytest.fixture +def aws_device(): + mock_device = MagicMock(spec=AwsDevice) + mock_device._arn = DEVICE_ARN + type(mock_device).arn = property(lambda x: DEVICE_ARN) + return mock_device + + +def test_direct_reservation_aws_device(aws_device): + with DirectReservation(aws_device, RESERVATION_ARN) as reservation: + assert reservation.device_arn == DEVICE_ARN + assert reservation.reservation_arn == RESERVATION_ARN + assert reservation._is_active + + +def test_direct_reservation_device_str(aws_device): + with patch( + "braket.aws.AwsDevice.__init__", + side_effect=lambda self, *args, **kwargs: setattr(self, "_arn", DEVICE_ARN), + autospec=True, + ): + with patch("braket.aws.AwsDevice", return_value=aws_device, autospec=True): + with DirectReservation(DEVICE_ARN, RESERVATION_ARN) as reservation: + assert reservation.device_arn == DEVICE_ARN + assert reservation.reservation_arn == RESERVATION_ARN + assert reservation._is_active + + +def test_direct_reservation_local_simulator(): + mock_device = MagicMock(spec=LocalSimulator) + with pytest.warns(UserWarning): + with DirectReservation(mock_device, RESERVATION_ARN) as reservation: + assert os.environ["AMZN_BRAKET_RESERVATION_DEVICE_ARN"] == "" + assert os.environ["AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN"] == RESERVATION_ARN + assert reservation._is_active is True + + +@pytest.mark.parametrize("device", [123, False, [aws_device], {"a": 1}]) +def test_direct_reservation_invalid_inputs(device): + with pytest.raises(TypeError): + DirectReservation(device, RESERVATION_ARN) + + +def test_direct_reservation_local_no_reservation(): + mock_device = MagicMock(spec=LocalSimulator) + mock_device.create_quantum_task = MagicMock() + kwargs = { + "program": {"ir": '{"instructions":[]}', "qubitCount": 4}, + "shots": 1, + } + with DirectReservation(mock_device, None): + mock_device.create_quantum_task(**kwargs) + mock_device.create_quantum_task.assert_called_once_with(**kwargs) + + +def test_context_management(aws_device): + with DirectReservation(aws_device, RESERVATION_ARN): + assert os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") == DEVICE_ARN + assert os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") == RESERVATION_ARN + assert not os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") + assert not os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") + + +def test_start_reservation_already_active(aws_device): + reservation = DirectReservation(aws_device, RESERVATION_ARN) + reservation.start() + with pytest.raises(RuntimeError, match=RUNTIME_ERROR_MESSAGE): + reservation.start() + reservation.stop() + + +def test_stop_reservation_not_active(aws_device): + reservation = DirectReservation(aws_device, RESERVATION_ARN) + with pytest.warns(UserWarning): + reservation.stop() + + +def test_multiple_start_stop_cycles(aws_device): + reservation = DirectReservation(aws_device, RESERVATION_ARN) + reservation.start() + reservation.stop() + reservation.start() + reservation.stop() + assert not os.getenv("AMZN_BRAKET_RESERVATION_DEVICE_ARN") + assert not os.getenv("AMZN_BRAKET_RESERVATION_TIME_WINDOW_ARN") + + +def test_two_direct_reservations(aws_device): + with pytest.raises(RuntimeError, match=RUNTIME_ERROR_MESSAGE): + with DirectReservation(aws_device, RESERVATION_ARN): + with DirectReservation(aws_device, "reservation_arn_example_2"): + pass + + +def test_create_quantum_task_with_correct_device_and_reservation(aws_device): + kwargs = {"deviceArn": DEVICE_ARN, "shots": 1} + with patch("boto3.client"): + mock_client = MagicMock() + aws_session = AwsSession(braket_client=mock_client) + with DirectReservation(aws_device, RESERVATION_ARN): + aws_session.create_quantum_task(**kwargs) + kwargs["associations"] = [ + { + "arn": RESERVATION_ARN, + "type": "RESERVATION_TIME_WINDOW_ARN", + } + ] + mock_client.create_quantum_task.assert_called_once_with(**kwargs) + + +def test_warning_for_overridden_reservation_arn(aws_device): + kwargs = { + "deviceArn": DEVICE_ARN, + "shots": 1, + "associations": [ + { + "arn": "task_reservation_arn", + "type": "RESERVATION_TIME_WINDOW_ARN", + } + ], + } + correct_kwargs = { + "deviceArn": DEVICE_ARN, + "shots": 1, + "associations": [ + { + "arn": RESERVATION_ARN, + "type": "RESERVATION_TIME_WINDOW_ARN", + } + ], + } + with patch("boto3.client"): + mock_client = MagicMock() + aws_session = AwsSession(braket_client=mock_client) + with pytest.warns( + UserWarning, + match="A reservation ARN was passed to 'CreateQuantumTask', but it is being overridden", + ): + with DirectReservation(aws_device, RESERVATION_ARN): + aws_session.create_quantum_task(**kwargs) + mock_client.create_quantum_task.assert_called_once_with(**correct_kwargs) + + +def test_warning_not_triggered_wrong_association_type(): + kwargs = { + "deviceArn": DEVICE_ARN, + "shots": 1, + "associations": [{"type": "OTHER_TYPE"}], + } + with patch("boto3.client"): + mock_client = MagicMock() + aws_session = AwsSession(braket_client=mock_client) + aws_session.create_quantum_task(**kwargs) + mock_client.create_quantum_task.assert_called_once_with(**kwargs) diff --git a/test/unit_tests/braket/circuits/test_angled_gate.py b/test/unit_tests/braket/circuits/test_angled_gate.py index 4e093e5b4..4c5252c02 100644 --- a/test/unit_tests/braket/circuits/test_angled_gate.py +++ b/test/unit_tests/braket/circuits/test_angled_gate.py @@ -15,7 +15,7 @@ import numpy as np import pytest -from pydantic import BaseModel +from pydantic.v1 import BaseModel from braket.circuits import AngledGate, FreeParameter, FreeParameterExpression, Gate from braket.circuits.angled_gate import DoubleAngledGate, TripleAngledGate @@ -131,7 +131,7 @@ def test_np_float_angle_json(): angled_gate = AngledGate(angle=np.float32(0.15), qubit_count=1, ascii_symbols=["foo"]) angled_gate_json = BaseModel.construct(target=[0], angle=angled_gate.angle).json() match = re.match(r'\{"target": \[0], "angle": (\d*\.?\d*)}', angled_gate_json) - angle_value = float(match.group(1)) + angle_value = float(match[1]) assert angle_value == angled_gate.angle diff --git a/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py b/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py index 916bfb050..aec875568 100644 --- a/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py +++ b/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py @@ -20,12 +20,17 @@ FreeParameter, Gate, Instruction, + Noise, Observable, Operator, ) from braket.pulse import Frame, Port, PulseSequence +def _assert_correct_diagram(circ, expected): + assert AsciiCircuitDiagram.build_diagram(circ) == "\n".join(expected) + + def test_empty_circuit(): assert AsciiCircuitDiagram.build_diagram(Circuit()) == "" @@ -787,10 +792,6 @@ def test_pulse_gate_multi_qubit_circuit(): _assert_correct_diagram(circ, expected) -def _assert_correct_diagram(circ, expected): - assert AsciiCircuitDiagram.build_diagram(circ) == "\n".join(expected) - - def test_circuit_with_nested_target_list(): circ = ( Circuit() @@ -872,3 +873,86 @@ def __init__(self): "T : | 0 | 1 | 2 | 3 | 4 |", ) _assert_correct_diagram(circ, expected) + + +def test_measure(): + circ = Circuit().h(0).cnot(0, 1).measure([0]) + expected = ( + "T : |0|1|2|", + " ", + "q0 : -H-C-M-", + " | ", + "q1 : ---X---", + "", + "T : |0|1|2|", + ) + _assert_correct_diagram(circ, expected) + + +def test_measure_multiple_targets(): + circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).cnot(2, 3).measure([0, 2, 3]) + expected = ( + "T : |0|1|2|3|4|", + " ", + "q0 : -H-C-----M-", + " | ", + "q1 : ---X-C-----", + " | ", + "q2 : -----X-C-M-", + " | ", + "q3 : -------X-M-", + "", + "T : |0|1|2|3|4|", + ) + _assert_correct_diagram(circ, expected) + + +def test_measure_multiple_instructions_after(): + circ = ( + Circuit() + .h(0) + .cnot(0, 1) + .cnot(1, 2) + .cnot(2, 3) + .measure(0) + .measure(1) + .h(3) + .cnot(3, 4) + .measure([2, 3]) + ) + expected = ( + "T : |0|1|2|3|4|5|6|", + " ", + "q0 : -H-C-----M-----", + " | ", + "q1 : ---X-C---M-----", + " | ", + "q2 : -----X-C-----M-", + " | ", + "q3 : -------X-H-C-M-", + " | ", + "q4 : -----------X---", + "", + "T : |0|1|2|3|4|5|6|", + ) + _assert_correct_diagram(circ, expected) + + +def test_measure_with_readout_noise(): + circ = ( + Circuit() + .h(0) + .cnot(0, 1) + .apply_readout_noise(Noise.BitFlip(probability=0.1), target_qubits=1) + .measure([0, 1]) + ) + expected = ( + "T : |0| 1 |2|", + " ", + "q0 : -H-C---------M-", + " | ", + "q1 : ---X-BF(0.1)-M-", + "", + "T : |0| 1 |2|", + ) + _assert_correct_diagram(circ, expected) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 928cff757..71eecd1f1 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -18,15 +18,17 @@ import braket.ir.jaqcd as jaqcd from braket.circuits import ( - AsciiCircuitDiagram, Circuit, FreeParameter, + FreeParameterExpression, Gate, Instruction, Moments, + Noise, Observable, QubitSet, ResultType, + UnicodeCircuitDiagram, circuit, compiler_directives, gates, @@ -34,6 +36,8 @@ observables, ) from braket.circuits.gate_calibrations import GateCalibrations +from braket.circuits.measure import Measure +from braket.circuits.noises import BitFlip from braket.circuits.parameterizable import Parameterizable from braket.circuits.serialization import ( IRType, @@ -147,6 +151,25 @@ def pulse_sequence_2(predefined_frame_1): ) +@pytest.fixture +def pulse_sequence_3(predefined_frame_1): + return ( + PulseSequence() + .shift_phase( + predefined_frame_1, + FreeParameter("alpha"), + ) + .shift_phase( + predefined_frame_1, + FreeParameter("beta"), + ) + .play( + predefined_frame_1, + DragGaussianWaveform(length=3e-3, sigma=0.4, beta=0.2, id="drag_gauss_wf"), + ) + ) + + @pytest.fixture def gate_calibrations(pulse_sequence, pulse_sequence_2): calibration_key = (Gate.Z(), QubitSet([0, 1])) @@ -179,7 +202,7 @@ def test_repr_result_types(cnot_prob): def test_str(h): - expected = AsciiCircuitDiagram.build_diagram(h) + expected = UnicodeCircuitDiagram.build_diagram(h) assert str(h) == expected @@ -222,9 +245,7 @@ def test_call_one_param_not_bound(): circ = Circuit().h(0).rx(angle=theta, target=1).ry(angle=alpha, target=0) new_circ = circ(theta=1) expected_circ = Circuit().h(0).rx(angle=1, target=1).ry(angle=alpha, target=0) - expected_parameters = set() - expected_parameters.add(alpha) - + expected_parameters = {alpha} assert new_circ == expected_circ and new_circ.parameters == expected_parameters @@ -550,6 +571,326 @@ def test_add_verbatim_box_result_types(): ) +def test_measure(): + circ = Circuit().h(0).cnot(0, 1).measure([0]) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .add_instruction(Instruction(Measure(), 0)) + ) + assert circ == expected + + +def test_measure_int(): + circ = Circuit().h(0).cnot(0, 1).measure(0) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .add_instruction(Instruction(Measure(), 0)) + ) + assert circ == expected + + +def test_measure_multiple_targets(): + circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).cnot(2, 3).measure([0, 1, 3]) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .add_instruction(Instruction(Gate.CNot(), [1, 2])) + .add_instruction(Instruction(Gate.CNot(), [2, 3])) + .add_instruction(Instruction(Measure(), 0)) + .add_instruction(Instruction(Measure(), 1)) + .add_instruction(Instruction(Measure(), 3)) + ) + assert circ == expected + assert circ._measure_targets == [0, 1, 3] + + +def test_measure_with_noise(): + circ = Circuit().x(0).x(1).bit_flip(0, probability=0.1).measure(0) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.X(), 0)) + .add_instruction(Instruction(Gate.X(), 1)) + .add_instruction(Instruction(BitFlip(probability=0.1), 0)) + .add_instruction(Instruction(Measure(), 0)) + ) + assert circ == expected + + +def test_measure_verbatim_box(): + circ = Circuit().add_verbatim_box(Circuit().x(0).x(1)).measure(0) + expected = ( + Circuit() + .add_instruction(Instruction(compiler_directives.StartVerbatimBox())) + .add_instruction(Instruction(Gate.X(), 0)) + .add_instruction(Instruction(Gate.X(), 1)) + .add_instruction(Instruction(compiler_directives.EndVerbatimBox())) + .add_instruction(Instruction(Measure(), 0)) + ) + expected_ir = OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[1] b;", + "qubit[2] q;", + "#pragma braket verbatim", + "box{", + "x q[0];", + "x q[1];", + "}", + "b[0] = measure q[0];", + ] + ), + inputs={}, + ) + assert circ == expected + assert circ.to_ir("OPENQASM") == expected_ir + + +def test_measure_in_verbatim_subcircuit(): + message = "cannot measure a subcircuit inside a verbatim box." + with pytest.raises(ValueError, match=message): + Circuit().add_verbatim_box(Circuit().x(0).x(1).measure(0)) + + +def test_measure_qubits_out_of_range(): + circ = Circuit().h(0).cnot(0, 1).measure(4) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .add_instruction(Instruction(Measure(), 4)) + ) + assert circ == expected + + +def test_measure_empty_circuit(): + circ = Circuit().measure([0, 1, 2]) + expected = ( + Circuit() + .add_instruction(Instruction(Measure(), 0)) + .add_instruction(Instruction(Measure(), 1)) + .add_instruction(Instruction(Measure(), 2)) + ) + assert circ == expected + + +def test_measure_target_input(): + message = "Supplied qubit index, 1.1, must be an integer." + with pytest.raises(TypeError, match=message): + Circuit().h(0).cnot(0, 1).measure(1.1) + + message = "Supplied qubit index, a, must be an integer." + with pytest.raises(TypeError, match=message): + Circuit().h(0).cnot(0, 1).measure(FreeParameter("a")) + + +def test_measure_with_result_types(): + message = "a circuit cannot contain both measure instructions and result types." + with pytest.raises(ValueError, match=message): + Circuit().h(0).sample(observable=Observable.Z(), target=0).measure(0) + + +def test_result_type_with_measure(): + message = "cannot add a result type to a circuit which already contains a measure instruction." + with pytest.raises(ValueError, match=message): + Circuit().h(0).measure(0).sample(observable=Observable.Z(), target=0) + + +def test_measure_with_multiple_measures(): + circ = Circuit().h(0).cnot(0, 1).h(2).measure([0, 1]).measure(2) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .add_instruction(Instruction(Gate.H(), 2)) + .add_instruction(Instruction(Measure(), 0)) + .add_instruction(Instruction(Measure(), 1)) + .add_instruction(Instruction(Measure(), 2)) + ) + assert circ == expected + + +def test_measure_same_qubit_twice(): + # message = "cannot measure the same qubit\\(s\\) Qubit\\(0\\) more than once." + message = "cannot apply instruction to measured qubits." + with pytest.raises(ValueError, match=message): + Circuit().h(0).cnot(0, 1).measure(0).measure(1).measure(0) + + +def test_measure_same_qubit_twice_with_list(): + # message = "cannot measure the same qubit\\(s\\) Qubit\\(0\\) more than once." + message = "cannot apply instruction to measured qubits." + with pytest.raises(ValueError, match=message): + Circuit().h(0).cnot(0, 1).measure(0).measure([0, 1]) + + +def test_measure_same_qubit_twice_with_one_measure(): + message = "cannot repeat qubit\\(s\\) 0 in the same measurement." + with pytest.raises(ValueError, match=message): + Circuit().h(0).cnot(0, 1).measure([0, 0, 0]) + + +def test_measure_gate_after(): + # message = "cannot add a gate or noise operation on a qubit after a measure instruction." + message = "cannot apply instruction to measured qubits." + with pytest.raises(ValueError, match=message): + Circuit().h(0).measure(0).h([0, 1]) + + # message = "cannot add a gate or noise operation on a qubit after a measure instruction." + message = "cannot apply instruction to measured qubits." + with pytest.raises(ValueError, match=message): + instr = Instruction(Gate.CNot(), [0, 1]) + Circuit().measure([0, 1]).add_instruction(instr, target_mapping={0: 0, 1: 1}) + + # message = "cannot add a gate or noise operation on a qubit after a measure instruction." + message = "cannot apply instruction to measured qubits." + with pytest.raises(ValueError, match=message): + instr = Instruction(Gate.CNot(), [0, 1]) + Circuit().h(0).measure(0).add_instruction(instr, target=[0, 1]) + + +def test_measure_noise_after(): + # message = "cannot add a gate or noise operation on a qubit after a measure instruction." + message = "cannot apply instruction to measured qubits." + with pytest.raises(ValueError, match=message): + Circuit().h(1).h(1).h(2).h(5).h(4).h(3).cnot(1, 2).measure([0, 1, 2, 3, 4]).kraus( + targets=[0], matrices=[np.array([[1, 0], [0, 1]])] + ) + + +def test_measure_with_readout_noise(): + circ = ( + Circuit() + .h(0) + .cnot(0, 1) + .apply_readout_noise(Noise.BitFlip(probability=0.1), target_qubits=1) + .measure([0, 1]) + ) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .apply_readout_noise(Noise.BitFlip(probability=0.1), target_qubits=1) + .add_instruction(Instruction(Measure(), 0)) + .add_instruction(Instruction(Measure(), 1)) + ) + assert circ == expected + + +def test_measure_gate_after_with_target_mapping(): + # message = "cannot add a gate or noise operation on a qubit after a measure instruction." + message = "cannot apply instruction to measured qubits." + instr = Instruction(Gate.CNot(), [0, 1]) + with pytest.raises(ValueError, match=message): + Circuit().h(0).cnot(0, 1).cnot(1, 2).measure([0, 1]).add_instruction( + instr, target_mapping={0: 10, 1: 11} + ) + + +def test_measure_gate_after_with_target(): + # message = "cannot add a gate or noise operation on a qubit after a measure instruction." + message = "cannot apply instruction to measured qubits." + instr = Instruction(Gate.CNot(), [0, 1]) + with pytest.raises(ValueError, match=message): + Circuit().h(0).cnot(0, 1).cnot(1, 2).measure([0, 1]).add_instruction(instr, target=[10, 11]) + + +def test_measure_gate_after_measurement(): + circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).measure(0).h(2) + expected = ( + Circuit() + .add_instruction(Instruction(Gate.H(), 0)) + .add_instruction(Instruction(Gate.CNot(), [0, 1])) + .add_instruction(Instruction(Gate.CNot(), [1, 2])) + .add_instruction(Instruction(Measure(), 0)) + .add_instruction(Instruction(Gate.H(), 2)) + ) + assert circ == expected + + +def test_to_ir_with_measure(): + circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).measure([0, 2]) + expected_ir = OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[2] b;", + "qubit[3] q;", + "h q[0];", + "cnot q[0], q[1];", + "cnot q[1], q[2];", + "b[0] = measure q[0];", + "b[1] = measure q[2];", + ] + ), + inputs={}, + ) + assert circ.to_ir("OPENQASM") == expected_ir + + +def test_from_ir_with_measure(): + ir = OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[1] b;", + "qubit[3] q;", + "h q[0];", + "cnot q[0], q[1];", + "cnot q[1], q[2];", + "b[0] = measure q[0];", + "b[1] = measure q[2];", + ] + ), + inputs={}, + ) + expected_circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).measure(0).measure(2) + assert Circuit.from_ir(source=ir.source, inputs=ir.inputs) == expected_circ + + +def test_from_ir_with_single_measure(): + ir = OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[2] b;", + "qubit[2] q;", + "h q[0];", + "cnot q[0], q[1];", + "b = measure q;", + ] + ), + inputs={}, + ) + expected_circ = Circuit().h(0).cnot(0, 1).measure(0).measure(1) + assert Circuit.from_ir(source=ir.source, inputs=ir.inputs) == expected_circ + + +def test_from_ir_round_trip_transformation(): + circuit = Circuit().h(0).cnot(0, 1).measure(0).measure(1) + ir = OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[2] b;", + "qubit[2] q;", + "h q[0];", + "cnot q[0], q[1];", + "b = measure q;", + ] + ), + inputs={}, + ) + + assert Circuit.from_ir(ir) == Circuit.from_ir(circuit.to_ir("OPENQASM")) + assert circuit.to_ir("OPENQASM") == Circuit.from_ir(ir).to_ir("OPENQASM") + + def test_add_with_instruction_with_default(cnot_instr): circ = Circuit().add(cnot_instr) assert circ == Circuit().add_instruction(cnot_instr) @@ -726,6 +1067,44 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): assert circ.to_ir() == expected +@pytest.mark.parametrize( + "circuit, serialization_properties, expected_ir", + [ + ( + Circuit() + .rx(0, 0.15) + .ry(1, FreeParameterExpression("0.3")) + .rx(2, 3 * FreeParameterExpression(1)), + OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "bit[3] b;", + "qubit[3] q;", + "rx(0.15) q[0];", + "ry(0.3) q[1];", + "rx(3) q[2];", + "b[0] = measure q[0];", + "b[1] = measure q[1];", + "b[2] = measure q[2];", + ] + ), + inputs={}, + ), + ), + ], +) +def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir): + assert ( + circuit.to_ir( + ir_type=IRType.OPENQASM, + serialization_properties=serialization_properties, + ) + == expected_ir + ) + + @pytest.mark.parametrize( "circuit, serialization_properties, expected_ir", [ @@ -740,7 +1119,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[2] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -769,7 +1148,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "bit[2] b;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -800,7 +1179,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "OPENQASM 3.0;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -835,7 +1214,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[5] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -866,7 +1245,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[2] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -899,7 +1278,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[5] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -933,7 +1312,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[7] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -965,7 +1344,7 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): "qubit[2] q;", "cal {", " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + + "(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -1000,7 +1379,9 @@ def test_ir_non_empty_instructions_result_types_basis_rotation_instructions(): ), ], ) -def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir, gate_calibrations): +def test_circuit_to_ir_openqasm_with_gate_calibrations( + circuit, serialization_properties, expected_ir, gate_calibrations +): copy_of_gate_calibrations = gate_calibrations.copy() assert ( circuit.to_ir( @@ -1013,6 +1394,55 @@ def test_circuit_to_ir_openqasm(circuit, serialization_properties, expected_ir, assert copy_of_gate_calibrations.pulse_sequences == gate_calibrations.pulse_sequences +@pytest.mark.parametrize( + "circuit, calibration_key, expected_ir", + [ + ( + Circuit().rx(0, 0.2), + (Gate.Rx(FreeParameter("alpha")), QubitSet(0)), + OpenQasmProgram( + source="\n".join( + [ + "OPENQASM 3.0;", + "input float beta;", + "bit[1] b;", + "qubit[1] q;", + "cal {", + " waveform drag_gauss_wf = drag_gaussian(3.0ms," + " 400.0ms, 0.2, 1, false);", + "}", + "defcal rx(0.2) $0 {", + " shift_phase(predefined_frame_1, 0.2);", + " shift_phase(predefined_frame_1, beta);", + " play(predefined_frame_1, drag_gauss_wf);", + "}", + "rx(0.2) q[0];", + "b[0] = measure q[0];", + ] + ), + inputs={}, + ), + ), + ], +) +def test_circuit_with_parametric_defcal(circuit, calibration_key, expected_ir, pulse_sequence_3): + serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) + gate_calibrations = GateCalibrations( + { + calibration_key: pulse_sequence_3, + } + ) + + assert ( + circuit.to_ir( + ir_type=IRType.OPENQASM, + serialization_properties=serialization_properties, + gate_definitions=gate_calibrations.pulse_sequences, + ) + == expected_ir + ) + + def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): circ = Circuit().h(0, power=-2.5).h(0, power=0).rx(0, angle=FreeParameter("theta")) serialization_properties = OpenQASMSerializationProperties(QubitReferenceType.VIRTUAL) @@ -1033,8 +1463,7 @@ def test_parametric_circuit_with_fixed_argument_defcal(pulse_sequence): "bit[1] b;", "qubit[1] q;", "cal {", - " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal z $0, $1 {", " set_frequency(predefined_frame_1, 6000000.0);", @@ -1131,8 +1560,7 @@ def foo( "bit[1] b;", "qubit[1] q;", "cal {", - " waveform drag_gauss_wf = drag_gaussian" - + "(3000000.0ns, 400000000.0ns, 0.2, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", "}", "defcal foo(-0.2) $0 {", " shift_phase(predefined_frame_1, -0.1);", @@ -1161,7 +1589,7 @@ def foo( "expected_circuit, ir", [ ( - Circuit().h(0, control=1, control_state=0), + Circuit().h(0, control=1, control_state=0).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1177,7 +1605,7 @@ def foo( ), ), ( - Circuit().cnot(target=0, control=1), + Circuit().cnot(target=0, control=1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1193,7 +1621,7 @@ def foo( ), ), ( - Circuit().x(0, control=[1], control_state=[0]), + Circuit().x(0, control=[1], control_state=[0]).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1209,7 +1637,7 @@ def foo( ), ), ( - Circuit().rx(0, 0.15, control=1, control_state=1), + Circuit().rx(0, 0.15, control=1, control_state=1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1225,7 +1653,7 @@ def foo( ), ), ( - Circuit().ry(0, 0.2, control=1, control_state=1), + Circuit().ry(0, 0.2, control=1, control_state=1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1241,7 +1669,7 @@ def foo( ), ), ( - Circuit().rz(0, 0.25, control=[1], control_state=[0]), + Circuit().rz(0, 0.25, control=[1], control_state=[0]).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1257,7 +1685,7 @@ def foo( ), ), ( - Circuit().s(target=0, control=[1], control_state=[0]), + Circuit().s(target=0, control=[1], control_state=[0]).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1273,7 +1701,7 @@ def foo( ), ), ( - Circuit().t(target=1, control=[0], control_state=[0]), + Circuit().t(target=1, control=[0], control_state=[0]).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1289,7 +1717,7 @@ def foo( ), ), ( - Circuit().cphaseshift(target=0, control=1, angle=0.15), + Circuit().cphaseshift(target=0, control=1, angle=0.15).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1305,7 +1733,7 @@ def foo( ), ), ( - Circuit().ccnot(*[0, 1], target=2), + Circuit().ccnot(*[0, 1], target=2).measure(0).measure(1).measure(2), OpenQasmProgram( source="\n".join( [ @@ -1392,7 +1820,7 @@ def foo( ), ), ( - Circuit().bit_flip(0, 0.1), + Circuit().bit_flip(0, 0.1).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1407,7 +1835,7 @@ def foo( ), ), ( - Circuit().generalized_amplitude_damping(0, 0.1, 0.1), + Circuit().generalized_amplitude_damping(0, 0.1, 0.1).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1422,7 +1850,7 @@ def foo( ), ), ( - Circuit().phase_flip(0, 0.2), + Circuit().phase_flip(0, 0.2).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1437,7 +1865,7 @@ def foo( ), ), ( - Circuit().depolarizing(0, 0.5), + Circuit().depolarizing(0, 0.5).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1452,7 +1880,7 @@ def foo( ), ), ( - Circuit().amplitude_damping(0, 0.8), + Circuit().amplitude_damping(0, 0.8).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1467,7 +1895,7 @@ def foo( ), ), ( - Circuit().phase_damping(0, 0.1), + Circuit().phase_damping(0, 0.1).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1499,7 +1927,12 @@ def foo( Circuit() .rx(0, 0.15, control=2, control_state=0) .rx(1, 0.3, control=[2, 3]) - .cnot(target=0, control=[2, 3, 4]), + .cnot(target=0, control=[2, 3, 4]) + .measure(0) + .measure(1) + .measure(2) + .measure(3) + .measure(4), OpenQasmProgram( source="\n".join( [ @@ -1520,7 +1953,17 @@ def foo( ), ), ( - Circuit().cnot(0, 1).cnot(target=2, control=3).cnot(target=4, control=[5, 6]), + Circuit() + .cnot(0, 1) + .cnot(target=2, control=3) + .cnot(target=4, control=[5, 6]) + .measure(0) + .measure(1) + .measure(2) + .measure(3) + .measure(4) + .measure(5) + .measure(6), OpenQasmProgram( source="\n".join( [ @@ -1543,7 +1986,7 @@ def foo( ), ), ( - Circuit().h(0, power=-2.5).h(0, power=0), + Circuit().h(0, power=-2.5).h(0, power=0).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1559,7 +2002,7 @@ def foo( ), ), ( - Circuit().unitary(matrix=np.array([[0, 1], [1, 0]]), targets=[0]), + Circuit().unitary(matrix=np.array([[0, 1], [1, 0]]), targets=[0]).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1574,7 +2017,7 @@ def foo( ), ), ( - Circuit().pauli_channel(0, probX=0.1, probY=0.2, probZ=0.3), + Circuit().pauli_channel(0, probX=0.1, probY=0.2, probZ=0.3).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1589,7 +2032,7 @@ def foo( ), ), ( - Circuit().two_qubit_depolarizing(0, 1, probability=0.1), + Circuit().two_qubit_depolarizing(0, 1, probability=0.1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1605,7 +2048,7 @@ def foo( ), ), ( - Circuit().two_qubit_dephasing(0, 1, probability=0.1), + Circuit().two_qubit_dephasing(0, 1, probability=0.1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1621,7 +2064,7 @@ def foo( ), ), ( - Circuit().two_qubit_dephasing(0, 1, probability=0.1), + Circuit().two_qubit_dephasing(0, 1, probability=0.1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1680,13 +2123,15 @@ def foo( ), ), ( - Circuit().kraus( + Circuit() + .kraus( [0], matrices=[ np.array([[0.9486833j, 0], [0, 0.9486833j]]), np.array([[0, 0.31622777], [0.31622777, 0]]), ], - ), + ) + .measure(0), OpenQasmProgram( source="\n".join( [ @@ -1703,7 +2148,7 @@ def foo( ), ), ( - Circuit().rx(0, FreeParameter("theta")), + Circuit().rx(0, FreeParameter("theta")).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1719,7 +2164,7 @@ def foo( ), ), ( - Circuit().rx(0, np.pi), + Circuit().rx(0, np.pi).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1734,7 +2179,7 @@ def foo( ), ), ( - Circuit().rx(0, 2 * np.pi), + Circuit().rx(0, 2 * np.pi).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1749,7 +2194,7 @@ def foo( ), ), ( - Circuit().gphase(0.15).x(0), + Circuit().gphase(0.15).x(0).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1772,7 +2217,7 @@ def test_from_ir(expected_circuit, ir): def test_from_ir_inputs_updated(): - circuit = Circuit().rx(0, 0.2).ry(0, 0.1) + circuit = Circuit().rx(0, 0.2).ry(0, 0.1).measure(0) openqasm = OpenQasmProgram( source="\n".join( [ @@ -1795,7 +2240,7 @@ def test_from_ir_inputs_updated(): "expected_circuit, ir", [ ( - Circuit().h(0).cnot(0, 1), + Circuit().h(0).cnot(0, 1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1815,7 +2260,7 @@ def test_from_ir_inputs_updated(): ), ), ( - Circuit().h(0).h(1), + Circuit().h(0).h(1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1835,7 +2280,7 @@ def test_from_ir_inputs_updated(): ), ), ( - Circuit().h(0).h(1).cnot(0, 1), + Circuit().h(0).h(1).cnot(0, 1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1854,7 +2299,7 @@ def test_from_ir_inputs_updated(): ), ), ( - Circuit().h(0).h(1).cnot(0, 1), + Circuit().h(0).h(1).cnot(0, 1).measure(0).measure(1), OpenQasmProgram( source="\n".join( [ @@ -1873,7 +2318,7 @@ def test_from_ir_inputs_updated(): ), ), ( - Circuit().x(0), + Circuit().x(0).measure(0), OpenQasmProgram( source="\n".join( [ @@ -1891,12 +2336,13 @@ def test_from_ir_inputs_updated(): ), ), ( - Circuit().rx(0, FreeParameter("theta")).rx(0, 2 * FreeParameter("theta")), + Circuit().rx(0, FreeParameter("theta")).rx(0, 2 * FreeParameter("theta")).measure(0), OpenQasmProgram( source="\n".join( [ "OPENQASM 3.0;", - "input float theta;" "bit[1] b;", + "input float theta;", + "bit[1] b;", "qubit[1] q;", "rx(theta) q[0];", "rx(2*theta) q[0];", @@ -2088,6 +2534,7 @@ def test_to_unitary_with_global_phase(): (Circuit().cphaseshift00(0, 1, 0.15), gates.CPhaseShift00(0.15).to_matrix()), (Circuit().cphaseshift01(0, 1, 0.15), gates.CPhaseShift01(0.15).to_matrix()), (Circuit().cphaseshift10(0, 1, 0.15), gates.CPhaseShift10(0.15).to_matrix()), + (Circuit().prx(0, 1, 0.15), gates.PRx(1, 0.15).to_matrix()), (Circuit().cy(0, 1), gates.CY().to_matrix()), (Circuit().cz(0, 1), gates.CZ().to_matrix()), (Circuit().xx(0, 1, 0.15), gates.XX(0.15).to_matrix()), @@ -2770,9 +3217,7 @@ def test_add_parameterized_check_true(): .ry(angle=theta, target=2) .ry(angle=theta, target=3) ) - expected = set() - expected.add(theta) - + expected = {theta} assert circ.parameters == expected @@ -2782,10 +3227,7 @@ def test_add_parameterized_instr_parameterized_circ_check_true(): alpha2 = FreeParameter("alpha") circ = Circuit().ry(angle=theta, target=0).ry(angle=alpha2, target=1).ry(angle=theta, target=2) circ.add_instruction(Instruction(Gate.Ry(alpha), 3)) - expected = set() - expected.add(theta) - expected.add(alpha) - + expected = {theta, alpha} assert circ.parameters == expected @@ -2793,9 +3235,7 @@ def test_add_non_parameterized_instr_parameterized_check_true(): theta = FreeParameter("theta") circ = Circuit().ry(angle=theta, target=0).ry(angle=theta, target=1).ry(angle=theta, target=2) circ.add_instruction(Instruction(Gate.Ry(0.1), 3)) - expected = set() - expected.add(theta) - + expected = {theta} assert circ.parameters == expected @@ -2803,9 +3243,7 @@ def test_add_circ_parameterized_check_true(): theta = FreeParameter("theta") circ = Circuit().ry(angle=1, target=0).add_circuit(Circuit().ry(angle=theta, target=0)) - expected = set() - expected.add(theta) - + expected = {theta} assert circ.parameters == expected @@ -2813,9 +3251,7 @@ def test_add_circ_not_parameterized_check_true(): theta = FreeParameter("theta") circ = Circuit().ry(angle=theta, target=0).add_circuit(Circuit().ry(angle=0.1, target=0)) - expected = set() - expected.add(theta) - + expected = {theta} assert circ.parameters == expected @@ -2836,9 +3272,7 @@ def test_parameterized_check_false(input_circ): def test_parameters(): theta = FreeParameter("theta") circ = Circuit().ry(angle=theta, target=0).ry(angle=theta, target=1).ry(angle=theta, target=2) - expected = set() - expected.add(theta) - + expected = {theta} assert circ.parameters == expected @@ -2896,9 +3330,7 @@ def test_make_bound_circuit_partial_bind(): expected_circ = ( Circuit().ry(angle=np.pi, target=0).ry(angle=np.pi, target=1).ry(angle=alpha, target=2) ) - expected_parameters = set() - expected_parameters.add(alpha) - + expected_parameters = {alpha} assert circ_new == expected_circ and circ_new.parameters == expected_parameters @@ -3003,11 +3435,9 @@ def test_pulse_circuit_to_openqasm(predefined_frame_1, user_defined_frame): "bit[2] b;", "cal {", " frame user_defined_frame_0 = newframe(device_port_x0, 10000000.0, 3.14);", - " waveform gauss_wf = gaussian(1000000.0ns, 700000000.0ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform drag_gauss_wf_2 = drag_gaussian(3000000.0ns, 400000000.0ns, " - "0.2, 1, false);", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1," " false);", + " waveform drag_gauss_wf_2 = drag_gaussian(3.0ms, 400.0ms, " "0.2, 1, false);", "}", "h $0;", "cal {", @@ -3105,7 +3535,7 @@ def test_parametrized_pulse_circuit(user_defined_frame): Circuit().rx(angle=theta, target=0).pulse_gate(pulse_sequence=pulse_sequence, targets=1) ) - assert circuit.parameters == set([frequency_parameter, length, theta]) + assert circuit.parameters == {frequency_parameter, length, theta} bound_half = circuit(theta=0.5, length=1e-5) assert bound_half.to_ir( @@ -3120,7 +3550,7 @@ def test_parametrized_pulse_circuit(user_defined_frame): "bit[2] b;", "cal {", " frame user_defined_frame_0 = newframe(device_port_x0, 10000000.0, 3.14);", - " waveform gauss_wf = gaussian(10000.0ns, 700000000.0ns, 1, false);", + " waveform gauss_wf = gaussian(10.0us, 700.0ms, 1, false);", "}", "rx(0.5) $0;", "cal {", @@ -3145,7 +3575,7 @@ def test_parametrized_pulse_circuit(user_defined_frame): "bit[2] b;", "cal {", " frame user_defined_frame_0 = newframe(device_port_x0, 10000000.0, 3.14);", - " waveform gauss_wf = gaussian(10000.0ns, 700000000.0ns, 1, false);", + " waveform gauss_wf = gaussian(10.0us, 700.0ms, 1, false);", "}", "rx(0.5) $0;", "cal {", diff --git a/test/unit_tests/braket/circuits/test_gate_calibration.py b/test/unit_tests/braket/circuits/test_gate_calibration.py index c95ce74a3..de3fa2fc6 100644 --- a/test/unit_tests/braket/circuits/test_gate_calibration.py +++ b/test/unit_tests/braket/circuits/test_gate_calibration.py @@ -91,7 +91,7 @@ def test_to_ir(pulse_sequence): "OPENQASM 3.0;", "defcal rx(1.0) $0, $1 {", " barrier test_frame_rf;", - " delay[1000000000000.0ns] test_frame_rf;", + " delay[1000s] test_frame_rf;", "}", ] ) @@ -111,7 +111,7 @@ def test_to_ir_with_bad_key(pulse_sequence): "OPENQASM 3.0;", "defcal z $0, $1 {", " barrier test_frame_rf;", - " delay[1000000000000.0ns] test_frame_rf;", + " delay[1000s] test_frame_rf;", "}", ] ) @@ -129,7 +129,7 @@ def test_to_ir_with_key(pulse_sequence): "OPENQASM 3.0;", "defcal z $0, $1 {", " barrier test_frame_rf;", - " delay[1000000000000.0ns] test_frame_rf;", + " delay[1000s] test_frame_rf;", "}", ] ) diff --git a/test/unit_tests/braket/circuits/test_gates.py b/test/unit_tests/braket/circuits/test_gates.py index fc8fe7787..0b9ce52d7 100644 --- a/test/unit_tests/braket/circuits/test_gates.py +++ b/test/unit_tests/braket/circuits/test_gates.py @@ -39,6 +39,10 @@ class NoTarget: pass +class DoubleAngle: + pass + + class TripleAngle: pass @@ -103,6 +107,7 @@ class SingleNegControlModifier: (Gate.ZZ, "zz", ir.ZZ, [DoubleTarget, Angle], {}), (Gate.GPi, "gpi", None, [SingleTarget, Angle], {}), (Gate.GPi2, "gpi2", None, [SingleTarget, Angle], {}), + (Gate.PRx, "prx", None, [SingleTarget, DoubleAngle], {}), (Gate.MS, "ms", None, [DoubleTarget, TripleAngle], {}), ( Gate.Unitary, @@ -145,9 +150,11 @@ class SingleNegControlModifier: Gate.CPhaseShift10, Gate.GPi, Gate.GPi2, + Gate.PRx, Gate.MS, ] + invalid_unitary_matrices = [ (np.array([[1]])), (np.array([1])), @@ -179,6 +186,10 @@ def angle_valid_input(**kwargs): return {"angle": 0.123} +def double_angle_valid_input(**kwargs): + return {"angle_1": 0.123, "angle_2": 3.567} + + def triple_angle_valid_input(**kwargs): return {"angle_1": 0.123, "angle_2": 4.567, "angle_3": 8.910} @@ -217,6 +228,7 @@ def two_dimensional_matrix_valid_input(**kwargs): "SingleTarget": single_target_valid_input, "DoubleTarget": double_target_valid_ir_input, "Angle": angle_valid_input, + "DoubleAngle": double_angle_valid_input, "TripleAngle": triple_angle_valid_input, "SingleControl": single_control_valid_input, "SingleNegControlModifier": single_neg_control_valid_input, @@ -238,21 +250,20 @@ def two_dimensional_matrix_valid_input(**kwargs): def create_valid_ir_input(irsubclasses): input = {} for subclass in irsubclasses: - input.update(valid_ir_switcher.get(subclass.__name__, lambda: "Invalid subclass")()) + input |= valid_ir_switcher.get(subclass.__name__, lambda: "Invalid subclass")() return input def create_valid_subroutine_input(irsubclasses, **kwargs): input = {} for subclass in irsubclasses: - input.update( - valid_subroutine_switcher.get(subclass.__name__, lambda: "Invalid subclass")(**kwargs) + input |= valid_subroutine_switcher.get(subclass.__name__, lambda: "Invalid subclass")( + **kwargs ) return input def create_valid_target_input(irsubclasses): - input = {} qubit_set = [] control_qubit_set = [] control_state = None @@ -273,11 +284,9 @@ def create_valid_target_input(irsubclasses): control_state = list(single_neg_control_valid_input()["control_state"]) elif subclass == DoubleControl: qubit_set = list(double_control_valid_ir_input().values()) + qubit_set - elif subclass in (Angle, TwoDimensionalMatrix, TripleAngle): - pass - else: + elif subclass not in (Angle, TwoDimensionalMatrix, DoubleAngle, TripleAngle): raise ValueError("Invalid subclass") - input["target"] = QubitSet(qubit_set) + input = {"target": QubitSet(qubit_set)} input["control"] = QubitSet(control_qubit_set) input["control_state"] = control_state return input @@ -287,6 +296,8 @@ def create_valid_gate_class_input(irsubclasses, **kwargs): input = {} if Angle in irsubclasses: input.update(angle_valid_input()) + if DoubleAngle in irsubclasses: + input.update(double_angle_valid_input()) if TripleAngle in irsubclasses: input.update(triple_angle_valid_input()) if TwoDimensionalMatrix in irsubclasses: @@ -313,9 +324,13 @@ def calculate_qubit_count(irsubclasses): qubit_count += 2 elif subclass == MultiTarget: qubit_count += 3 - elif subclass in (NoTarget, Angle, TwoDimensionalMatrix, TripleAngle): - pass - else: + elif subclass not in ( + NoTarget, + Angle, + TwoDimensionalMatrix, + DoubleAngle, + TripleAngle, + ): raise ValueError("Invalid subclass") return qubit_count @@ -847,6 +862,18 @@ def test_ir_gate_level(testclass, subroutine_name, irclass, irsubclasses, kwargs OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.PHYSICAL), "gpi2(0.17) $4;", ), + ( + Gate.PRx(angle_1=0.17, angle_2=3.45), + [4], + OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL), + "prx(0.17, 3.45) q[4];", + ), + ( + Gate.PRx(angle_1=0.17, angle_2=3.45), + [4], + OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.PHYSICAL), + "prx(0.17, 3.45) $4;", + ), ( Gate.MS(angle_1=0.17, angle_2=3.45), [4, 5], @@ -890,18 +917,19 @@ def test_gate_subroutine(testclass, subroutine_name, irclass, irsubclasses, kwar ) if qubit_count == 1: multi_targets = [0, 1, 2] - instruction_list = [] - for target in multi_targets: - instruction_list.append( - Instruction( - operator=testclass(**create_valid_gate_class_input(irsubclasses, **kwargs)), - target=target, - ) + instruction_list = [ + Instruction( + operator=testclass(**create_valid_gate_class_input(irsubclasses, **kwargs)), + target=target, ) + for target in multi_targets + ] subroutine = getattr(Circuit(), subroutine_name) subroutine_input = {"target": multi_targets} if Angle in irsubclasses: subroutine_input.update(angle_valid_input()) + if DoubleAngle in irsubclasses: + subroutine_input.update(double_angle_valid_input()) if TripleAngle in irsubclasses: subroutine_input.update(triple_angle_valid_input()) assert subroutine(**subroutine_input) == Circuit(instruction_list) @@ -1012,8 +1040,13 @@ def test_large_unitary(): @pytest.mark.parametrize("gate", parameterizable_gates) def test_bind_values(gate): + double_angled = gate.__name__ in ["PRx"] triple_angled = gate.__name__ in ("MS", "U") - num_params = 3 if triple_angled else 1 + num_params = 1 + if triple_angled: + num_params = 3 + elif double_angled: + num_params = 2 thetas = [FreeParameter(f"theta_{i}") for i in range(num_params)] mapping = {f"theta_{i}": i for i in range(num_params)} param_gate = gate(*thetas) @@ -1024,6 +1057,9 @@ def test_bind_values(gate): if triple_angled: for angle in new_gate.angle_1, new_gate.angle_2, new_gate.angle_3: assert isinstance(angle, float) + elif double_angled: + for angle in new_gate.angle_1, new_gate.angle_2: + assert isinstance(angle, float) else: assert isinstance(new_gate.angle, float) @@ -1048,8 +1084,8 @@ def to_ir(pulse_gate): assert a_bound_ir == "\n".join( [ "cal {", - " set_frequency(user_frame, b + 3);", - " delay[(1000000000.0*c)ns] user_frame;", + " set_frequency(user_frame, 3.0 + b);", + " delay[c * 1s] user_frame;", "}", ] ) diff --git a/test/unit_tests/braket/circuits/test_instruction.py b/test/unit_tests/braket/circuits/test_instruction.py index 212e65949..1d04627e9 100644 --- a/test/unit_tests/braket/circuits/test_instruction.py +++ b/test/unit_tests/braket/circuits/test_instruction.py @@ -107,14 +107,11 @@ def test_adjoint_unsupported(): def test_str(instr): expected = ( - "Instruction('operator': {}, 'target': {}, " - "'control': {}, 'control_state': {}, 'power': {})" - ).format( - instr.operator, - instr.target, - instr.control, - instr.control_state.as_tuple, - instr.power, + f"Instruction('operator': {instr.operator}, " + f"'target': {instr.target}, " + f"'control': {instr.control}, " + f"'control_state': {instr.control_state.as_tuple}, " + f"'power': {instr.power})" ) assert str(instr) == expected diff --git a/test/unit_tests/braket/circuits/test_measure.py b/test/unit_tests/braket/circuits/test_measure.py new file mode 100644 index 000000000..8911da5e4 --- /dev/null +++ b/test/unit_tests/braket/circuits/test_measure.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import pytest + +from braket.circuits.measure import Measure +from braket.circuits.quantum_operator import QuantumOperator +from braket.circuits.serialization import ( + IRType, + OpenQASMSerializationProperties, + QubitReferenceType, +) + + +@pytest.fixture +def measure(): + return Measure() + + +def test_is_operator(measure): + assert isinstance(measure, QuantumOperator) + + +def test_equality(): + measure1 = Measure() + measure2 = Measure() + non_measure = "non measure" + + assert measure1 == measure2 + assert measure1 is not measure2 + assert measure1 != non_measure + + +def test_ascii_symbols(measure): + assert measure.ascii_symbols == ("M",) + + +def test_str(measure): + assert str(measure) == measure.name + + +@pytest.mark.parametrize( + "ir_type, serialization_properties, expected_exception, expected_message", + [ + ( + IRType.JAQCD, + None, + NotImplementedError, + "measure instructions are not supported with JAQCD.", + ), + ("invalid-ir-type", None, ValueError, "supplied ir_type invalid-ir-type is not supported."), + ], +) +def test_measure_to_ir( + ir_type, serialization_properties, expected_exception, expected_message, measure +): + with pytest.raises(expected_exception) as exc: + measure.to_ir(ir_type=ir_type, serialization_properties=serialization_properties) + assert exc.value.args[0] == expected_message + + +@pytest.mark.parametrize( + "measure, target, serialization_properties, expected_ir", + [ + ( + Measure(), + [0], + OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL), + "b[0] = measure q[0];", + ), + ( + Measure(), + [1, 4], + OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.PHYSICAL), + "\n".join( + [ + "b[0] = measure $1;", + "b[1] = measure $4;", + ] + ), + ), + ], +) +def test_measure_to_ir_openqasm(measure, target, serialization_properties, expected_ir): + assert ( + measure.to_ir( + target, ir_type=IRType.OPENQASM, serialization_properties=serialization_properties + ) + == expected_ir + ) diff --git a/test/unit_tests/braket/circuits/test_moments.py b/test/unit_tests/braket/circuits/test_moments.py index 982649f62..ed45b0aa3 100644 --- a/test/unit_tests/braket/circuits/test_moments.py +++ b/test/unit_tests/braket/circuits/test_moments.py @@ -153,7 +153,7 @@ def test_getitem(): def test_iter(moments): - assert [key for key in moments] == list(moments.keys()) + assert list(moments) == list(moments.keys()) def test_len(): diff --git a/test/unit_tests/braket/circuits/test_noises.py b/test/unit_tests/braket/circuits/test_noises.py index 2b55dfa4f..5fffba43f 100644 --- a/test/unit_tests/braket/circuits/test_noises.py +++ b/test/unit_tests/braket/circuits/test_noises.py @@ -225,28 +225,27 @@ def multi_probability_invalid_input(**kwargs): "TwoDimensionalMatrixList": two_dimensional_matrix_list_valid_input, "DoubleTarget": double_target_valid_input, "DoubleControl": double_control_valid_input, - } + }, ) def create_valid_ir_input(irsubclasses): input = {} for subclass in irsubclasses: - input.update(valid_ir_switcher.get(subclass.__name__, lambda: "Invalid subclass")()) + input |= valid_ir_switcher.get(subclass.__name__, lambda: "Invalid subclass")() return input def create_valid_subroutine_input(irsubclasses, **kwargs): input = {} for subclass in irsubclasses: - input.update( - valid_subroutine_switcher.get(subclass.__name__, lambda: "Invalid subclass")(**kwargs) + input |= valid_subroutine_switcher.get(subclass.__name__, lambda: "Invalid subclass")( + **kwargs ) return input def create_valid_target_input(irsubclasses): - input = {} qubit_set = [] # based on the concept that control goes first in target input for subclass in irsubclasses: @@ -260,8 +259,8 @@ def create_valid_target_input(irsubclasses): qubit_set = list(single_control_valid_input().values()) + qubit_set elif subclass == DoubleControl: qubit_set = list(double_control_valid_ir_input().values()) + qubit_set - elif any( - subclass == i + elif all( + subclass != i for i in [ SingleProbability, SingleProbability_34, @@ -273,17 +272,15 @@ def create_valid_target_input(irsubclasses): MultiProbability, ] ): - pass - else: raise ValueError("Invalid subclass") - input["target"] = QubitSet(qubit_set) + input = {"target": QubitSet(qubit_set)} return input def create_valid_noise_class_input(irsubclasses, **kwargs): input = {} if SingleProbability in irsubclasses: - input.update(single_probability_valid_input()) + input |= single_probability_valid_input() if SingleProbability_34 in irsubclasses: input.update(single_probability_34_valid_input()) if SingleProbability_1516 in irsubclasses: @@ -320,8 +317,8 @@ def calculate_qubit_count(irsubclasses): qubit_count += 2 elif subclass == MultiTarget: qubit_count += 3 - elif any( - subclass == i + elif all( + subclass != i for i in [ SingleProbability, SingleProbability_34, @@ -333,8 +330,6 @@ def calculate_qubit_count(irsubclasses): TwoDimensionalMatrixList, ] ): - pass - else: raise ValueError("Invalid subclass") return qubit_count @@ -365,18 +360,17 @@ def test_noise_subroutine(testclass, subroutine_name, irclass, irsubclasses, kwa ) if qubit_count == 1: multi_targets = [0, 1, 2] - instruction_list = [] - for target in multi_targets: - instruction_list.append( - Instruction( - operator=testclass(**create_valid_noise_class_input(irsubclasses, **kwargs)), - target=target, - ) + instruction_list = [ + Instruction( + operator=testclass(**create_valid_noise_class_input(irsubclasses, **kwargs)), + target=target, ) + for target in multi_targets + ] subroutine = getattr(Circuit(), subroutine_name) subroutine_input = {"target": multi_targets} if SingleProbability in irsubclasses: - subroutine_input.update(single_probability_valid_input()) + subroutine_input |= single_probability_valid_input() if SingleProbability_34 in irsubclasses: subroutine_input.update(single_probability_34_valid_input()) if SingleProbability_1516 in irsubclasses: diff --git a/test/unit_tests/braket/circuits/test_observable.py b/test/unit_tests/braket/circuits/test_observable.py index b7e9c201f..38689398a 100644 --- a/test/unit_tests/braket/circuits/test_observable.py +++ b/test/unit_tests/braket/circuits/test_observable.py @@ -130,7 +130,7 @@ def test_eigenvalue_not_implemented_by_default(observable): def test_str(observable): - expected = "{}('qubit_count': {})".format(observable.name, observable.qubit_count) + expected = f"{observable.name}('qubit_count': {observable.qubit_count})" assert str(observable) == expected assert observable.coefficient == 1 diff --git a/test/unit_tests/braket/circuits/test_observables.py b/test/unit_tests/braket/circuits/test_observables.py index 79f917af9..b6430d4d8 100644 --- a/test/unit_tests/braket/circuits/test_observables.py +++ b/test/unit_tests/braket/circuits/test_observables.py @@ -497,7 +497,7 @@ def test_flattened_tensor_product(): def test_hermitian_basis_rotation_gates(matrix, basis_rotation_matrix): expected_unitary = Gate.Unitary(matrix=basis_rotation_matrix) actual_rotation_gates = Observable.Hermitian(matrix=matrix).basis_rotation_gates - assert actual_rotation_gates == tuple([expected_unitary]) + assert actual_rotation_gates == (expected_unitary,) assert expected_unitary.matrix_equivalence(actual_rotation_gates[0]) @@ -596,16 +596,16 @@ def test_tensor_product_eigenvalues(observable, eigenvalues): @pytest.mark.parametrize( "observable,basis_rotation_gates", [ - (Observable.X() @ Observable.Y(), tuple([Gate.H(), Gate.Z(), Gate.S(), Gate.H()])), + (Observable.X() @ Observable.Y(), (Gate.H(), Gate.Z(), Gate.S(), Gate.H())), ( Observable.X() @ Observable.Y() @ Observable.Z(), - tuple([Gate.H(), Gate.Z(), Gate.S(), Gate.H()]), + (Gate.H(), Gate.Z(), Gate.S(), Gate.H()), ), ( Observable.X() @ Observable.Y() @ Observable.I(), - tuple([Gate.H(), Gate.Z(), Gate.S(), Gate.H()]), + (Gate.H(), Gate.Z(), Gate.S(), Gate.H()), ), - (Observable.X() @ Observable.H(), tuple([Gate.H(), Gate.Ry(-np.pi / 4)])), + (Observable.X() @ Observable.H(), (Gate.H(), Gate.Ry(-np.pi / 4))), ], ) def test_tensor_product_basis_rotation_gates(observable, basis_rotation_gates): @@ -642,9 +642,7 @@ def test_sum_not_allowed_in_tensor_product(): @pytest.mark.parametrize( "observable,basis_rotation_gates", - [ - (Observable.X() + Observable.Y(), tuple([Gate.H(), Gate.Z(), Gate.S(), Gate.H()])), - ], + [(Observable.X() + Observable.Y(), (Gate.H(), Gate.Z(), Gate.S(), Gate.H()))], ) def test_no_basis_rotation_support_for_sum(observable, basis_rotation_gates): no_basis_rotation_support_for_sum = "Basis rotation calculation not supported for Sum" diff --git a/test/unit_tests/braket/circuits/test_quantum_operator.py b/test/unit_tests/braket/circuits/test_quantum_operator.py index ed58d5d63..3a26e8c82 100644 --- a/test/unit_tests/braket/circuits/test_quantum_operator.py +++ b/test/unit_tests/braket/circuits/test_quantum_operator.py @@ -138,5 +138,5 @@ def test_matrix_equivalence_non_quantum_operator(): def test_str(quantum_operator): - expected = "{}('qubit_count': {})".format(quantum_operator.name, quantum_operator.qubit_count) + expected = f"{quantum_operator.name}('qubit_count': {quantum_operator.qubit_count})" assert str(quantum_operator) == expected diff --git a/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py b/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py new file mode 100644 index 000000000..268c53a96 --- /dev/null +++ b/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py @@ -0,0 +1,1121 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import numpy as np +import pytest + +from braket.circuits import ( + Circuit, + FreeParameter, + Gate, + Instruction, + Noise, + Observable, + Operator, + UnicodeCircuitDiagram, +) +from braket.pulse import Frame, Port, PulseSequence + + +def _assert_correct_diagram(circ, expected): + assert UnicodeCircuitDiagram.build_diagram(circ) == "\n".join(expected) + + +def test_empty_circuit(): + assert UnicodeCircuitDiagram.build_diagram(Circuit()) == "" + + +def test_only_gphase_circuit(): + assert UnicodeCircuitDiagram.build_diagram(Circuit().gphase(0.1)) == "Global phase: 0.1" + + +def test_one_gate_one_qubit(): + circ = Circuit().h(0) + expected = ( + "T : │ 0 │", + " ┌───┐ ", + "q0 : ─┤ H ├─", + " └───┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_one_gate_one_qubit_rotation(): + circ = Circuit().rx(angle=3.14, target=0) + # Column formats to length of the gate plus the ascii representation for the angle. + expected = ( + "T : │ 0 │", + " ┌──────────┐ ", + "q0 : ─┤ Rx(3.14) ├─", + " └──────────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_one_gate_one_qubit_rotation_with_parameter(): + theta = FreeParameter("theta") + circ = Circuit().rx(angle=theta, target=0) + # Column formats to length of the gate plus the ascii representation for the angle. + expected = ( + "T : │ 0 │", + " ┌───────────┐ ", + "q0 : ─┤ Rx(theta) ├─", + " └───────────┘ ", + "T : │ 0 │", + "", + "Unassigned parameters: [theta].", + ) + _assert_correct_diagram(circ, expected) + + +@pytest.mark.parametrize("target", [0, 1]) +def test_one_gate_with_global_phase(target): + circ = Circuit().x(target=target).gphase(0.15) + expected = ( + "T : │ 0 │ 1 │", + "GP : │ 0 │0.15 │", + " ┌───┐ ", + f"q{target} : ─┤ X ├───────", + " └───┘ ", + "T : │ 0 │ 1 │", + "", + "Global phase: 0.15", + ) + _assert_correct_diagram(circ, expected) + + +def test_one_gate_with_zero_global_phase(): + circ = Circuit().gphase(-0.15).x(target=0).gphase(0.15) + expected = ( + "T : │ 0 │ 1 │", + "GP : │-0.15│0.00 │", + " ┌───┐ ", + "q0 : ─┤ X ├───────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_one_gate_one_qubit_rotation_with_unicode(): + theta = FreeParameter("\u03B8") + circ = Circuit().rx(angle=theta, target=0) + # Column formats to length of the gate plus the ascii representation for the angle. + expected = ( + "T : │ 0 │", + " ┌───────┐ ", + "q0 : ─┤ Rx(θ) ├─", + " └───────┘ ", + "T : │ 0 │", + "", + "Unassigned parameters: [θ].", + ) + _assert_correct_diagram(circ, expected) + + +def test_one_gate_with_parametric_expression_global_phase_(): + theta = FreeParameter("\u03B8") + circ = Circuit().x(target=0).gphase(2 * theta).x(0).gphase(1) + expected = ( + "T : │ 0 │ 1 │ 2 │", + "GP : │ 0 │ 2*θ │2*θ + 1.0│", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ X ├─┤ X ├───────────", + " └───┘ └───┘ ", + "T : │ 0 │ 1 │ 2 │", + "", + "Global phase: 2*θ + 1.0", + "", + "Unassigned parameters: [θ].", + ) + _assert_correct_diagram(circ, expected) + + +def test_one_gate_one_qubit_rotation_with_parameter_assigned(): + theta = FreeParameter("theta") + circ = Circuit().rx(angle=theta, target=0) + new_circ = circ.make_bound_circuit({"theta": np.pi}) + # Column formats to length of the gate plus the ascii representation for the angle. + expected = ( + "T : │ 0 │", + " ┌──────────┐ ", + "q0 : ─┤ Rx(3.14) ├─", + " └──────────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(new_circ, expected) + + +def test_qubit_width(): + circ = Circuit().h(0).h(100) + expected = ( + "T : │ 0 │", + " ┌───┐ ", + "q0 : ─┤ H ├─", + " └───┘ ", + " ┌───┐ ", + "q100 : ─┤ H ├─", + " └───┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_different_size_boxes(): + circ = Circuit().cnot(0, 1).rx(2, 0.3) + expected = ( + "T : │ 0 │", + " ", + "q0 : ──────●───────", + " │ ", + " ┌─┴─┐ ", + "q1 : ────┤ X ├─────", + " └───┘ ", + " ┌──────────┐ ", + "q2 : ─┤ Rx(0.30) ├─", + " └──────────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_swap(): + circ = Circuit().swap(0, 2).x(1) + expected = ( + "T : │ 0 │", + " ", + "q0 : ────x───────────", + " │ ", + " │ ┌───┐ ", + "q1 : ────┼─────┤ X ├─", + " │ └───┘ ", + " │ ", + "q2 : ────x───────────", + " ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_gate_width(): + class Foo(Gate): + def __init__(self): + super().__init__(qubit_count=1, ascii_symbols=["FOO"]) + + def to_ir(self, target): + return "foo" + + circ = Circuit().h(0).h(1).add_instruction(Instruction(Foo(), 0)) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ┌─────┐ ", + "q0 : ─┤ H ├─┤ FOO ├─", + " └───┘ └─────┘ ", + " ┌───┐ ", + "q1 : ─┤ H ├─────────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_time_width(): + circ = Circuit() + num_qubits = 8 + for qubit in range(num_qubits): + if qubit == num_qubits - 1: + break + circ.cnot(qubit, qubit + 1) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │", + " ", + "q0 : ───●───────────────────────────────────────", + " │ ", + " ┌─┴─┐ ", + "q1 : ─┤ X ├───●─────────────────────────────────", + " └───┘ │ ", + " ┌─┴─┐ ", + "q2 : ───────┤ X ├───●───────────────────────────", + " └───┘ │ ", + " ┌─┴─┐ ", + "q3 : ─────────────┤ X ├───●─────────────────────", + " └───┘ │ ", + " ┌─┴─┐ ", + "q4 : ───────────────────┤ X ├───●───────────────", + " └───┘ │ ", + " ┌─┴─┐ ", + "q5 : ─────────────────────────┤ X ├───●─────────", + " └───┘ │ ", + " ┌─┴─┐ ", + "q6 : ───────────────────────────────┤ X ├───●───", + " └───┘ │ ", + " ┌─┴─┐ ", + "q7 : ─────────────────────────────────────┤ X ├─", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_connector_across_two_qubits(): + circ = Circuit().cnot(4, 3).h(range(2, 6)) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q2 : ─┤ H ├───────", + " └───┘ ", + " ┌───┐ ┌───┐ ", + "q3 : ─┤ X ├─┤ H ├─", + " └─┬─┘ └───┘ ", + " │ ┌───┐ ", + "q4 : ───●───┤ H ├─", + " └───┘ ", + " ┌───┐ ", + "q5 : ─┤ H ├───────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_neg_control_qubits(): + circ = Circuit().x(1, control=[0, 2], control_state=[0, 1]) + expected = ( + "T : │ 0 │", + " ", + "q0 : ───◯───", + " │ ", + " ┌─┴─┐ ", + "q1 : ─┤ X ├─", + " └─┬─┘ ", + " │ ", + "q2 : ───●───", + " ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_only_neg_control_qubits(): + circ = Circuit().x(2, control=[0, 1], control_state=0) + expected = ( + "T : │ 0 │", + " ", + "q0 : ───◯───", + " │ ", + " │ ", + "q1 : ───◯───", + " │ ", + " ┌─┴─┐ ", + "q2 : ─┤ X ├─", + " └───┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_connector_across_three_qubits(): + circ = Circuit().x(control=(3, 4), target=5).h(range(2, 6)) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q2 : ─┤ H ├───────", + " └───┘ ", + " ┌───┐ ", + "q3 : ───●───┤ H ├─", + " │ └───┘ ", + " │ ┌───┐ ", + "q4 : ───●───┤ H ├─", + " │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q5 : ─┤ X ├─┤ H ├─", + " └───┘ └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_overlapping_qubits(): + circ = Circuit().cnot(0, 2).x(control=1, target=3).h(0) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q0 : ───●─────────┤ H ├─", + " │ └───┘ ", + " │ ", + "q1 : ───┼─────●─────────", + " │ │ ", + " ┌─┴─┐ │ ", + "q2 : ─┤ X ├───┼─────────", + " └───┘ │ ", + " ┌─┴─┐ ", + "q3 : ───────┤ X ├───────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_overlapping_qubits_angled_gates(): + circ = Circuit().zz(0, 2, 0.15).x(control=1, target=3).h(0) + expected = ( + "T : │ 0 │ 1 │", + " ┌──────────┐ ┌───┐ ", + "q0 : ─┤ ZZ(0.15) ├───────┤ H ├─", + " └────┬─────┘ └───┘ ", + " │ ", + "q1 : ──────┼─────────●─────────", + " │ │ ", + " ┌────┴─────┐ │ ", + "q2 : ─┤ ZZ(0.15) ├───┼─────────", + " └──────────┘ │ ", + " ┌─┴─┐ ", + "q3 : ──────────────┤ X ├───────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_connector_across_gt_two_qubits(): + circ = Circuit().h(4).x(control=3, target=5).h(4).h(2) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q2 : ─┤ H ├─────────────", + " └───┘ ", + " ", + "q3 : ─────────●─────────", + " │ ", + " ┌───┐ │ ┌───┐ ", + "q4 : ─┤ H ├───┼───┤ H ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ", + "q5 : ───────┤ X ├───────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_connector_across_non_used_qubits(): + circ = Circuit().h(4).cnot(3, 100).h(4).h(101) + expected = ( + "T : │ 0 │ 1 │", + " ", + "q3 : ─────────●─────────", + " │ ", + " ┌───┐ │ ┌───┐ ", + "q4 : ─┤ H ├───┼───┤ H ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ", + "q100 : ───────┤ X ├───────", + " └───┘ ", + " ┌───┐ ", + "q101 : ─┤ H ├─────────────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_1q_no_preceding(): + circ = Circuit().add_verbatim_box(Circuit().h(0)) + expected = ( + "T : │ 0 │ 1 │ 2 │", + " ┌───┐ ", + "q0 : ───StartVerbatim───┤ H ├───EndVerbatim───", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_1q_preceding(): + circ = Circuit().h(0).add_verbatim_box(Circuit().h(0)) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───StartVerbatim───┤ H ├───EndVerbatim───", + " └───┘ └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_1q_following(): + circ = Circuit().add_verbatim_box(Circuit().h(0)).h(0) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │", + " ┌───┐ ┌───┐ ", + "q0 : ───StartVerbatim───┤ H ├───EndVerbatim───┤ H ├─", + " └───┘ └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_2q_no_preceding(): + circ = Circuit().add_verbatim_box(Circuit().h(0).cnot(0, 1)) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │", + " ┌───┐ ", + "q0 : ───StartVerbatim───┤ H ├───●─────EndVerbatim───", + " ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q1 : ─────────╨───────────────┤ X ├────────╨────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_2q_preceding(): + circ = Circuit().h(0).add_verbatim_box(Circuit().h(0).cnot(0, 1)) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───StartVerbatim───┤ H ├───●─────EndVerbatim───", + " └───┘ ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q1 : ───────────────╨───────────────┤ X ├────────╨────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_2q_following(): + circ = Circuit().add_verbatim_box(Circuit().h(0).cnot(0, 1)).h(0) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ┌───┐ ", + "q0 : ───StartVerbatim───┤ H ├───●─────EndVerbatim───┤ H ├─", + " ║ └───┘ │ ║ └───┘ ", + " ║ ┌─┴─┐ ║ ", + "q1 : ─────────╨───────────────┤ X ├────────╨──────────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_3q_no_preceding(): + circ = Circuit().add_verbatim_box(Circuit().h(0).cnot(0, 1).cnot(1, 2)) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ", + "q0 : ───StartVerbatim───┤ H ├───●───────────EndVerbatim───", + " ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q1 : ─────────║───────────────┤ X ├───●──────────║────────", + " ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q2 : ─────────╨─────────────────────┤ X ├────────╨────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_3q_preceding(): + circ = Circuit().h(0).add_verbatim_box(Circuit().h(0).cnot(0, 1).cnot(1, 2)) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───StartVerbatim───┤ H ├───●───────────EndVerbatim───", + " └───┘ ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q1 : ───────────────║───────────────┤ X ├───●──────────║────────", + " ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q2 : ───────────────╨─────────────────────┤ X ├────────╨────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_3q_following(): + circ = Circuit().add_verbatim_box(Circuit().h(0).cnot(0, 1).cnot(1, 2)).h(0) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │", + " ┌───┐ ┌───┐ ", + "q0 : ───StartVerbatim───┤ H ├───●───────────EndVerbatim───┤ H ├─", + " ║ └───┘ │ ║ └───┘ ", + " ║ ┌─┴─┐ ║ ", + "q1 : ─────────║───────────────┤ X ├───●──────────║──────────────", + " ║ └───┘ │ ║ ", + " ║ ┌─┴─┐ ║ ", + "q2 : ─────────╨─────────────────────┤ X ├────────╨──────────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_different_qubits(): + circ = Circuit().h(1).add_verbatim_box(Circuit().h(0)).cnot(3, 4) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ", + "q0 : ─────────StartVerbatim───┤ H ├───EndVerbatim─────────", + " ║ └───┘ ║ ", + " ┌───┐ ║ ║ ", + "q1 : ─┤ H ├─────────║──────────────────────║──────────────", + " └───┘ ║ ║ ", + " ║ ║ ", + "q3 : ───────────────║──────────────────────║──────────●───", + " ║ ║ │ ", + " ║ ║ ┌─┴─┐ ", + "q4 : ───────────────╨──────────────────────╨────────┤ X ├─", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_verbatim_qubset_qubits(): + circ = Circuit().h(1).cnot(0, 1).cnot(1, 2).add_verbatim_box(Circuit().h(1)).cnot(2, 3) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │", + " ", + "q0 : ─────────●───────────StartVerbatim───────────EndVerbatim─────────", + " │ ║ ║ ", + " ┌───┐ ┌─┴─┐ ║ ┌───┐ ║ ", + "q1 : ─┤ H ├─┤ X ├───●───────────║─────────┤ H ├────────║──────────────", + " └───┘ └───┘ │ ║ └───┘ ║ ", + " ┌─┴─┐ ║ ║ ", + "q2 : ─────────────┤ X ├─────────║──────────────────────║──────────●───", + " └───┘ ║ ║ │ ", + " ║ ║ ┌─┴─┐ ", + "q3 : ───────────────────────────╨──────────────────────╨────────┤ X ├─", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_ignore_non_gates(): + class Foo(Operator): + @property + def name(self) -> str: + return "foo" + + def to_ir(self, target): + return "foo" + + circ = Circuit().h(0).h(1).cnot(1, 2).add_instruction(Instruction(Foo(), 0)) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q0 : ─┤ H ├───────", + " └───┘ ", + " ┌───┐ ", + "q1 : ─┤ H ├───●───", + " └───┘ │ ", + " ┌─┴─┐ ", + "q2 : ───────┤ X ├─", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_single_qubit_result_types_target_none(): + circ = Circuit().h(0).probability() + expected = ( + "T : │ 0 │ Result Types │", + " ┌───┐ ┌─────────────┐ ", + "q0 : ─┤ H ├─┤ Probability ├─", + " └───┘ └─────────────┘ ", + "T : │ 0 │ Result Types │", + ) + _assert_correct_diagram(circ, expected) + + +def test_result_types_target_none(): + circ = Circuit().h(0).h(100).probability() + expected = ( + "T : │ 0 │ Result Types │", + " ┌───┐ ┌─────────────┐ ", + "q0 : ─┤ H ├─┤ Probability ├─", + " └───┘ └──────┬──────┘ ", + " ┌───┐ ┌──────┴──────┐ ", + "q100 : ─┤ H ├─┤ Probability ├─", + " └───┘ └─────────────┘ ", + "T : │ 0 │ Result Types │", + ) + _assert_correct_diagram(circ, expected) + + +def test_result_types_target_some(): + circ = ( + Circuit() + .h(0) + .h(1) + .h(100) + .expectation(observable=Observable.Y() @ Observable.Z(), target=[0, 100]) + ) + expected = ( + "T : │ 0 │ Result Types │", + " ┌───┐ ┌──────────────────┐ ", + "q0 : ─┤ H ├─┤ Expectation(Y@Z) ├─", + " └───┘ └────────┬─────────┘ ", + " ┌───┐ │ ", + "q1 : ─┤ H ├──────────┼───────────", + " └───┘ │ ", + " ┌───┐ ┌────────┴─────────┐ ", + "q100 : ─┤ H ├─┤ Expectation(Y@Z) ├─", + " └───┘ └──────────────────┘ ", + "T : │ 0 │ Result Types │", + ) + _assert_correct_diagram(circ, expected) + + +def test_additional_result_types(): + circ = Circuit().h(0).h(1).h(100).state_vector().amplitude(["110", "001"]) + expected = ( + "T : │ 0 │", + " ┌───┐ ", + "q0 : ─┤ H ├─", + " └───┘ ", + " ┌───┐ ", + "q1 : ─┤ H ├─", + " └───┘ ", + " ┌───┐ ", + "q100 : ─┤ H ├─", + " └───┘ ", + "T : │ 0 │", + "", + "Additional result types: StateVector, Amplitude(110,001)", + ) + _assert_correct_diagram(circ, expected) + + +def test_multiple_result_types(): + circ = ( + Circuit() + .cnot(0, 2) + .cnot(1, 3) + .h(0) + .variance(observable=Observable.Y(), target=0) + .expectation(observable=Observable.Y(), target=2) + .sample(observable=Observable.Y()) + ) + expected = ( + "T : │ 0 │ 1 │ Result Types │", + " ┌───┐ ┌─────────────┐ ┌───────────┐ ", + "q0 : ───●─────────┤ H ├──┤ Variance(Y) ├───┤ Sample(Y) ├─", + " │ └───┘ └─────────────┘ └─────┬─────┘ ", + " │ ┌─────┴─────┐ ", + "q1 : ───┼─────●────────────────────────────┤ Sample(Y) ├─", + " │ │ └─────┬─────┘ ", + " ┌─┴─┐ │ ┌────────────────┐ ┌─────┴─────┐ ", + "q2 : ─┤ X ├───┼─────────┤ Expectation(Y) ├─┤ Sample(Y) ├─", + " └───┘ │ └────────────────┘ └─────┬─────┘ ", + " ┌─┴─┐ ┌─────┴─────┐ ", + "q3 : ───────┤ X ├──────────────────────────┤ Sample(Y) ├─", + " └───┘ └───────────┘ ", + "T : │ 0 │ 1 │ Result Types │", + ) + _assert_correct_diagram(circ, expected) + + +def test_multiple_result_types_with_state_vector_amplitude(): + circ = ( + Circuit() + .cnot(0, 2) + .cnot(1, 3) + .h(0) + .variance(observable=Observable.Y(), target=0) + .expectation(observable=Observable.Y(), target=3) + .expectation(observable=Observable.Hermitian(np.array([[1.0, 0.0], [0.0, 1.0]])), target=1) + .amplitude(["0001"]) + .state_vector() + ) + expected = ( + "T : │ 0 │ 1 │ Result Types │", + " ┌───┐ ┌─────────────┐ ", + "q0 : ───●─────────┤ H ├──────┤ Variance(Y) ├───────", + " │ └───┘ └─────────────┘ ", + " │ ┌────────────────────────┐ ", + "q1 : ───┼─────●─────────┤ Expectation(Hermitian) ├─", + " │ │ └────────────────────────┘ ", + " ┌─┴─┐ │ ", + "q2 : ─┤ X ├───┼────────────────────────────────────", + " └───┘ │ ", + " ┌─┴─┐ ┌────────────────┐ ", + "q3 : ───────┤ X ├───────────┤ Expectation(Y) ├─────", + " └───┘ └────────────────┘ ", + "T : │ 0 │ 1 │ Result Types │", + "", + "Additional result types: Amplitude(0001), StateVector", + ) + _assert_correct_diagram(circ, expected) + + +def test_multiple_result_types_with_custom_hermitian_ascii_symbol(): + herm_matrix = (Observable.Y() @ Observable.Z()).to_matrix() + circ = ( + Circuit() + .cnot(0, 2) + .cnot(1, 3) + .h(0) + .variance(observable=Observable.Y(), target=0) + .expectation(observable=Observable.Y(), target=3) + .expectation( + observable=Observable.Hermitian( + matrix=herm_matrix, + display_name="MyHerm", + ), + target=[1, 2], + ) + ) + expected = ( + "T : │ 0 │ 1 │ Result Types │", + " ┌───┐ ┌─────────────┐ ", + "q0 : ───●─────────┤ H ├─────┤ Variance(Y) ├─────", + " │ └───┘ └─────────────┘ ", + " │ ┌─────────────────────┐ ", + "q1 : ───┼─────●─────────┤ Expectation(MyHerm) ├─", + " │ │ └──────────┬──────────┘ ", + " ┌─┴─┐ │ ┌──────────┴──────────┐ ", + "q2 : ─┤ X ├───┼─────────┤ Expectation(MyHerm) ├─", + " └───┘ │ └─────────────────────┘ ", + " ┌─┴─┐ ┌────────────────┐ ", + "q3 : ───────┤ X ├─────────┤ Expectation(Y) ├────", + " └───┘ └────────────────┘ ", + "T : │ 0 │ 1 │ Result Types │", + ) + _assert_correct_diagram(circ, expected) + + +def test_noise_1qubit(): + circ = Circuit().h(0).x(1).bit_flip(1, 0.1) + expected = ( + "T : │ 0 │", + " ┌───┐ ", + "q0 : ─┤ H ├─────────────", + " └───┘ ", + " ┌───┐ ┌─────────┐ ", + "q1 : ─┤ X ├─┤ BF(0.1) ├─", + " └───┘ └─────────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_noise_2qubit(): + circ = Circuit().h(1).kraus((0, 2), [np.eye(4)]) + expected = ( + "T : │ 0 │", + " ┌────┐ ", + "q0 : ───────┤ KR ├─", + " └─┬──┘ ", + " ┌───┐ │ ", + "q1 : ─┤ H ├───┼────", + " └───┘ │ ", + " ┌─┴──┐ ", + "q2 : ───────┤ KR ├─", + " └────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_noise_multi_probabilities(): + circ = Circuit().h(0).x(1).pauli_channel(1, 0.1, 0.2, 0.3) + expected = ( + "T : │ 0 │", + " ┌───┐ ", + "q0 : ─┤ H ├─────────────────────", + " └───┘ ", + " ┌───┐ ┌─────────────────┐ ", + "q1 : ─┤ X ├─┤ PC(0.1,0.2,0.3) ├─", + " └───┘ └─────────────────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_noise_multi_probabilities_with_parameter(): + a = FreeParameter("a") + b = FreeParameter("b") + c = FreeParameter("c") + circ = Circuit().h(0).x(1).pauli_channel(1, a, b, c) + expected = ( + "T : │ 0 │", + " ┌───┐ ", + "q0 : ─┤ H ├───────────────", + " └───┘ ", + " ┌───┐ ┌───────────┐ ", + "q1 : ─┤ X ├─┤ PC(a,b,c) ├─", + " └───┘ └───────────┘ ", + "T : │ 0 │", + "", + "Unassigned parameters: [a, b, c].", + ) + _assert_correct_diagram(circ, expected) + + +def test_pulse_gate_1_qubit_circuit(): + circ = ( + Circuit() + .h(0) + .pulse_gate(0, PulseSequence().set_phase(Frame("x", Port("px", 1e-9), 1e9, 0), 0)) + ) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ┌────┐ ", + "q0 : ─┤ H ├─┤ PG ├─", + " └───┘ └────┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_pulse_gate_multi_qubit_circuit(): + circ = ( + Circuit() + .h(0) + .pulse_gate([0, 1], PulseSequence().set_phase(Frame("x", Port("px", 1e-9), 1e9, 0), 0)) + ) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ┌────┐ ", + "q0 : ─┤ H ├─┤ PG ├─", + " └───┘ └─┬──┘ ", + " ┌─┴──┐ ", + "q1 : ───────┤ PG ├─", + " └────┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_circuit_with_nested_target_list(): + circ = ( + Circuit() + .h(0) + .h(1) + .expectation( + observable=(2 * Observable.Y()) @ (-3 * Observable.I()) + - 0.75 * Observable.Y() @ Observable.Z(), + target=[[0, 1], [0, 1]], + ) + ) + + expected = ( + "T : │ 0 │ Result Types │", + " ┌───┐ ┌──────────────────────────┐ ", + "q0 : ─┤ H ├─┤ Expectation(Hamiltonian) ├─", + " └───┘ └────────────┬─────────────┘ ", + " ┌───┐ ┌────────────┴─────────────┐ ", + "q1 : ─┤ H ├─┤ Expectation(Hamiltonian) ├─", + " └───┘ └──────────────────────────┘ ", + "T : │ 0 │ Result Types │", + ) + _assert_correct_diagram(circ, expected) + + +def test_hamiltonian(): + circ = ( + Circuit() + .h(0) + .cnot(0, 1) + .rx(0, FreeParameter("theta")) + .adjoint_gradient( + 4 * (2e-5 * Observable.Z() + 2 * (3 * Observable.X() @ (2 * Observable.Y()))), + [[0], [1, 2]], + ) + ) + expected = ( + "T : │ 0 │ 1 │ 2 │ Result Types │", + " ┌───┐ ┌───────────┐ ┌──────────────────────────────┐ ", + "q0 : ─┤ H ├───●───┤ Rx(theta) ├─┤ AdjointGradient(Hamiltonian) ├─", + " └───┘ │ └───────────┘ └──────────────┬───────────────┘ ", + " ┌─┴─┐ ┌──────────────┴───────────────┐ ", + "q1 : ───────┤ X ├───────────────┤ AdjointGradient(Hamiltonian) ├─", + " └───┘ └──────────────┬───────────────┘ ", + " ┌──────────────┴───────────────┐ ", + "q2 : ───────────────────────────┤ AdjointGradient(Hamiltonian) ├─", + " └──────────────────────────────┘ ", + "T : │ 0 │ 1 │ 2 │ Result Types │", + "", + "Unassigned parameters: [theta].", + ) + _assert_correct_diagram(circ, expected) + + +def test_power(): + class Foo(Gate): + def __init__(self): + super().__init__(qubit_count=1, ascii_symbols=["FOO"]) + + class CFoo(Gate): + def __init__(self): + super().__init__(qubit_count=2, ascii_symbols=["C", "FOO"]) + + class FooFoo(Gate): + def __init__(self): + super().__init__(qubit_count=2, ascii_symbols=["FOO", "FOO"]) + + circ = Circuit().h(0, power=1).h(1, power=0).h(2, power=-3.14) + circ.add_instruction(Instruction(Foo(), 0, power=-1)) + circ.add_instruction(Instruction(CFoo(), (0, 1), power=2)) + circ.add_instruction(Instruction(CFoo(), (1, 2), control=0, power=3)) + circ.add_instruction(Instruction(FooFoo(), (1, 3), control=[0, 2], power=4)) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ┌────────┐ ", + "q0 : ────┤ H ├────┤ FOO^-1 ├─────●─────────●─────────●─────", + " └───┘ └────────┘ │ │ │ ", + " ┌─────┐ ┌───┴───┐ │ ┌───┴───┐ ", + "q1 : ───┤ H^0 ├──────────────┤ FOO^2 ├─────●─────┤ FOO^4 ├─", + " └─────┘ └───────┘ │ └───┬───┘ ", + " ┌─────────┐ ┌───┴───┐ │ ", + "q2 : ─┤ H^-3.14 ├──────────────────────┤ FOO^3 ├─────●─────", + " └─────────┘ └───────┘ │ ", + " ┌───┴───┐ ", + "q3 : ────────────────────────────────────────────┤ FOO^4 ├─", + " └───────┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_unbalanced_ascii_symbols(): + class FooFoo(Gate): + def __init__(self): + super().__init__(qubit_count=2, ascii_symbols=["FOOO", "FOO"]) + + circ = Circuit().add_instruction(Instruction(FooFoo(), (1, 3), control=[0, 2], power=4)) + expected = ( + "T : │ 0 │", + " ", + "q0 : ─────●──────", + " │ ", + " ┌───┴────┐ ", + "q1 : ─┤ FOOO^4 ├─", + " └───┬────┘ ", + " │ ", + "q2 : ─────●──────", + " │ ", + " ┌───┴───┐ ", + "q3 : ─┤ FOO^4 ├──", + " └───────┘ ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_measure(): + circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).cnot(2, 3).measure([0, 2, 3]) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───●───────────────┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ", + "q1 : ───────┤ X ├───●───────────────", + " └───┘ │ ", + " ┌─┴─┐ ┌───┐ ", + "q2 : ─────────────┤ X ├───●───┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q3 : ───────────────────┤ X ├─┤ M ├─", + " └───┘ └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_measure_with_multiple_measures(): + circ = Circuit().h(0).cnot(0, 1).cnot(1, 2).cnot(2, 3).measure([0, 2]).measure(3).measure(1) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───●───────────────┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q1 : ───────┤ X ├───●─────────┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q2 : ─────────────┤ X ├───●───┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q3 : ───────────────────┤ X ├─┤ M ├─", + " └───┘ └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │", + ) + _assert_correct_diagram(circ, expected) + _assert_correct_diagram(circ, expected) + + +def test_measure_multiple_instructions_after(): + circ = ( + Circuit() + .h(0) + .cnot(0, 1) + .cnot(1, 2) + .cnot(2, 3) + .measure(0) + .measure(1) + .h(3) + .cnot(3, 4) + .measure([2, 3]) + ) + expected = ( + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───●───────────────┤ M ├─────────────", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q1 : ───────┤ X ├───●─────────┤ M ├─────────────", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ", + "q2 : ─────────────┤ X ├───●───────────────┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌───┐ ┌───┐ ", + "q3 : ───────────────────┤ X ├─┤ H ├───●───┤ M ├─", + " └───┘ └───┘ │ └───┘ ", + " ┌─┴─┐ ", + "q4 : ───────────────────────────────┤ X ├───────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_measure_with_readout_noise(): + circ = ( + Circuit() + .h(0) + .cnot(0, 1) + .apply_readout_noise(Noise.BitFlip(probability=0.1), target_qubits=1) + .measure([0, 1]) + ) + expected = ( + "T : │ 0 │ 1 │ 2 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ H ├───●───────────────┤ M ├─", + " └───┘ │ └───┘ ", + " ┌─┴─┐ ┌─────────┐ ┌───┐ ", + "q1 : ───────┤ X ├─┤ BF(0.1) ├─┤ M ├─", + " └───┘ └─────────┘ └───┘ ", + "T : │ 0 │ 1 │ 2 │", + ) + _assert_correct_diagram(circ, expected) diff --git a/test/unit_tests/braket/devices/test_local_simulator.py b/test/unit_tests/braket/devices/test_local_simulator.py index 8485dc5e5..216d161c7 100644 --- a/test/unit_tests/braket/devices/test_local_simulator.py +++ b/test/unit_tests/braket/devices/test_local_simulator.py @@ -11,19 +11,25 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from typing import Any, Dict, Optional -from unittest.mock import Mock +import json +import textwrap +import warnings +from typing import Any, Optional +from unittest.mock import Mock, patch import pytest -from pydantic import create_model # This is temporary for defining properties below +from pydantic.v1 import create_model # This is temporary for defining properties below import braket.ir as ir from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.ahs.atom_arrangement import AtomArrangement from braket.ahs.hamiltonian import Hamiltonian from braket.annealing import Problem, ProblemType -from braket.circuits import Circuit, FreeParameter -from braket.device_schema import DeviceCapabilities +from braket.circuits import Circuit, FreeParameter, Gate, Noise +from braket.circuits.noise_model import GateCriteria, NoiseModel, NoiseModelInstruction +from braket.circuits.serialization import IRType, SerializableProgram +from braket.device_schema import DeviceActionType, DeviceCapabilities +from braket.device_schema.openqasm_device_action_properties import OpenQASMDeviceActionProperties from braket.devices import LocalSimulator, local_simulator from braket.ir.openqasm import Program from braket.simulator import BraketSimulator @@ -111,10 +117,10 @@ def run( program: ir.jaqcd.Program, qubits: int, shots: Optional[int], - inputs: Optional[Dict[str, float]], + inputs: Optional[dict[str, float]], *args, - **kwargs - ) -> Dict[str, Any]: + **kwargs, + ) -> dict[str, Any]: self._shots = shots self._qubits = qubits return GATE_MODEL_RESULT @@ -151,7 +157,7 @@ def properties(self) -> DeviceCapabilities: class DummyJaqcdSimulator(BraketSimulator): def run( self, program: ir.jaqcd.Program, qubits: int, shots: Optional[int], *args, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if not isinstance(program, ir.jaqcd.Program): raise TypeError("Not a Jaqcd program") self._shots = shots @@ -200,7 +206,7 @@ def run( @property def properties(self) -> DeviceCapabilities: - return DeviceCapabilities.parse_obj( + device_properties = DeviceCapabilities.parse_obj( { "service": { "executionWindows": [ @@ -221,6 +227,105 @@ def properties(self) -> DeviceCapabilities: "deviceParameters": {}, } ) + oq3_action = OpenQASMDeviceActionProperties.parse_raw( + json.dumps( + { + "actionType": "braket.ir.openqasm.program", + "version": ["1"], + "supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"], + "supportedResultTypes": [ + {"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0}, + ], + "supportedPragmas": [ + "braket_unitary_matrix", + "braket_result_type_sample", + "braket_result_type_expectation", + "braket_result_type_variance", + "braket_result_type_probability", + "braket_result_type_state_vector", + ], + } + ) + ) + device_properties.action[DeviceActionType.OPENQASM] = oq3_action + return device_properties + + +class DummySerializableProgram(SerializableProgram): + def __init__(self, source: str): + self.source = source + + def to_ir(self, ir_type: IRType = IRType.OPENQASM) -> str: + return self.source + + +class DummySerializableProgramSimulator(DummyProgramSimulator): + def run( + self, + program: SerializableProgram, + shots: int = 0, + batch_size: int = 1, + ) -> GateModelQuantumTaskResult: + return GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) + + +class DummyProgramDensityMatrixSimulator(BraketSimulator): + def run( + self, program: ir.openqasm.Program, shots: Optional[int], *args, **kwargs + ) -> dict[str, Any]: + self._shots = shots + return GATE_MODEL_RESULT + + @property + def properties(self) -> DeviceCapabilities: + device_properties = DeviceCapabilities.parse_obj( + { + "service": { + "executionWindows": [ + { + "executionDay": "Everyday", + "windowStartHour": "11:00", + "windowEndHour": "12:00", + } + ], + "shotsRange": [1, 10], + }, + "action": {}, + "deviceParameters": {}, + } + ) + oq3_action = OpenQASMDeviceActionProperties.parse_raw( + json.dumps( + { + "actionType": "braket.ir.openqasm.program", + "version": ["1"], + "supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"], + "supportedResultTypes": [ + {"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0}, + ], + "supportedPragmas": [ + "braket_noise_bit_flip", + "braket_noise_depolarizing", + "braket_noise_kraus", + "braket_noise_pauli_channel", + "braket_noise_generalized_amplitude_damping", + "braket_noise_amplitude_damping", + "braket_noise_phase_flip", + "braket_noise_phase_damping", + "braket_noise_two_qubit_dephasing", + "braket_noise_two_qubit_depolarizing", + "braket_unitary_matrix", + "braket_result_type_sample", + "braket_result_type_expectation", + "braket_result_type_variance", + "braket_result_type_probability", + "braket_result_type_density_matrix", + ], + } + ) + ) + device_properties.action[DeviceActionType.OPENQASM] = oq3_action + return device_properties class DummyAnnealingSimulator(BraketSimulator): @@ -284,13 +389,16 @@ def properties(self) -> DeviceCapabilities: mock_circuit_entry = Mock() mock_program_entry = Mock() mock_jaqcd_entry = Mock() +mock_circuit_dm_entry = Mock() mock_circuit_entry.load.return_value = DummyCircuitSimulator mock_program_entry.load.return_value = DummyProgramSimulator mock_jaqcd_entry.load.return_value = DummyJaqcdSimulator +mock_circuit_dm_entry.load.return_value = DummyProgramDensityMatrixSimulator local_simulator._simulator_devices = { "dummy": mock_circuit_entry, "dummy_oq3": mock_program_entry, "dummy_jaqcd": mock_jaqcd_entry, + "dummy_oq3_dm": mock_circuit_dm_entry, } mock_ahs_program = AnalogHamiltonianSimulation( @@ -467,6 +575,25 @@ def test_run_program_model(): assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) +def test_run_serializable_program_model(): + dummy = DummySerializableProgramSimulator() + sim = LocalSimulator(dummy) + task = sim.run( + DummySerializableProgram( + source=""" +qubit[2] q; +bit[2] c; + +h q[0]; +cnot q[0], q[1]; + +c = measure q; +""" + ) + ) + assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) + + @pytest.mark.xfail(raises=ValueError) def test_run_gate_model_value_error(): dummy = DummyCircuitSimulator() @@ -490,7 +617,12 @@ def test_run_ahs(): def test_registered_backends(): - assert LocalSimulator.registered_backends() == {"dummy", "dummy_oq3", "dummy_jaqcd"} + assert LocalSimulator.registered_backends() == { + "dummy", + "dummy_oq3", + "dummy_jaqcd", + "dummy_oq3_dm", + } @pytest.mark.xfail(raises=TypeError) @@ -526,3 +658,129 @@ def test_properties(): sim = LocalSimulator(dummy) expected_properties = dummy.properties assert sim.properties == expected_properties + + +@pytest.fixture +def noise_model(): + return ( + NoiseModel() + .add_noise(Noise.BitFlip(0.05), GateCriteria(Gate.H)) + .add_noise(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)) + ) + + +@pytest.mark.parametrize("backend", ["dummy_oq3_dm"]) +def test_valid_local_device_for_noise_model(backend, noise_model): + device = LocalSimulator(backend, noise_model=noise_model) + assert device._noise_model.instructions == [ + NoiseModelInstruction(Noise.BitFlip(0.05), GateCriteria(Gate.H)), + NoiseModelInstruction(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)), + ] + + +@pytest.mark.parametrize("backend", ["dummy_oq3"]) +def test_invalid_local_device_for_noise_model(backend, noise_model): + with pytest.raises(ValueError): + _ = LocalSimulator(backend, noise_model=noise_model) + + +@pytest.mark.parametrize("backend", ["dummy_oq3_dm"]) +def test_local_device_with_invalid_noise_model(backend, noise_model): + with pytest.raises(TypeError): + _ = LocalSimulator(backend, noise_model=Mock()) + + +@patch.object(DummyProgramDensityMatrixSimulator, "run") +def test_run_with_noise_model(mock_run, noise_model): + mock_run.return_value = GATE_MODEL_RESULT + device = LocalSimulator("dummy_oq3_dm", noise_model=noise_model) + circuit = Circuit().h(0).cnot(0, 1) + _ = device.run(circuit, shots=4) + + expected_circuit = textwrap.dedent( + """ + OPENQASM 3.0; + bit[2] b; + qubit[2] q; + h q[0]; + #pragma braket noise bit_flip(0.05) q[0] + cnot q[0], q[1]; + #pragma braket noise two_qubit_depolarizing(0.1) q[0], q[1] + b[0] = measure q[0]; + b[1] = measure q[1]; + """ + ).strip() + + mock_run.assert_called_with( + Program(source=expected_circuit, inputs={}), + 4, + ) + + +@patch.object(LocalSimulator, "_apply_noise_model_to_circuit") +def test_run_batch_with_noise_model(mock_apply, noise_model): + device = LocalSimulator("dummy_oq3_dm", noise_model=noise_model) + circuit = Circuit().h(0).cnot(0, 1) + + mock_apply.return_value = noise_model.apply(circuit) + _ = device.run_batch([circuit] * 2, shots=4).results() + assert mock_apply.call_count == 2 + + +@patch.object(DummyProgramDensityMatrixSimulator, "run") +def test_run_noisy_circuit_with_noise_model(mock_run, noise_model): + mock_run.return_value = GATE_MODEL_RESULT + device = LocalSimulator("dummy_oq3_dm", noise_model=noise_model) + circuit = Circuit().h(0).depolarizing(0, 0.1) + with warnings.catch_warnings(record=True) as w: + _ = device.run(circuit, shots=4) + + expected_warning = ( + "The noise model of the device is applied to a circuit that already has noise " + "instructions." + ) + expected_circuit = textwrap.dedent( + """ + OPENQASM 3.0; + bit[1] b; + qubit[1] q; + h q[0]; + #pragma braket noise bit_flip(0.05) q[0] + #pragma braket noise depolarizing(0.1) q[0] + b[0] = measure q[0]; + """ + ).strip() + + mock_run.assert_called_with( + Program(source=expected_circuit, inputs={}), + 4, + ) + assert w[-1].message.__str__() == expected_warning + + +@patch.object(DummyProgramDensityMatrixSimulator, "run") +def test_run_openqasm_with_noise_model(mock_run, noise_model): + mock_run.return_value = GATE_MODEL_RESULT + device = LocalSimulator("dummy_oq3_dm", noise_model=noise_model) + expected_circuit = textwrap.dedent( + """ + OPENQASM 3.0; + bit[1] b; + qubit[1] q; + h q[0]; + b[0] = measure q[0]; + """ + ).strip() + expected_warning = ( + "Noise model is only applicable to circuits. The type of the task specification " + "is Program. The noise model of the device does not apply." + ) + circuit = Program(source=expected_circuit) + with warnings.catch_warnings(record=True) as w: + _ = device.run(circuit, shots=4) + + mock_run.assert_called_with( + Program(source=expected_circuit, inputs=None), + 4, + ) + assert w[-1].message.__str__() == expected_warning diff --git a/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py index 17027cec4..72f9c4db5 100644 --- a/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics_data/test_cwl_insights_metrics_fetcher.py @@ -80,6 +80,40 @@ def test_get_all_metrics_complete_results(mock_add_metrics, mock_get_metrics, aw assert result == expected_result +@patch("braket.jobs.metrics_data.cwl_insights_metrics_fetcher.LogMetricsParser.get_parsed_metrics") +@patch("braket.jobs.metrics_data.cwl_insights_metrics_fetcher.LogMetricsParser.parse_log_message") +def test_get_all_metrics_complete_results_stream_prefix( + mock_add_metrics, mock_get_metrics, aws_session +): + logs_client_mock = Mock() + aws_session.logs_client = logs_client_mock + + logs_client_mock.start_query.return_value = {"queryId": "test"} + logs_client_mock.get_query_results.return_value = { + "status": "Complete", + "results": EXAMPLE_METRICS_LOG_LINES, + } + expected_result = {"Test": [0]} + mock_get_metrics.return_value = expected_result + + fetcher = CwlInsightsMetricsFetcher(aws_session) + + result = fetcher.get_metrics_for_job( + "test_job", job_start_time=1, job_end_time=2, stream_prefix="test_job/uuid" + ) + logs_client_mock.get_query_results.assert_called_with(queryId="test") + logs_client_mock.start_query.assert_called_with( + logGroupName="/aws/braket/jobs", + startTime=1, + endTime=2, + queryString="fields @timestamp, @message | filter @logStream like /^test_job\\/uuid\\//" + " | filter @message like /Metrics - /", + limit=10000, + ) + assert mock_add_metrics.call_args_list == EXPECTED_CALL_LIST + assert result == expected_result + + def test_get_all_metrics_timeout(aws_session): logs_client_mock = Mock() aws_session.logs_client = logs_client_mock diff --git a/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py b/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py index fdaff840b..247d1873b 100644 --- a/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py +++ b/test/unit_tests/braket/jobs/metrics_data/test_cwl_metrics_fetcher.py @@ -128,8 +128,6 @@ def test_get_metrics_timeout(mock_add_metrics, mock_get_metrics, aws_session): def get_log_events_forever(*args, **kwargs): - next_token = "1" token = kwargs.get("nextToken") - if token and token == "1": - next_token = "2" + next_token = "2" if token and token == "1" else "1" return {"events": EXAMPLE_METRICS_LOG_LINES, "nextForwardToken": next_token} diff --git a/test/unit_tests/braket/jobs/test_data_persistence.py b/test/unit_tests/braket/jobs/test_data_persistence.py index 6a5e27283..a4ac78f26 100644 --- a/test/unit_tests/braket/jobs/test_data_persistence.py +++ b/test/unit_tests/braket/jobs/test_data_persistence.py @@ -83,7 +83,7 @@ def test_save_job_checkpoint( if file_suffix else f"{tmp_dir}/{job_name}.json" ) - with open(expected_file_location, "r") as expected_file: + with open(expected_file_location) as expected_file: assert expected_file.read() == expected_saved_data @@ -267,7 +267,7 @@ def test_save_job_result(data_format, result_data, expected_saved_data): save_job_result(result_data, data_format) expected_file_location = f"{tmp_dir}/results.json" - with open(expected_file_location, "r") as expected_file: + with open(expected_file_location) as expected_file: assert expected_file.read() == expected_saved_data diff --git a/test/unit_tests/braket/jobs/test_hybrid_job.py b/test/unit_tests/braket/jobs/test_hybrid_job.py index e757c6a69..b1739d879 100644 --- a/test/unit_tests/braket/jobs/test_hybrid_job.py +++ b/test/unit_tests/braket/jobs/test_hybrid_job.py @@ -5,7 +5,7 @@ import tempfile from logging import getLogger from pathlib import Path -from ssl import SSLContext +from ssl import PROTOCOL_TLS_CLIENT, SSLContext from unittest.mock import MagicMock, patch import job_module @@ -488,7 +488,7 @@ def my_entry(*args): def test_serialization_error(aws_session): - ssl_context = SSLContext() + ssl_context = SSLContext(protocol=PROTOCOL_TLS_CLIENT) @hybrid_job(device=None, aws_session=aws_session) def fails_serialization(): @@ -511,7 +511,7 @@ def my_entry(*args, **kwargs): args, kwargs = (1, "two"), {"three": 3} template = _serialize_entry_point(my_entry, args, kwargs) - pickled_str = re.search(r"(?s)cloudpickle.loads\((.*?)\)\ndef my_entry", template).group(1) + pickled_str = re.search(r"(?s)cloudpickle.loads\((.*?)\)\ndef my_entry", template)[1] byte_str = ast.literal_eval(pickled_str) recovered = cloudpickle.loads(byte_str) diff --git a/test/unit_tests/braket/jobs/test_quantum_job_creation.py b/test/unit_tests/braket/jobs/test_quantum_job_creation.py index 8cd1fbca9..d12a29b00 100644 --- a/test/unit_tests/braket/jobs/test_quantum_job_creation.py +++ b/test/unit_tests/braket/jobs/test_quantum_job_creation.py @@ -148,9 +148,7 @@ def data_parallel(request): @pytest.fixture def distribution(data_parallel): - if data_parallel: - return "data_parallel" - return None + return "data_parallel" if data_parallel else None @pytest.fixture @@ -255,8 +253,8 @@ def create_job_args( reservation_arn, ): if request.param == "fixtures": - return dict( - (key, value) + return { + key: value for key, value in { "device": device, "source_module": source_module, @@ -277,7 +275,7 @@ def create_job_args( "reservation_arn": reservation_arn, }.items() if value is not None - ) + } elif request.param == "defaults": return { "device": device, @@ -339,7 +337,7 @@ def _translate_creation_args(create_job_args): "sagemaker_distributed_dataparallel_enabled": "true", "sagemaker_instance_type": instance_config.instanceType, } - hyperparameters.update(distributed_hyperparams) + hyperparameters |= distributed_hyperparams output_data_config = create_job_args["output_data_config"] or OutputDataConfig( s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, timestamp, "data") ) @@ -379,16 +377,12 @@ def _translate_creation_args(create_job_args): } if reservation_arn: - test_kwargs.update( + test_kwargs["associations"] = [ { - "associations": [ - { - "arn": reservation_arn, - "type": "RESERVATION_TIME_WINDOW_ARN", - } - ] + "arn": reservation_arn, + "type": "RESERVATION_TIME_WINDOW_ARN", } - ) + ] return test_kwargs diff --git a/test/unit_tests/braket/parametric/test_free_parameter.py b/test/unit_tests/braket/parametric/test_free_parameter.py index 816bc0cee..b94c60a97 100644 --- a/test/unit_tests/braket/parametric/test_free_parameter.py +++ b/test/unit_tests/braket/parametric/test_free_parameter.py @@ -21,9 +21,15 @@ def free_parameter(): return FreeParameter("theta") -@pytest.mark.xfail(raises=TypeError) def test_bad_input(): - FreeParameter(6) + with pytest.raises(ValueError, match="FreeParameter names must be non empty"): + FreeParameter("") + with pytest.raises(TypeError, match="FreeParameter names must be strings"): + FreeParameter(6) + with pytest.raises( + ValueError, match="FreeParameter names must start with a letter or an underscore" + ): + FreeParameter(".2") def test_is_free_param(free_parameter): diff --git a/test/unit_tests/braket/parametric/test_free_parameter_expression.py b/test/unit_tests/braket/parametric/test_free_parameter_expression.py index 879706fe0..1bf6818bf 100644 --- a/test/unit_tests/braket/parametric/test_free_parameter_expression.py +++ b/test/unit_tests/braket/parametric/test_free_parameter_expression.py @@ -80,7 +80,6 @@ def test_commutativity(): def test_add(): add_expr = FreeParameter("theta") + FreeParameter("theta") expected = FreeParameterExpression(2 * FreeParameter("theta")) - assert add_expr == expected @@ -89,14 +88,12 @@ def test_sub(): expected = FreeParameterExpression(FreeParameter("theta")) - FreeParameterExpression( FreeParameter("alpha") ) - assert sub_expr == expected def test_r_sub(): r_sub_expr = 1 - FreeParameter("theta") expected = FreeParameterExpression(1 - FreeParameter("theta")) - assert r_sub_expr == expected @@ -106,6 +103,20 @@ def test_mul(): assert mul_expr == expected +def test_truediv(): + truediv_expr = FreeParameter("theta") / FreeParameter("alpha") + expected = FreeParameterExpression(FreeParameter("theta")) / FreeParameterExpression( + FreeParameter("alpha") + ) + assert truediv_expr == expected + + +def test_r_truediv(): + r_truediv_expr = 1 / FreeParameter("theta") + expected = FreeParameterExpression(1 / FreeParameter("theta")) + assert r_truediv_expr == expected + + def test_pow(): mul_expr = FreeParameter("theta") ** FreeParameter("alpha") * 2 expected = FreeParameterExpression(FreeParameter("theta") ** FreeParameter("alpha") * 2) @@ -157,6 +168,7 @@ def test_sub_return_expression(): (FreeParameter("a") + 2 * FreeParameter("b"), {"a": 0.1, "b": 0.3}, 0.7, float), (FreeParameter("x"), {"y": 1}, FreeParameter("x"), FreeParameter), (FreeParameter("y"), {"y": -0.1}, -0.1, float), + (2 * FreeParameter("i"), {"i": 1}, 2.0, float), ( FreeParameter("a") + 2 * FreeParameter("x"), {"a": 0.4, "b": 0.4}, diff --git a/test/unit_tests/braket/pulse/ast/test_approximation_parser.py b/test/unit_tests/braket/pulse/ast/test_approximation_parser.py index c4199e81b..56f02aa12 100644 --- a/test/unit_tests/braket/pulse/ast/test_approximation_parser.py +++ b/test/unit_tests/braket/pulse/ast/test_approximation_parser.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from typing import List, Union +from typing import Union from unittest.mock import Mock import numpy as np @@ -85,7 +85,7 @@ def test_delay_multiple_frames(port): # Inst2 # Delay frame1 and frame2 by 10e-9 # frame2 is 0 from 0ns to 21ns - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 10e-9 expected_amplitudes["frame1"].put(shift_time_frame1, 0).put( @@ -104,7 +104,7 @@ def test_delay_multiple_frames(port): expected_phases["frame2"].put(0, 0).put(shift_time_frame1 + pulse_length - port.dt, 0) # Inst3 - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 16e-9 times = np.arange(0, pulse_length, port.dt) values = 1 * np.ones_like(times) @@ -156,7 +156,7 @@ def test_delay_qubits(port): # Inst2 # Delay frame1 and frame2 by 10e-9 # frame2 is 0 from 0ns to 21ns - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 10e-9 expected_amplitudes["q0_frame"].put(shift_time_frame1, 0).put( @@ -177,7 +177,7 @@ def test_delay_qubits(port): expected_phases["q0_q1_frame"].put(0, 0).put(shift_time_frame1 + pulse_length - port.dt, 0) # Inst3 - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 16e-9 times = np.arange(0, pulse_length, port.dt) values = 1 * np.ones_like(times) @@ -230,7 +230,7 @@ def test_delay_no_args(port): # Inst2 # Delay frame1 and frame2 by 10e-9 # frame2 is 0 from 0ns to 21ns - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 10e-9 expected_amplitudes["q0_frame"].put(shift_time_frame1, 0).put( @@ -251,7 +251,7 @@ def test_delay_no_args(port): expected_phases["q0_q1_frame"].put(0, 0).put(shift_time_frame1 + pulse_length - port.dt, 0) # Inst3 - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 16e-9 times = np.arange(0, pulse_length, port.dt) values = 1 * np.ones_like(times) @@ -636,7 +636,7 @@ def test_play_drag_gaussian_waveforms(port): dtype=np.complex128, ) - shift_time = shift_time + 20e-9 + shift_time += 20e-9 for t, v in zip(times, values): expected_amplitudes["frame1"].put(t + shift_time, v) expected_frequencies["frame1"].put(t + shift_time, 1e8) @@ -681,7 +681,7 @@ def test_barrier_same_dt(port): expected_phases["frame2"].put(0, 0).put(11e-9, 0) # Inst3 - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 16e-9 times = np.arange(0, pulse_length, port.dt) values = 1 * np.ones_like(times) @@ -739,7 +739,7 @@ def test_barrier_no_args(port): expected_phases["frame2"].put(0, 0).put(11e-9, 0) # Inst3 - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 16e-9 times = np.arange(0, pulse_length, port.dt) values = 1 * np.ones_like(times) @@ -797,7 +797,7 @@ def test_barrier_qubits(port): expected_phases["q0_q1_frame"].put(0, 0).put(11e-9, 0) # Inst3 - shift_time_frame1 = shift_time_frame1 + pulse_length + shift_time_frame1 += pulse_length pulse_length = 16e-9 times = np.arange(0, pulse_length, port.dt) values = 1 * np.ones_like(times) @@ -1001,8 +1001,8 @@ def verify_results(results, expected_amplitudes, expected_frequencies, expected_ assert _all_close(results.phases[frame_id], expected_phases[frame_id], 1e-10) -def to_dict(frames: Union[Frame, List]): - if not isinstance(frames, List): +def to_dict(frames: Union[Frame, list]): + if not isinstance(frames, list): frames = [frames] frame_dict = dict() for frame in frames: diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 57ba20fbd..da8e0a8a3 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -87,7 +87,7 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined .set_frequency(predefined_frame_1, param) .shift_frequency(predefined_frame_1, param) .set_phase(predefined_frame_1, param) - .shift_phase(predefined_frame_1, param) + .shift_phase(predefined_frame_1, -param) .set_scale(predefined_frame_1, param) .capture_v0(predefined_frame_1) .delay([predefined_frame_1, predefined_frame_2], param) @@ -125,25 +125,25 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian((1000000000.0*length_g)ns, (1000000000.0*sigma_g)ns, " - "1, false);", - " waveform drag_gauss_wf = " - "drag_gaussian((1000000000.0*length_dg)ns, (1000000000.0*sigma_dg)ns, 0.2, 1, false);", - " waveform constant_wf = constant((1000000000.0*length_c)ns, 2.0 + 0.3im);", - " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", - " set_frequency(predefined_frame_1, a + 2*b);", - " shift_frequency(predefined_frame_1, a + 2*b);", - " set_phase(predefined_frame_1, a + 2*b);", - " shift_phase(predefined_frame_1, a + 2*b);", - " set_scale(predefined_frame_1, a + 2*b);", + *[ + f" input float {parameter};" + for parameter in reversed(list(pulse_sequence.parameters)) + ], + " waveform gauss_wf = gaussian(length_g * 1s, sigma_g * 1s, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(length_dg * 1s," + " sigma_dg * 1s, 0.2, 1, false);", + " waveform constant_wf = constant(length_c * 1s, 2.0 + 0.3im);", + " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", + " set_frequency(predefined_frame_1, a + 2.0 * b);", + " shift_frequency(predefined_frame_1, a + 2.0 * b);", + " set_phase(predefined_frame_1, a + 2.0 * b);", + " shift_phase(predefined_frame_1, -1.0 * a + -2.0 * b);", + " set_scale(predefined_frame_1, a + 2.0 * b);", " psb[0] = capture_v0(predefined_frame_1);", - ( - " delay[(1000000000.0*a + 2000000000.0*b)ns]" - " predefined_frame_1, predefined_frame_2;" - ), - " delay[(1000000000.0*a + 2000000000.0*b)ns] predefined_frame_1;", - " delay[1000000.0ns] predefined_frame_1;", + " delay[(a + 2.0 * b) * 1s] predefined_frame_1, predefined_frame_2;", + " delay[(a + 2.0 * b) * 1s] predefined_frame_1;", + " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", " play(predefined_frame_2, drag_gauss_wf);", @@ -154,17 +154,15 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined ] ) assert pulse_sequence.to_ir() == expected_str_unbound - assert pulse_sequence.parameters == set( - [ - FreeParameter("a"), - FreeParameter("b"), - FreeParameter("length_g"), - FreeParameter("length_dg"), - FreeParameter("sigma_g"), - FreeParameter("sigma_dg"), - FreeParameter("length_c"), - ] - ) + assert pulse_sequence.parameters == { + FreeParameter("a"), + FreeParameter("b"), + FreeParameter("length_g"), + FreeParameter("length_dg"), + FreeParameter("sigma_g"), + FreeParameter("sigma_dg"), + FreeParameter("length_c"), + } b_bound = pulse_sequence.make_bound_pulse_sequence( {"b": 2, "length_g": 1e-3, "length_dg": 3e-3, "sigma_dg": 0.4, "length_c": 4e-3} ) @@ -173,21 +171,21 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1000000.0ns, (1000000000.0*sigma_g)ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform constant_wf = constant(4000000.0ns, 2.0 + 0.3im);", - " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", - " set_frequency(predefined_frame_1, a + 4);", - " shift_frequency(predefined_frame_1, a + 4);", - " set_phase(predefined_frame_1, a + 4);", - " shift_phase(predefined_frame_1, a + 4);", - " set_scale(predefined_frame_1, a + 4);", + *[f" input float {parameter};" for parameter in reversed(list(b_bound.parameters))], + " waveform gauss_wf = gaussian(1.0ms, sigma_g * 1s, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", + " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", + " set_frequency(predefined_frame_1, a + 4.0);", + " shift_frequency(predefined_frame_1, a + 4.0);", + " set_phase(predefined_frame_1, a + 4.0);", + " shift_phase(predefined_frame_1, -1.0 * a + -4.0);", + " set_scale(predefined_frame_1, a + 4.0);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[(1000000000.0*a + 4000000000.0)ns] predefined_frame_1, predefined_frame_2;", - " delay[(1000000000.0*a + 4000000000.0)ns] predefined_frame_1;", - " delay[1000000.0ns] predefined_frame_1;", + " delay[(a + 4.0) * 1s] predefined_frame_1, predefined_frame_2;", + " delay[(a + 4.0) * 1s] predefined_frame_1;", + " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", " play(predefined_frame_2, drag_gauss_wf);", @@ -199,28 +197,27 @@ def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined ) assert b_bound.to_ir() == b_bound_call.to_ir() == expected_str_b_bound assert pulse_sequence.to_ir() == expected_str_unbound - assert b_bound.parameters == set([FreeParameter("sigma_g"), FreeParameter("a")]) + assert b_bound.parameters == {FreeParameter("sigma_g"), FreeParameter("a")} both_bound = b_bound.make_bound_pulse_sequence({"a": 1, "sigma_g": 0.7}) both_bound_call = b_bound_call(1, sigma_g=0.7) # use arg 1 for a expected_str_both_bound = "\n".join( [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1000000.0ns, 700000000.0ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform constant_wf = constant(4000000.0ns, 2.0 + 0.3im);", - " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", - " set_frequency(predefined_frame_1, 5);", - " shift_frequency(predefined_frame_1, 5);", - " set_phase(predefined_frame_1, 5);", - " shift_phase(predefined_frame_1, 5);", - " set_scale(predefined_frame_1, 5);", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", + " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", + " set_frequency(predefined_frame_1, 5.0);", + " shift_frequency(predefined_frame_1, 5.0);", + " set_phase(predefined_frame_1, 5.0);", + " shift_phase(predefined_frame_1, -5.0);", + " set_scale(predefined_frame_1, 5.0);", " psb[0] = capture_v0(predefined_frame_1);", - " delay[5000000000.00000ns] predefined_frame_1, predefined_frame_2;", - " delay[5000000000.00000ns] predefined_frame_1;", - " delay[1000000.0ns] predefined_frame_1;", + " delay[5.0s] predefined_frame_1, predefined_frame_2;", + " delay[5.0s] predefined_frame_1;", + " delay[1.0ms] predefined_frame_1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", " play(predefined_frame_2, drag_gauss_wf);", @@ -311,12 +308,11 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2): [ "OPENQASM 3.0;", "cal {", - " waveform gauss_wf = gaussian(1000000.0ns, 700000000.0ns, 1, false);", - " waveform drag_gauss_wf = drag_gaussian(3000000.0ns, 400000000.0ns, 0.2, 1," - " false);", - " waveform constant_wf = constant(4000000.0ns, 2.0 + 0.3im);", - " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " bit[2] psb;", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", + " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", " set_frequency(predefined_frame_1, 3000000000.0);", " shift_frequency(predefined_frame_1, 1000000000.0);", " set_phase(predefined_frame_1, -0.5);", @@ -324,8 +320,8 @@ def test_pulse_sequence_to_ir(predefined_frame_1, predefined_frame_2): " set_scale(predefined_frame_1, 0.25);", " psb[0] = capture_v0(predefined_frame_1);", " delay[2.0ns] predefined_frame_1, predefined_frame_2;", - " delay[1000.0ns] predefined_frame_1;", - " delay[1000000.0ns] $0;", + " delay[1.0us] predefined_frame_1;", + " delay[1.0ms] $0;", " barrier $0, $1;", " barrier predefined_frame_1, predefined_frame_2;", " play(predefined_frame_1, gauss_wf);", diff --git a/test/unit_tests/braket/pulse/test_waveforms.py b/test/unit_tests/braket/pulse/test_waveforms.py index b42eacc0b..34f989c0b 100644 --- a/test/unit_tests/braket/pulse/test_waveforms.py +++ b/test/unit_tests/braket/pulse/test_waveforms.py @@ -33,15 +33,23 @@ ], ) def test_arbitrary_waveform(amps): - id = "arb_wf_x" - wf = ArbitraryWaveform(amps, id) + waveform_id = "arb_wf_x" + wf = ArbitraryWaveform(amps, waveform_id) assert wf.amplitudes == list(amps) - assert wf.id == id + assert wf.id == waveform_id oq_exp = wf._to_oqpy_expression() assert oq_exp.init_expression == list(amps) assert oq_exp.name == wf.id +def test_arbitrary_waveform_repr(): + amps = [1, 4, 5] + waveform_id = "arb_wf_x" + wf = ArbitraryWaveform(amps, waveform_id) + expected = f"ArbitraryWaveform('id': {wf.id}, 'amplitudes': {wf.amplitudes})" + assert repr(wf) == expected + + def test_arbitrary_waveform_default_params(): amps = [1, 4, 5] wf = ArbitraryWaveform(amps) @@ -74,7 +82,16 @@ def test_constant_waveform(): assert wf.iq == iq assert wf.id == id - _assert_wf_qasm(wf, "waveform const_wf_x = constant(4000000.0ns, 4);") + _assert_wf_qasm(wf, "waveform const_wf_x = constant(4.0ms, 4);") + + +def test_constant_waveform_repr(): + length = 4e-3 + iq = 4 + id = "const_wf_x" + wf = ConstantWaveform(length, iq, id) + expected = f"ConstantWaveform('id': {wf.id}, 'length': {wf.length}, 'iq': {wf.iq})" + assert repr(wf) == expected def test_constant_waveform_default_params(): @@ -101,14 +118,13 @@ def test_constant_wf_free_params(): assert wf.parameters == [FreeParameter("length_v") + FreeParameter("length_w")] _assert_wf_qasm( wf, - "waveform const_wf = " - "constant((1000000000.0*length_v + 1000000000.0*length_w)ns, 2.0 - 3.0im);", + "waveform const_wf = constant((length_v + length_w) * 1s, 2.0 - 3.0im);", ) wf_2 = wf.bind_values(length_v=2e-6, length_w=4e-6) assert len(wf_2.parameters) == 1 assert math.isclose(wf_2.parameters[0], 6e-6) - _assert_wf_qasm(wf_2, "waveform const_wf = constant(6000.0ns, 2.0 - 3.0im);") + _assert_wf_qasm(wf_2, "waveform const_wf = constant(6.0us, 2.0 - 3.0im);") def test_drag_gaussian_waveform(): @@ -126,9 +142,22 @@ def test_drag_gaussian_waveform(): assert wf.sigma == sigma assert wf.length == length - _assert_wf_qasm( - wf, "waveform drag_gauss_wf = drag_gaussian(4.0ns, 300000000.0ns, 0.6, 0.4, false);" + _assert_wf_qasm(wf, "waveform drag_gauss_wf = drag_gaussian(4.0ns, 300.0ms, 0.6, 0.4, false);") + + +def test_drag_gaussian_waveform_repr(): + length = 4e-9 + sigma = 0.3 + beta = 0.6 + amplitude = 0.4 + zero_at_edges = False + id = "drag_gauss_wf" + wf = DragGaussianWaveform(length, sigma, beta, amplitude, zero_at_edges, id) + expected = ( + f"DragGaussianWaveform('id': {wf.id}, 'length': {wf.length}, 'sigma': {wf.sigma}, " + f"'beta': {wf.beta}, 'amplitude': {wf.amplitude}, 'zero_at_edges': {wf.zero_at_edges})" ) + assert repr(wf) == expected def test_drag_gaussian_waveform_default_params(): @@ -154,22 +183,6 @@ def test_drag_gaussian_wf_eq(): assert wf != wfc -def test_gaussian_waveform(): - length = 4e-9 - sigma = 0.3 - amplitude = 0.4 - zero_at_edges = False - id = "gauss_wf" - wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id) - assert wf.id == id - assert wf.zero_at_edges == zero_at_edges - assert wf.amplitude == amplitude - assert wf.sigma == sigma - assert wf.length == length - - _assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300000000.0ns, 0.4, false);") - - def test_drag_gaussian_wf_free_params(): wf = DragGaussianWaveform( FreeParameter("length_v"), @@ -187,8 +200,7 @@ def test_drag_gaussian_wf_free_params(): _assert_wf_qasm( wf, "waveform d_gauss_wf = " - "drag_gaussian((1000000000.0*length_v)ns, (1000000000.0*sigma_a + " - "1000000000.0*sigma_b)ns, beta_y, amp_z, false);", + "drag_gaussian(length_v * 1s, (sigma_a + sigma_b) * 1s, beta_y, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_a=0.4) @@ -200,15 +212,42 @@ def test_drag_gaussian_wf_free_params(): ] _assert_wf_qasm( wf_2, - "waveform d_gauss_wf = drag_gaussian(600000000.0ns, (1000000000.0*sigma_b " - "+ 400000000.0)ns, beta_y, amp_z, false);", + "waveform d_gauss_wf = drag_gaussian(600.0ms, (0.4 + sigma_b) * 1s, beta_y, amp_z, false);", ) wf_3 = wf.bind_values(length_v=0.6, sigma_a=0.3, sigma_b=0.1, beta_y=0.2, amp_z=0.1) assert wf_3.parameters == [0.6, 0.4, 0.2, 0.1] - _assert_wf_qasm( - wf_3, "waveform d_gauss_wf = drag_gaussian(600000000.0ns, 400000000.0ns, 0.2, 0.1, false);" + _assert_wf_qasm(wf_3, "waveform d_gauss_wf = drag_gaussian(600.0ms, 400.0ms, 0.2, 0.1, false);") + + +def test_gaussian_waveform(): + length = 4e-9 + sigma = 0.3 + amplitude = 0.4 + zero_at_edges = False + id = "gauss_wf" + wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id) + assert wf.id == id + assert wf.zero_at_edges == zero_at_edges + assert wf.amplitude == amplitude + assert wf.sigma == sigma + assert wf.length == length + + _assert_wf_qasm(wf, "waveform gauss_wf = gaussian(4.0ns, 300.0ms, 0.4, false);") + + +def test_gaussian_waveform_repr(): + length = 4e-9 + sigma = 0.3 + amplitude = 0.4 + zero_at_edges = False + id = "gauss_wf" + wf = GaussianWaveform(length, sigma, amplitude, zero_at_edges, id) + expected = ( + f"GaussianWaveform('id': {wf.id}, 'length': {wf.length}, 'sigma': {wf.sigma}, " + f"'amplitude': {wf.amplitude}, 'zero_at_edges': {wf.zero_at_edges})" ) + assert repr(wf) == expected def test_gaussian_waveform_default_params(): @@ -243,19 +282,16 @@ def test_gaussian_wf_free_params(): ] _assert_wf_qasm( wf, - "waveform gauss_wf = gaussian((1000000000.0*length_v)ns, (1000000000.0*sigma_x)ns, " - "amp_z, false);", + "waveform gauss_wf = gaussian(length_v * 1s, sigma_x * 1s, amp_z, false);", ) wf_2 = wf.bind_values(length_v=0.6, sigma_x=0.4) assert wf_2.parameters == [0.6, 0.4, FreeParameter("amp_z")] - _assert_wf_qasm( - wf_2, "waveform gauss_wf = gaussian(600000000.0ns, 400000000.0ns, amp_z, false);" - ) + _assert_wf_qasm(wf_2, "waveform gauss_wf = gaussian(600.0ms, 400.0ms, amp_z, false);") wf_3 = wf.bind_values(length_v=0.6, sigma_x=0.3, amp_z=0.1) assert wf_3.parameters == [0.6, 0.3, 0.1] - _assert_wf_qasm(wf_3, "waveform gauss_wf = gaussian(600000000.0ns, 300000000.0ns, 0.1, false);") + _assert_wf_qasm(wf_3, "waveform gauss_wf = gaussian(600.0ms, 300.0ms, 0.1, false);") def _assert_wf_qasm(waveform, expected_qasm): diff --git a/test/unit_tests/braket/quantum_information/test_pauli_string.py b/test/unit_tests/braket/quantum_information/test_pauli_string.py index 82e3b0698..c3959d305 100644 --- a/test/unit_tests/braket/quantum_information/test_pauli_string.py +++ b/test/unit_tests/braket/quantum_information/test_pauli_string.py @@ -20,7 +20,7 @@ from braket.circuits import gates from braket.circuits.circuit import Circuit -from braket.circuits.observables import X, Y, Z +from braket.circuits.observables import I, X, Y, Z from braket.quantum_information import PauliString ORDER = ["I", "X", "Y", "Z"] @@ -34,15 +34,16 @@ @pytest.mark.parametrize( - "pauli_string, string, phase, observable", + "pauli_string, string, phase, observable, obs_with_id", [ - ("+XZ", "+XZ", 1, X() @ Z()), - ("-ZXY", "-ZXY", -1, Z() @ X() @ Y()), - ("YIX", "+YIX", 1, Y() @ X()), - (PauliString("-ZYXI"), "-ZYXI", -1, Z() @ Y() @ X()), + ("+XZ", "+XZ", 1, X() @ Z(), X() @ Z()), + ("-ZXY", "-ZXY", -1, Z() @ X() @ Y(), Z() @ X() @ Y()), + ("YIX", "+YIX", 1, Y() @ X(), Y() @ I() @ X()), + (PauliString("-ZYXI"), "-ZYXI", -1, Z() @ Y() @ X(), Z() @ Y() @ X() @ I()), + ("IIXIIIYI", "+IIXIIIYI", 1, X() @ Y(), I() @ I() @ X() @ I() @ I() @ I() @ Y() @ I()), ], ) -def test_happy_case(pauli_string, string, phase, observable): +def test_happy_case(pauli_string, string, phase, observable, obs_with_id): instance = PauliString(pauli_string) assert str(instance) == string assert instance.phase == phase @@ -57,6 +58,7 @@ def test_happy_case(pauli_string, string, phase, observable): assert instance == PauliString(pauli_string) assert instance == PauliString(instance) assert instance.to_unsigned_observable() == observable + assert instance.to_unsigned_observable(include_trivial=True) == obs_with_id @pytest.mark.parametrize( diff --git a/test/unit_tests/braket/registers/test_qubit.py b/test/unit_tests/braket/registers/test_qubit.py index 98f89cf8d..0c04a5d95 100644 --- a/test/unit_tests/braket/registers/test_qubit.py +++ b/test/unit_tests/braket/registers/test_qubit.py @@ -39,7 +39,7 @@ def test_index_gte_zero(qubit_index): def test_str(qubit): - expected = "Qubit({})".format(int(qubit)) + expected = f"Qubit({int(qubit)})" assert str(qubit) == expected diff --git a/test/unit_tests/braket/registers/test_qubit_set.py b/test/unit_tests/braket/registers/test_qubit_set.py index 1fd8d7212..1d730b967 100644 --- a/test/unit_tests/braket/registers/test_qubit_set.py +++ b/test/unit_tests/braket/registers/test_qubit_set.py @@ -31,20 +31,20 @@ def test_default_input(): def test_with_single(): - assert QubitSet(0) == tuple([Qubit(0)]) + assert QubitSet(0) == (Qubit(0),) def test_with_iterable(): - assert QubitSet([0, 1]) == tuple([Qubit(0), Qubit(1)]) + assert QubitSet([0, 1]) == (Qubit(0), Qubit(1)) def test_with_nested_iterable(): - assert QubitSet([0, 1, [2, 3]]) == tuple([Qubit(0), Qubit(1), Qubit(2), Qubit(3)]) + assert QubitSet([0, 1, [2, 3]]) == (Qubit(0), Qubit(1), Qubit(2), Qubit(3)) def test_with_qubit_set(): qubits = QubitSet([0, 1]) - assert QubitSet([qubits, [2, 3]]) == tuple([Qubit(0), Qubit(1), Qubit(2), Qubit(3)]) + assert QubitSet([qubits, [2, 3]]) == (Qubit(0), Qubit(1), Qubit(2), Qubit(3)) def test_flattening_does_not_recurse_infinitely(): diff --git a/test/unit_tests/braket/tasks/test_analog_hamiltonian_simulation_task_result.py b/test/unit_tests/braket/tasks/test_analog_hamiltonian_simulation_task_result.py index 23cca78be..be2d6fdbe 100644 --- a/test/unit_tests/braket/tasks/test_analog_hamiltonian_simulation_task_result.py +++ b/test/unit_tests/braket/tasks/test_analog_hamiltonian_simulation_task_result.py @@ -75,7 +75,7 @@ def additional_metadata(): }, } ], - "shiftingFields": [ + "localDetuning": [ { "magnitude": { "time_series": { diff --git a/test/unit_tests/braket/tasks/test_annealing_quantum_task_result.py b/test/unit_tests/braket/tasks/test_annealing_quantum_task_result.py index 29d21e58f..03f68b7cb 100644 --- a/test/unit_tests/braket/tasks/test_annealing_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_annealing_quantum_task_result.py @@ -221,7 +221,7 @@ def test_data_sort_by_none(annealing_result, solutions, values, solution_counts) def test_data_selected_fields(annealing_result, solutions, values, solution_counts): d = list(annealing_result.data(selected_fields=["value"])) for i in range(len(solutions)): - assert d[i] == tuple([values[i]]) + assert d[i] == (values[i],) def test_data_reverse(annealing_result, solutions, values, solution_counts): diff --git a/test/unit_tests/braket/tasks/test_gate_model_quantum_task_result.py b/test/unit_tests/braket/tasks/test_gate_model_quantum_task_result.py index 8c0f8d9a9..5447bd8b2 100644 --- a/test/unit_tests/braket/tasks/test_gate_model_quantum_task_result.py +++ b/test/unit_tests/braket/tasks/test_gate_model_quantum_task_result.py @@ -262,16 +262,6 @@ def malformatted_results_1(task_metadata_shots, additional_metadata): ).json() -@pytest.fixture -def malformatted_results_2(task_metadata_shots, additional_metadata): - return GateModelTaskResult( - measurementProbabilities={"011000": 0.9999999999999982}, - measuredQubits=[0], - taskMetadata=task_metadata_shots, - additionalMetadata=additional_metadata, - ).json() - - @pytest.fixture def openqasm_result_obj_shots(task_metadata_shots, additional_metadata_openqasm): return GateModelTaskResult.construct( @@ -415,7 +405,7 @@ def test_from_string_measurement_probabilities(result_str_3): measurement_list = [list("011000") for _ in range(shots)] expected_measurements = np.asarray(measurement_list, dtype=int) assert np.allclose(task_result.measurements, expected_measurements) - assert task_result.measurement_counts == Counter(["011000" for x in range(shots)]) + assert task_result.measurement_counts == Counter(["011000" for _ in range(shots)]) assert not task_result.measurement_counts_copied_from_device assert task_result.measurement_probabilities_copied_from_device assert not task_result.measurements_copied_from_device @@ -484,11 +474,6 @@ def test_shots_no_measurements_no_measurement_probs(malformatted_results_1): GateModelQuantumTaskResult.from_string(malformatted_results_1) -@pytest.mark.xfail(raises=ValueError) -def test_measurements_measured_qubits_mismatch(malformatted_results_2): - GateModelQuantumTaskResult.from_string(malformatted_results_2) - - @pytest.mark.parametrize("ir_result,expected_result", test_ir_results) def test_calculate_ir_results(ir_result, expected_result): ir_string = jaqcd.Program( diff --git a/test/unit_tests/braket/tasks/test_local_quantum_task.py b/test/unit_tests/braket/tasks/test_local_quantum_task.py index 6b583c608..aca0aa20d 100644 --- a/test/unit_tests/braket/tasks/test_local_quantum_task.py +++ b/test/unit_tests/braket/tasks/test_local_quantum_task.py @@ -57,5 +57,5 @@ def test_async(): def test_str(): - expected = "LocalQuantumTask('id':{})".format(TASK.id) + expected = f"LocalQuantumTask('id':{TASK.id})" assert str(TASK) == expected diff --git a/test/unit_tests/braket/timings/test_time_series.py b/test/unit_tests/braket/timings/test_time_series.py index 0f99ca650..729813aa0 100755 --- a/test/unit_tests/braket/timings/test_time_series.py +++ b/test/unit_tests/braket/timings/test_time_series.py @@ -20,7 +20,12 @@ @pytest.fixture def default_values(): - return [(2700, 25.1327), (300, 25.1327), (600, 15.1327), (Decimal(0.3), Decimal(0.4))] + return [ + (2700, Decimal("25.1327")), + (300, Decimal("25.1327")), + (600, Decimal("15.1327")), + (Decimal("0.3"), Decimal("0.4")), + ] @pytest.fixture @@ -265,11 +270,12 @@ def test_stitch_wrong_bndry_value(): @pytest.mark.parametrize( "time_res, expected_times", [ - # default_time_series: [(Decimal(0.3), Decimal(0.4), (300, 25.1327), (600, 15.1327), (2700, 25.1327))] # noqa - (Decimal(0.5), [Decimal("0.5"), Decimal("300"), Decimal("600"), Decimal("2700")]), - (Decimal(1), [Decimal("0"), Decimal("300"), Decimal("600"), Decimal("2700")]), - (Decimal(200), [Decimal("0"), Decimal("400"), Decimal("600"), Decimal("2800")]), - (Decimal(1000), [Decimal("0"), Decimal("1000"), Decimal("3000")]), + # default_time_series: [(Decimal(0.3), Decimal(0.4)), (300, 25.1327), (600, 15.1327), (2700, 25.1327)] # noqa + (None, [Decimal("0.3"), Decimal("300"), Decimal("600"), Decimal("2700")]), + (Decimal("0.5"), [Decimal("0.5"), Decimal("300"), Decimal("600"), Decimal("2700")]), + (Decimal("1"), [Decimal("0"), Decimal("300"), Decimal("600"), Decimal("2700")]), + (Decimal("200"), [Decimal("0"), Decimal("400"), Decimal("600"), Decimal("2800")]), + (Decimal("1000"), [Decimal("0"), Decimal("1000"), Decimal("3000")]), ], ) def test_discretize_times(default_time_series, time_res, expected_times): @@ -280,7 +286,8 @@ def test_discretize_times(default_time_series, time_res, expected_times): @pytest.mark.parametrize( "value_res, expected_values", [ - # default_time_series: [(Decimal(0.3), Decimal(0.4), (300, 25.1327), (600, 15.1327), (2700, 25.1327))] # noqa + # default_time_series: [(Decimal(0.3), Decimal(0.4)), (300, 25.1327), (600, 15.1327), (2700, 25.1327)] # noqa + (None, [Decimal("0.4"), Decimal("25.1327"), Decimal("15.1327"), Decimal("25.1327")]), (Decimal("0.1"), [Decimal("0.4"), Decimal("25.1"), Decimal("15.1"), Decimal("25.1")]), (Decimal(1), [Decimal("0"), Decimal("25"), Decimal("15"), Decimal("25")]), (Decimal(6), [Decimal("0"), Decimal("24"), Decimal("18"), Decimal("24")]), @@ -297,10 +304,10 @@ def test_discretize_values(default_time_series, value_res, expected_values): [ (TimeSeries(), TimeSeries(), True), (TimeSeries().put(0.1, 0.2), TimeSeries(), False), - (TimeSeries().put(float(0.1), float(0.2)), TimeSeries().put(float(0.1), float(0.2)), True), - (TimeSeries().put(float(1), float(0.2)), TimeSeries().put(int(1), float(0.2)), True), - (TimeSeries().put(float(0.1), float(0.2)), TimeSeries().put(float(0.2), float(0.2)), False), - (TimeSeries().put(float(0.1), float(0.3)), TimeSeries().put(float(0.1), float(0.2)), False), + (TimeSeries().put(0.1, 0.2), TimeSeries().put(0.1, 0.2), True), + (TimeSeries().put(float(1), 0.2), TimeSeries().put(1, 0.2), True), + (TimeSeries().put(0.1, 0.2), TimeSeries().put(0.2, 0.2), False), + (TimeSeries().put(0.1, 0.3), TimeSeries().put(0.1, 0.2), False), ], ) def test_all_close(first_series, second_series, expected_result): diff --git a/tox.ini b/tox.ini index 467b9ab84..95c862106 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,11 @@ [tox] envlist = clean,linters,docs,unit-tests +[testenv] +parallel_show_output = true +package = wheel +wheel_build_env = .pkg + [testenv:clean] deps = coverage skip_install = true @@ -18,11 +23,13 @@ extras = test [testenv:integ-tests] basepython = python3 +usedevelop=True # {posargs} contains additional arguments specified when invoking tox. e.g. tox -- -s -k test_foo.py deps = {[test-deps]deps} passenv = AWS_PROFILE + AWS_REGION BRAKET_ENDPOINT commands = pytest test/integ_tests {posargs} @@ -104,7 +111,7 @@ deps = sphinx-rtd-theme sphinxcontrib-apidoc commands = - sphinx-build -E -T -b html doc build/documentation/html + sphinx-build -E -T -b html doc build/documentation/html -j auto [testenv:serve-docs] basepython = python3