Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: discretize method now takes None as an arg #950

Merged
merged 7 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/braket/ahs/driving_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,23 @@ def discretize(self, properties: DiscretizationProperties) -> DrivingField:
"""
driving_parameters = properties.rydberg.rydbergGlobal
time_resolution = driving_parameters.timeResolution

amplitude_value_resolution = driving_parameters.rabiFrequencyResolution
discretized_amplitude = self.amplitude.discretize(
time_resolution=time_resolution,
value_resolution=driving_parameters.rabiFrequencyResolution,
value_resolution=amplitude_value_resolution,
)

phase_value_resolution = driving_parameters.phaseResolution
discretized_phase = self.phase.discretize(
time_resolution=time_resolution,
value_resolution=driving_parameters.phaseResolution,
value_resolution=phase_value_resolution,
)

detuning_value_resolution = driving_parameters.detuningResolution
discretized_detuning = self.detuning.discretize(
time_resolution=time_resolution,
value_resolution=driving_parameters.detuningResolution,
value_resolution=detuning_value_resolution,
)
return DrivingField(
amplitude=discretized_amplitude, phase=discretized_phase, detuning=discretized_detuning
Expand Down
16 changes: 4 additions & 12 deletions src/braket/ahs/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from decimal import Decimal
from typing import Optional

from braket.ahs.discretization_types import DiscretizationError
from braket.ahs.pattern import Pattern
from braket.timings.time_series import TimeSeries

Expand Down Expand Up @@ -44,33 +43,26 @@ def pattern(self) -> Optional[Pattern]:

def discretize(
self,
time_resolution: Decimal,
value_resolution: Decimal,
time_resolution: Optional[Decimal] = None,
value_resolution: Optional[Decimal] = None,
pattern_resolution: Optional[Decimal] = None,
) -> Field:
"""Creates a discretized version of the field,
where time, value and pattern are rounded to the
closest multiple of their corresponding resolutions.

Args:
time_resolution (Decimal): Time resolution
value_resolution (Decimal): Value resolution
time_resolution (Optional[Decimal]): Time resolution
value_resolution (Optional[Decimal]): Value resolution
pattern_resolution (Optional[Decimal]): Pattern resolution

Returns:
Field: A new discretized field.

Raises:
ValueError: if pattern_resolution is None, but there is a Pattern
"""
discretized_time_series = self.time_series.discretize(time_resolution, value_resolution)
if self.pattern is None:
discretized_pattern = None
else:
if pattern_resolution is None:
raise DiscretizationError(
f"{self.pattern} is defined but has no pattern_resolution defined"
)
discretized_pattern = self.pattern.discretize(pattern_resolution)
discretized_field = Field(time_series=discretized_time_series, pattern=discretized_pattern)
return discretized_field
11 changes: 7 additions & 4 deletions src/braket/ahs/local_detuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,13 @@ def discretize(self, properties: DiscretizationProperties) -> LocalDetuning:
Returns:
LocalDetuning: A new discretized LocalDetuning.
"""
shifting_parameters = properties.rydberg.rydbergLocal
local_detuning_parameters = properties.rydberg.rydbergLocal
time_resolution = local_detuning_parameters.timeResolution
value_resolution = local_detuning_parameters.commonDetuningResolution
pattern_resolution = local_detuning_parameters.localDetuningResolution
discretized_magnitude = self.magnitude.discretize(
time_resolution=shifting_parameters.timeResolution,
value_resolution=shifting_parameters.commonDetuningResolution,
pattern_resolution=shifting_parameters.localDetuningResolution,
time_resolution=time_resolution,
value_resolution=value_resolution,
pattern_resolution=pattern_resolution,
)
return LocalDetuning(discretized_magnitude)
12 changes: 9 additions & 3 deletions src/braket/ahs/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from decimal import Decimal
from numbers import Number
from typing import Optional


class Pattern:
Expand All @@ -34,16 +35,21 @@ def series(self) -> list[Number]:
"""
return self._series

def discretize(self, resolution: Decimal) -> Pattern:
def discretize(self, resolution: Optional[Decimal]) -> Pattern:
"""Creates a discretized version of the pattern,
where each value is rounded to the closest multiple
of the resolution.

Args:
resolution (Decimal): Resolution of the discretization
resolution (Optional[Decimal]): Resolution of the discretization

Returns:
Pattern: The new discretized pattern
"""
discretized_series = [round(Decimal(num) / resolution) * resolution for num in self.series]
if resolution is None:
discretized_series = [Decimal(num) for num in self.series]
else:
discretized_series = [
round(Decimal(num) / resolution) * resolution for num in self.series
]
return Pattern(series=discretized_series)
24 changes: 17 additions & 7 deletions src/braket/timings/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from decimal import Decimal
from enum import Enum
from numbers import Number
from typing import Optional


@dataclass
Expand Down Expand Up @@ -267,24 +268,33 @@ def stitch(

return new_time_series

def discretize(self, time_resolution: Decimal, value_resolution: Decimal) -> TimeSeries:
def discretize(
self, time_resolution: Optional[Decimal], value_resolution: Optional[Decimal]
) -> TimeSeries:
"""Creates a discretized version of the time series,
rounding all times and values to the closest multiple of the
corresponding resolution.

Args:
time_resolution (Decimal): Time resolution
value_resolution (Decimal): Value resolution
time_resolution (Optional[Decimal]): Time resolution
value_resolution (Optional[Decimal]): Value resolution

