From 6696e025b03bf77d8262f61e5c3d1d31cfc2d819 Mon Sep 17 00:00:00 2001 From: Tim Date: Thu, 22 Feb 2024 16:14:07 -0500 Subject: [PATCH] add set_noise_model method --- src/braket/aws/aws_device.py | 4 +--- src/braket/devices/device.py | 12 ++++++++++ src/braket/devices/local_simulator.py | 4 +--- .../braket/devices/test_local_simulator.py | 24 ++++++++++++++----- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/braket/aws/aws_device.py b/src/braket/aws/aws_device.py index 43780aa38..b16490f0b 100644 --- a/src/braket/aws/aws_device.py +++ b/src/braket/aws/aws_device.py @@ -113,9 +113,7 @@ def __init__( self._aws_session = self._get_session_and_initialize(aws_session or AwsSession()) self._ports = None self._frames = None - if noise_model: - self._validate_device_noise_model_support(noise_model) - self._noise_model = noise_model + self.set_noise_model(noise_model) def run( self, diff --git a/src/braket/devices/device.py b/src/braket/devices/device.py index 3f2a28e41..1dd0630f1 100644 --- a/src/braket/devices/device.py +++ b/src/braket/devices/device.py @@ -116,6 +116,18 @@ def status(self) -> str: """ return self._status + def set_noise_model(self, noise_model: NoiseModel) -> None: + """Set the noise model of the device. + + Args: + noise_model (NoiseModel): The Braket noise model to apply to the circuit before + execution. Noise model can only be added to the devices that support noise + simulation. + """ + if noise_model: + self._validate_device_noise_model_support(noise_model) + self._noise_model = noise_model + def _validate_device_noise_model_support(self, noise_model: NoiseModel) -> None: supported_noises = set( SUPPORTED_NOISE_PRAGMA_TO_NOISE[pragma].__name__ diff --git a/src/braket/devices/local_simulator.py b/src/braket/devices/local_simulator.py index faee13fdf..b48ca7c20 100644 --- a/src/braket/devices/local_simulator.py +++ b/src/braket/devices/local_simulator.py @@ -72,9 +72,7 @@ def __init__( status="AVAILABLE", ) self._delegate = delegate - if noise_model: - self._validate_device_noise_model_support(noise_model) - self._noise_model = noise_model + self.set_noise_model(noise_model) def run( self, diff --git a/test/unit_tests/braket/devices/test_local_simulator.py b/test/unit_tests/braket/devices/test_local_simulator.py index 4877a0b8f..d4ddd34db 100644 --- a/test/unit_tests/braket/devices/test_local_simulator.py +++ b/test/unit_tests/braket/devices/test_local_simulator.py @@ -632,8 +632,9 @@ def noise_model(): @pytest.mark.parametrize("backend", ["dummy_oq3_dm"]) -def test_valid_local_device_for_noise_model(backend, noise_model): - device = LocalSimulator(backend, noise_model=noise_model) +def test_set_noise_model(backend, noise_model): + device = LocalSimulator(backend) + device.set_noise_model(noise_model) assert device._noise_model.instructions == [ NoiseModelInstruction(Noise.BitFlip(0.05), GateCriteria(Gate.H)), NoiseModelInstruction(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)), @@ -641,15 +642,26 @@ def test_valid_local_device_for_noise_model(backend, noise_model): @pytest.mark.parametrize("backend", ["dummy_oq3"]) -def test_invalid_local_device_for_noise_model(backend, noise_model): +def test_set_noise_model_invalid_device(backend, noise_model): with pytest.raises(ValueError): - _ = LocalSimulator(backend, noise_model=noise_model) + device = LocalSimulator(backend) + device.set_noise_model(noise_model) @pytest.mark.parametrize("backend", ["dummy_oq3_dm"]) -def test_local_device_with_invalid_noise_model(backend, noise_model): +def test_set_noise_model_invalid_noise_model(backend, noise_model): with pytest.raises(TypeError): - _ = LocalSimulator(backend, noise_model=Mock()) + device = LocalSimulator(backend) + device.set_noise_model(Mock()) + + +@pytest.mark.parametrize("backend", ["dummy_oq3_dm"]) +def test_valid_local_device_for_noise_model(backend, noise_model): + device = LocalSimulator(backend, noise_model=noise_model) + assert device._noise_model.instructions == [ + NoiseModelInstruction(Noise.BitFlip(0.05), GateCriteria(Gate.H)), + NoiseModelInstruction(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)), + ] @patch.object(DummyProgramDensityMatrixSimulator, "run")