From 560bbfb97a1b13e2bd6a9a1a8172a9de3bd37713 Mon Sep 17 00:00:00 2001 From: Coull Date: Thu, 11 Apr 2024 11:33:44 -0700 Subject: [PATCH 1/5] fix: discretize method now takes None as an arg --- src/braket/ahs/field.py | 16 ++++------------ src/braket/ahs/local_detuning.py | 15 +++++++++++---- src/braket/ahs/pattern.py | 12 +++++++++--- src/braket/timings/time_series.py | 22 +++++++++++++++------- test/unit_tests/braket/ahs/test_field.py | 18 ++++++------------ test/unit_tests/braket/ahs/test_pattern.py | 22 +++++++++++++++++++++- 6 files changed, 66 insertions(+), 39 deletions(-) diff --git a/src/braket/ahs/field.py b/src/braket/ahs/field.py index 9a473fd99..1522b9d65 100644 --- a/src/braket/ahs/field.py +++ b/src/braket/ahs/field.py @@ -16,7 +16,6 @@ from decimal import Decimal from typing import Optional -from braket.ahs.discretization_types import DiscretizationError from braket.ahs.pattern import Pattern from braket.timings.time_series import TimeSeries @@ -44,8 +43,8 @@ def pattern(self) -> Optional[Pattern]: def discretize( self, - time_resolution: Decimal, - value_resolution: Decimal, + time_resolution: Optional[Decimal] = None, + value_resolution: Optional[Decimal] = None, pattern_resolution: Optional[Decimal] = None, ) -> Field: """Creates a discretized version of the field, @@ -53,24 +52,17 @@ def discretize( closest multiple of their corresponding resolutions. Args: - time_resolution (Decimal): Time resolution - value_resolution (Decimal): Value resolution + time_resolution (Optional[Decimal]): Time resolution + value_resolution (Optional[Decimal]): Value resolution pattern_resolution (Optional[Decimal]): Pattern resolution Returns: Field: A new discretized field. - - Raises: - ValueError: if pattern_resolution is None, but there is a Pattern """ discretized_time_series = self.time_series.discretize(time_resolution, value_resolution) if self.pattern is None: discretized_pattern = None else: - if pattern_resolution is None: - raise DiscretizationError( - f"{self.pattern} is defined but has no pattern_resolution defined" - ) discretized_pattern = self.pattern.discretize(pattern_resolution) discretized_field = Field(time_series=discretized_time_series, pattern=discretized_pattern) return discretized_field diff --git a/src/braket/ahs/local_detuning.py b/src/braket/ahs/local_detuning.py index 574b985d3..6e5122fcd 100644 --- a/src/braket/ahs/local_detuning.py +++ b/src/braket/ahs/local_detuning.py @@ -152,10 +152,17 @@ def discretize(self, properties: DiscretizationProperties) -> LocalDetuning: Returns: LocalDetuning: A new discretized LocalDetuning. """ - shifting_parameters = properties.rydberg.rydbergLocal + local_detuning_parameters = properties.rydberg.rydbergLocal + time_resolution, value_resolution, pattern_resolution = (None, None, None) + if hasattr(local_detuning_parameters, "timeResolution"): + time_resolution = local_detuning_parameters.timeResolution + if hasattr(local_detuning_parameters, "commonDetuningResolution"): + value_resolution = local_detuning_parameters.commonDetuningResolution + if hasattr(local_detuning_parameters, "localDetuningResolution"): + pattern_resolution = local_detuning_parameters.localDetuningResolution discretized_magnitude = self.magnitude.discretize( - time_resolution=shifting_parameters.timeResolution, - value_resolution=shifting_parameters.commonDetuningResolution, - pattern_resolution=shifting_parameters.localDetuningResolution, + time_resolution=time_resolution, + value_resolution=value_resolution, + pattern_resolution=pattern_resolution, ) return LocalDetuning(discretized_magnitude) diff --git a/src/braket/ahs/pattern.py b/src/braket/ahs/pattern.py index 462f0e369..92637fe0f 100644 --- a/src/braket/ahs/pattern.py +++ b/src/braket/ahs/pattern.py @@ -15,6 +15,7 @@ from decimal import Decimal from numbers import Number +from typing import Optional class Pattern: @@ -34,16 +35,21 @@ def series(self) -> list[Number]: """ return self._series - def discretize(self, resolution: Decimal) -> Pattern: + def discretize(self, resolution: Optional[Decimal]) -> Pattern: """Creates a discretized version of the pattern, where each value is rounded to the closest multiple of the resolution. Args: - resolution (Decimal): Resolution of the discretization + resolution (Optional[Decimal]): Resolution of the discretization Returns: Pattern: The new discretized pattern """ - discretized_series = [round(Decimal(num) / resolution) * resolution for num in self.series] + if resolution is None: + discretized_series = [Decimal(num) for num in self.series] + else: + discretized_series = [ + round(Decimal(num) / resolution) * resolution for num in self.series + ] return Pattern(series=discretized_series) diff --git a/src/braket/timings/time_series.py b/src/braket/timings/time_series.py index a558bec71..039292c92 100644 --- a/src/braket/timings/time_series.py +++ b/src/braket/timings/time_series.py @@ -19,6 +19,7 @@ from decimal import Decimal from enum import Enum from numbers import Number +from typing import Optional @dataclass @@ -267,24 +268,31 @@ def stitch( return new_time_series - def discretize(self, time_resolution: Decimal, value_resolution: Decimal) -> TimeSeries: + def discretize(self, time_resolution: Optional[Decimal], value_resolution: Optional[Decimal]) -> TimeSeries: """Creates a discretized version of the time series, rounding all times and values to the closest multiple of the corresponding resolution. Args: - time_resolution (Decimal): Time resolution - value_resolution (Decimal): Value resolution + time_resolution (Optional[Decimal]): Time resolution + value_resolution (Optional[Decimal]): Value resolution Returns: TimeSeries: A new discretized time series. """ discretized_ts = TimeSeries() for item in self: - discretized_ts.put( - time=round(Decimal(item.time) / time_resolution) * time_resolution, - value=round(Decimal(item.value) / value_resolution) * value_resolution, - ) + if time_resolution is None: + discretized_time = Decimal(item.time) + else: + discretized_time = round(Decimal(item.time) / time_resolution) * time_resolution + + if value_resolution is None: + discretized_value = Decimal(item.value) + else: + discretized_value = round(Decimal(item.value) / value_resolution) * value_resolution + + discretized_ts.put(time=discretized_time, value=discretized_value) return discretized_ts @staticmethod diff --git a/test/unit_tests/braket/ahs/test_field.py b/test/unit_tests/braket/ahs/test_field.py index 4212ba336..2ff6714ce 100644 --- a/test/unit_tests/braket/ahs/test_field.py +++ b/test/unit_tests/braket/ahs/test_field.py @@ -16,7 +16,6 @@ import pytest -from braket.ahs.discretization_types import DiscretizationError from braket.ahs.field import Field from braket.ahs.pattern import Pattern from braket.timings.time_series import TimeSeries @@ -80,6 +79,12 @@ def test_discretize( [ (Decimal("0.1"), Decimal("10"), Decimal("0.5")), (Decimal("10"), Decimal("20"), None), + (Decimal("0.1"), None, Decimal("0.5")), + (None, Decimal("10"), Decimal("0.5")), + (None, None, Decimal("0.5")), + (None, Decimal("10"), None), + (Decimal("0.1"), None, None), + (None, None, None), (Decimal("100"), Decimal("0.1"), Decimal("1")), ], ) @@ -93,14 +98,3 @@ def test_uniform_field( ) or expected.pattern.series == actual.pattern.series assert expected.time_series.times() == actual.time_series.times() assert expected.time_series.values() == actual.time_series.values() - - -@pytest.mark.parametrize( - "time_res, value_res, pattern_res", - [ - (Decimal("10"), Decimal("20"), None), - ], -) -@pytest.mark.xfail(raises=DiscretizationError) -def test_invalid_pattern_res(default_field, time_res, value_res, pattern_res): - default_field.discretize(time_res, value_res, pattern_res) diff --git a/test/unit_tests/braket/ahs/test_pattern.py b/test/unit_tests/braket/ahs/test_pattern.py index d84f3a925..920f2cc29 100644 --- a/test/unit_tests/braket/ahs/test_pattern.py +++ b/test/unit_tests/braket/ahs/test_pattern.py @@ -20,7 +20,15 @@ @pytest.fixture def default_values(): - return [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10] + return [ + Decimal(0), + Decimal("0.1"), + Decimal(1), + Decimal("0.5"), + Decimal("0.2"), + Decimal("0.001"), + Decimal("1e-10"), + ] @pytest.fixture @@ -38,6 +46,18 @@ def test_create(): "res, expected_series", [ # default pattern: [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10] + ( + None, + [ + Decimal("0"), + Decimal("0.1"), + Decimal("1"), + Decimal("0.5"), + Decimal("0.2"), + Decimal("0.001"), + Decimal("1e-10"), + ], + ), ( Decimal("0.001"), [ From 5e34974ce1eef5123099e3796ffdf87dc54f4735 Mon Sep 17 00:00:00 2001 From: Coull Date: Thu, 11 Apr 2024 11:34:24 -0700 Subject: [PATCH 2/5] tox fix --- src/braket/timings/time_series.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/braket/timings/time_series.py b/src/braket/timings/time_series.py index 039292c92..afdabd726 100644 --- a/src/braket/timings/time_series.py +++ b/src/braket/timings/time_series.py @@ -268,7 +268,9 @@ def stitch( return new_time_series - def discretize(self, time_resolution: Optional[Decimal], value_resolution: Optional[Decimal]) -> TimeSeries: + def discretize( + self, time_resolution: Optional[Decimal], value_resolution: Optional[Decimal] + ) -> TimeSeries: """Creates a discretized version of the time series, rounding all times and values to the closest multiple of the corresponding resolution. From 2b716214f603920cefb8def72a7a99c20d14bda8 Mon Sep 17 00:00:00 2001 From: Coull Date: Thu, 11 Apr 2024 11:40:27 -0700 Subject: [PATCH 3/5] coverage fix --- .../braket/timings/test_time_series.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/test/unit_tests/braket/timings/test_time_series.py b/test/unit_tests/braket/timings/test_time_series.py index 0f99ca650..c5a26334f 100755 --- a/test/unit_tests/braket/timings/test_time_series.py +++ b/test/unit_tests/braket/timings/test_time_series.py @@ -20,7 +20,12 @@ @pytest.fixture def default_values(): - return [(2700, 25.1327), (300, 25.1327), (600, 15.1327), (Decimal(0.3), Decimal(0.4))] + return [ + (2700, Decimal("25.1327")), + (300, Decimal("25.1327")), + (600, Decimal("15.1327")), + (Decimal("0.3"), Decimal("0.4")), + ] @pytest.fixture @@ -265,11 +270,12 @@ def test_stitch_wrong_bndry_value(): @pytest.mark.parametrize( "time_res, expected_times", [ - # default_time_series: [(Decimal(0.3), Decimal(0.4), (300, 25.1327), (600, 15.1327), (2700, 25.1327))] # noqa - (Decimal(0.5), [Decimal("0.5"), Decimal("300"), Decimal("600"), Decimal("2700")]), - (Decimal(1), [Decimal("0"), Decimal("300"), Decimal("600"), Decimal("2700")]), - (Decimal(200), [Decimal("0"), Decimal("400"), Decimal("600"), Decimal("2800")]), - (Decimal(1000), [Decimal("0"), Decimal("1000"), Decimal("3000")]), + # default_time_series: [(Decimal(0.3), Decimal(0.4)), (300, 25.1327), (600, 15.1327), (2700, 25.1327)] # noqa + (None, [Decimal("0.3"), Decimal("300"), Decimal("600"), Decimal("2700")]), + (Decimal("0.5"), [Decimal("0.5"), Decimal("300"), Decimal("600"), Decimal("2700")]), + (Decimal("1"), [Decimal("0"), Decimal("300"), Decimal("600"), Decimal("2700")]), + (Decimal("200"), [Decimal("0"), Decimal("400"), Decimal("600"), Decimal("2800")]), + (Decimal("1000"), [Decimal("0"), Decimal("1000"), Decimal("3000")]), ], ) def test_discretize_times(default_time_series, time_res, expected_times): @@ -280,7 +286,8 @@ def test_discretize_times(default_time_series, time_res, expected_times): @pytest.mark.parametrize( "value_res, expected_values", [ - # default_time_series: [(Decimal(0.3), Decimal(0.4), (300, 25.1327), (600, 15.1327), (2700, 25.1327))] # noqa + # default_time_series: [(Decimal(0.3), Decimal(0.4)), (300, 25.1327), (600, 15.1327), (2700, 25.1327)] # noqa + (None, [Decimal("0.4"), Decimal("25.1327"), Decimal("15.1327"), Decimal("25.1327")]), (Decimal("0.1"), [Decimal("0.4"), Decimal("25.1"), Decimal("15.1"), Decimal("25.1")]), (Decimal(1), [Decimal("0"), Decimal("25"), Decimal("15"), Decimal("25")]), (Decimal(6), [Decimal("0"), Decimal("24"), Decimal("18"), Decimal("24")]), From 4c1de15ac7d453a0025e2422a828a6e9205722c0 Mon Sep 17 00:00:00 2001 From: Coull Date: Fri, 12 Apr 2024 15:58:00 -0700 Subject: [PATCH 4/5] tox fix --- src/braket/ahs/driving_field.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/braket/ahs/driving_field.py b/src/braket/ahs/driving_field.py index f6c6430f2..c01632f07 100644 --- a/src/braket/ahs/driving_field.py +++ b/src/braket/ahs/driving_field.py @@ -121,18 +121,32 @@ def discretize(self, properties: DiscretizationProperties) -> DrivingField: DrivingField: A new discretized DrivingField. """ driving_parameters = properties.rydberg.rydbergGlobal - time_resolution = driving_parameters.timeResolution + time_resolution = None + if hasattr(driving_parameters, "timeResolution"): + time_resolution = driving_parameters.timeResolution + + amplitude_value_resolution = None + if hasattr(driving_parameters, "rabiFrequencyResolution"): + amplitude_value_resolution = driving_parameters.rabiFrequencyResolution discretized_amplitude = self.amplitude.discretize( time_resolution=time_resolution, - value_resolution=driving_parameters.rabiFrequencyResolution, + value_resolution=amplitude_value_resolution, ) + + phase_value_resolution = None + if hasattr(driving_parameters, "phaseResolution"): + phase_value_resolution = driving_parameters.phaseResolution discretized_phase = self.phase.discretize( time_resolution=time_resolution, - value_resolution=driving_parameters.phaseResolution, + value_resolution=phase_value_resolution, ) + + detuning_value_resolution = None + if hasattr(driving_parameters, "detuningResolution"): + detuning_value_resolution = driving_parameters.detuningResolution discretized_detuning = self.detuning.discretize( time_resolution=time_resolution, - value_resolution=driving_parameters.detuningResolution, + value_resolution=detuning_value_resolution, ) return DrivingField( amplitude=discretized_amplitude, phase=discretized_phase, detuning=discretized_detuning From 83c9081cb72d1213a8a2dbf5a2219ce642df58fa Mon Sep 17 00:00:00 2001 From: Coull Date: Fri, 12 Apr 2024 16:02:51 -0700 Subject: [PATCH 5/5] fix coverage --- src/braket/ahs/driving_field.py | 16 ++++------------ src/braket/ahs/local_detuning.py | 10 +++------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/src/braket/ahs/driving_field.py b/src/braket/ahs/driving_field.py index c01632f07..cbf01838d 100644 --- a/src/braket/ahs/driving_field.py +++ b/src/braket/ahs/driving_field.py @@ -121,29 +121,21 @@ def discretize(self, properties: DiscretizationProperties) -> DrivingField: DrivingField: A new discretized DrivingField. """ driving_parameters = properties.rydberg.rydbergGlobal - time_resolution = None - if hasattr(driving_parameters, "timeResolution"): - time_resolution = driving_parameters.timeResolution + time_resolution = driving_parameters.timeResolution - amplitude_value_resolution = None - if hasattr(driving_parameters, "rabiFrequencyResolution"): - amplitude_value_resolution = driving_parameters.rabiFrequencyResolution + amplitude_value_resolution = driving_parameters.rabiFrequencyResolution discretized_amplitude = self.amplitude.discretize( time_resolution=time_resolution, value_resolution=amplitude_value_resolution, ) - phase_value_resolution = None - if hasattr(driving_parameters, "phaseResolution"): - phase_value_resolution = driving_parameters.phaseResolution + phase_value_resolution = driving_parameters.phaseResolution discretized_phase = self.phase.discretize( time_resolution=time_resolution, value_resolution=phase_value_resolution, ) - detuning_value_resolution = None - if hasattr(driving_parameters, "detuningResolution"): - detuning_value_resolution = driving_parameters.detuningResolution + detuning_value_resolution = driving_parameters.detuningResolution discretized_detuning = self.detuning.discretize( time_resolution=time_resolution, value_resolution=detuning_value_resolution, diff --git a/src/braket/ahs/local_detuning.py b/src/braket/ahs/local_detuning.py index 6e5122fcd..1906ce837 100644 --- a/src/braket/ahs/local_detuning.py +++ b/src/braket/ahs/local_detuning.py @@ -153,13 +153,9 @@ def discretize(self, properties: DiscretizationProperties) -> LocalDetuning: LocalDetuning: A new discretized LocalDetuning. """ local_detuning_parameters = properties.rydberg.rydbergLocal - time_resolution, value_resolution, pattern_resolution = (None, None, None) - if hasattr(local_detuning_parameters, "timeResolution"): - time_resolution = local_detuning_parameters.timeResolution - if hasattr(local_detuning_parameters, "commonDetuningResolution"): - value_resolution = local_detuning_parameters.commonDetuningResolution - if hasattr(local_detuning_parameters, "localDetuningResolution"): - pattern_resolution = local_detuning_parameters.localDetuningResolution + time_resolution = local_detuning_parameters.timeResolution + value_resolution = local_detuning_parameters.commonDetuningResolution + pattern_resolution = local_detuning_parameters.localDetuningResolution discretized_magnitude = self.magnitude.discretize( time_resolution=time_resolution, value_resolution=value_resolution,