Returns:
TimeSeries: A new discretized time series.
"""
discretized_ts = TimeSeries()
for item in self:
discretized_ts.put(
time=round(Decimal(item.time) / time_resolution) * time_resolution,
value=round(Decimal(item.value) / value_resolution) * value_resolution,
)
if time_resolution is None:
discretized_time = Decimal(item.time)
else:
discretized_time = round(Decimal(item.time) / time_resolution) * time_resolution

if value_resolution is None:
discretized_value = Decimal(item.value)
else:
discretized_value = round(Decimal(item.value) / value_resolution) * value_resolution

discretized_ts.put(time=discretized_time, value=discretized_value)
return discretized_ts

@staticmethod
Expand Down
18 changes: 6 additions & 12 deletions test/unit_tests/braket/ahs/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest

from braket.ahs.discretization_types import DiscretizationError
from braket.ahs.field import Field
from braket.ahs.pattern import Pattern
from braket.timings.time_series import TimeSeries
Expand Down Expand Up @@ -80,6 +79,12 @@ def test_discretize(
[
(Decimal("0.1"), Decimal("10"), Decimal("0.5")),
(Decimal("10"), Decimal("20"), None),
(Decimal("0.1"), None, Decimal("0.5")),
(None, Decimal("10"), Decimal("0.5")),
(None, None, Decimal("0.5")),
(None, Decimal("10"), None),
(Decimal("0.1"), None, None),
(None, None, None),
(Decimal("100"), Decimal("0.1"), Decimal("1")),
],
)
Expand All @@ -93,14 +98,3 @@ def test_uniform_field(
) or expected.pattern.series == actual.pattern.series
assert expected.time_series.times() == actual.time_series.times()
assert expected.time_series.values() == actual.time_series.values()


@pytest.mark.parametrize(
"time_res, value_res, pattern_res",
[
(Decimal("10"), Decimal("20"), None),
],
)
@pytest.mark.xfail(raises=DiscretizationError)
def test_invalid_pattern_res(default_field, time_res, value_res, pattern_res):
default_field.discretize(time_res, value_res, pattern_res)
22 changes: 21 additions & 1 deletion test/unit_tests/braket/ahs/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

@pytest.fixture
def default_values():
return [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10]
return [
Decimal(0),
Decimal("0.1"),
Decimal(1),
Decimal("0.5"),
Decimal("0.2"),
Decimal("0.001"),
Decimal("1e-10"),
]


@pytest.fixture
Expand All @@ -38,6 +46,18 @@ def test_create():
"res, expected_series",
[
# default pattern: [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10]
(
None,
[
Decimal("0"),
Decimal("0.1"),
Decimal("1"),
Decimal("0.5"),
Decimal("0.2"),
Decimal("0.001"),
Decimal("1e-10"),
],
),
(
Decimal("0.001"),
[
Expand Down
21 changes: 14 additions & 7 deletions test/unit_tests/braket/timings/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

@pytest.fixture
def default_values():
return [(2700, 25.1327), (300, 25.1327), (600, 15.1327), (Decimal(0.3), Decimal(0.4))]
return [
(2700, Decimal("25.1327")),
(300, Decimal("25.1327")),
(600, Decimal("15.1327")),
(Decimal("0.3"), Decimal("0.4")),
]


@pytest.fixture
Expand Down Expand Up @@ -265,11 +270,12 @@ def test_stitch_wrong_bndry_value():
@pytest.mark.parametrize(
"time_res, expected_times",
[
# default_time_series: [(Decimal(0.3), Decimal(0.4), (300, 25.1327), (600, 15.1327), (2700, 25.1327))] # noqa
(Decimal(0.5), [Decimal("0.5"), Decimal("300"), Decimal("600"), Decimal("2700")]),
(Decimal(1), [Decimal("0"), Decimal("300"), Decimal("600"), Decimal("2700")]),
(Decimal(200), [Decimal("0"), Decimal("400"), Decimal("600"), Decimal("2800")]),
(Decimal(1000), [Decimal("0"), Decimal("1000"), Decimal("3000")]),
# default_time_series: [(Decimal(0.3), Decimal(0.4)), (300, 25.1327), (600, 15.1327), (2700, 25.1327)] # noqa
(None, [Decimal("0.3"), Decimal("300"), Decimal("600"), Decimal("2700")]),
(Decimal("0.5"), [Decimal("0.5"), Decimal("300"), Decimal("600"), Decimal("2700")]),
(Decimal("1"), [Decimal("0"), Decimal("300"), Decimal("600"), Decimal("2700")]),
(Decimal("200"), [Decimal("0"), Decimal("400"), Decimal("600"), Decimal("2800")]),
(Decimal("1000"), [Decimal("0"), Decimal("1000"), Decimal("3000")]),
],
)
def test_discretize_times(default_time_series, time_res, expected_times):
Expand All @@ -280,7 +286,8 @@ def test_discretize_times(default_time_series, time_res, expected_times):
@pytest.mark.parametrize(
"value_res, expected_values",
[
# default_time_series: [(Decimal(0.3), Decimal(0.4), (300, 25.1327), (600, 15.1327), (2700, 25.1327))] # noqa
# default_time_series: [(Decimal(0.3), Decimal(0.4)), (300, 25.1327), (600, 15.1327), (2700, 25.1327)] # noqa
(None, [Decimal("0.4"), Decimal("25.1327"), Decimal("15.1327"), Decimal("25.1327")]),
(Decimal("0.1"), [Decimal("0.4"), Decimal("25.1"), Decimal("15.1"), Decimal("25.1")]),
(Decimal(1), [Decimal("0"), Decimal("25"), Decimal("15"), Decimal("25")]),
(Decimal(6), [Decimal("0"), Decimal("24"), Decimal("18"), Decimal("24")]),
Expand Down