Skip to content

Commit

Permalink
fix: discretize method now takes None as an arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Coull committed Apr 11, 2024
1 parent 53aba5e commit 560bbfb
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 39 deletions.
16 changes: 4 additions & 12 deletions src/braket/ahs/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from decimal import Decimal
from typing import Optional

from braket.ahs.discretization_types import DiscretizationError
from braket.ahs.pattern import Pattern
from braket.timings.time_series import TimeSeries

Expand Down Expand Up @@ -44,33 +43,26 @@ def pattern(self) -> Optional[Pattern]:

def discretize(
self,
time_resolution: Decimal,
value_resolution: Decimal,
time_resolution: Optional[Decimal] = None,
value_resolution: Optional[Decimal] = None,
pattern_resolution: Optional[Decimal] = None,
) -> Field:
"""Creates a discretized version of the field,
where time, value and pattern are rounded to the
closest multiple of their corresponding resolutions.
Args:
time_resolution (Decimal): Time resolution
value_resolution (Decimal): Value resolution
time_resolution (Optional[Decimal]): Time resolution
value_resolution (Optional[Decimal]): Value resolution
pattern_resolution (Optional[Decimal]): Pattern resolution
Returns:
Field: A new discretized field.
Raises:
ValueError: if pattern_resolution is None, but there is a Pattern
"""
discretized_time_series = self.time_series.discretize(time_resolution, value_resolution)
if self.pattern is None:
discretized_pattern = None
else:
if pattern_resolution is None:
raise DiscretizationError(
f"{self.pattern} is defined but has no pattern_resolution defined"
)
discretized_pattern = self.pattern.discretize(pattern_resolution)
discretized_field = Field(time_series=discretized_time_series, pattern=discretized_pattern)
return discretized_field
15 changes: 11 additions & 4 deletions src/braket/ahs/local_detuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,17 @@ def discretize(self, properties: DiscretizationProperties) -> LocalDetuning:
Returns:
LocalDetuning: A new discretized LocalDetuning.
"""
shifting_parameters = properties.rydberg.rydbergLocal
local_detuning_parameters = properties.rydberg.rydbergLocal
time_resolution, value_resolution, pattern_resolution = (None, None, None)
if hasattr(local_detuning_parameters, "timeResolution"):
time_resolution = local_detuning_parameters.timeResolution
if hasattr(local_detuning_parameters, "commonDetuningResolution"):
value_resolution = local_detuning_parameters.commonDetuningResolution
if hasattr(local_detuning_parameters, "localDetuningResolution"):
pattern_resolution = local_detuning_parameters.localDetuningResolution
discretized_magnitude = self.magnitude.discretize(
time_resolution=shifting_parameters.timeResolution,
value_resolution=shifting_parameters.commonDetuningResolution,
pattern_resolution=shifting_parameters.localDetuningResolution,
time_resolution=time_resolution,
value_resolution=value_resolution,
pattern_resolution=pattern_resolution,
)
return LocalDetuning(discretized_magnitude)
12 changes: 9 additions & 3 deletions src/braket/ahs/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from decimal import Decimal
from numbers import Number
from typing import Optional


class Pattern:
Expand All @@ -34,16 +35,21 @@ def series(self) -> list[Number]:
"""
return self._series

def discretize(self, resolution: Decimal) -> Pattern:
def discretize(self, resolution: Optional[Decimal]) -> Pattern:
"""Creates a discretized version of the pattern,
where each value is rounded to the closest multiple
of the resolution.
Args:
resolution (Decimal): Resolution of the discretization
resolution (Optional[Decimal]): Resolution of the discretization
Returns:
Pattern: The new discretized pattern
"""
discretized_series = [round(Decimal(num) / resolution) * resolution for num in self.series]
if resolution is None:
discretized_series = [Decimal(num) for num in self.series]
else:
discretized_series = [
round(Decimal(num) / resolution) * resolution for num in self.series
]
return Pattern(series=discretized_series)
22 changes: 15 additions & 7 deletions src/braket/timings/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from decimal import Decimal
from enum import Enum
from numbers import Number
from typing import Optional


@dataclass
Expand Down Expand Up @@ -267,24 +268,31 @@ def stitch(

return new_time_series

def discretize(self, time_resolution: Decimal, value_resolution: Decimal) -> TimeSeries:
def discretize(self, time_resolution: Optional[Decimal], value_resolution: Optional[Decimal]) -> TimeSeries:
"""Creates a discretized version of the time series,
rounding all times and values to the closest multiple of the
corresponding resolution.
Args:
time_resolution (Decimal): Time resolution
value_resolution (Decimal): Value resolution
time_resolution (Optional[Decimal]): Time resolution
value_resolution (Optional[Decimal]): Value resolution
Returns:
TimeSeries: A new discretized time series.
"""
discretized_ts = TimeSeries()
for item in self:
discretized_ts.put(
time=round(Decimal(item.time) / time_resolution) * time_resolution,
value=round(Decimal(item.value) / value_resolution) * value_resolution,
)
if time_resolution is None:
discretized_time = Decimal(item.time)
else:
discretized_time = round(Decimal(item.time) / time_resolution) * time_resolution

if value_resolution is None:
discretized_value = Decimal(item.value)
else:
discretized_value = round(Decimal(item.value) / value_resolution) * value_resolution

discretized_ts.put(time=discretized_time, value=discretized_value)
return discretized_ts

@staticmethod
Expand Down
18 changes: 6 additions & 12 deletions test/unit_tests/braket/ahs/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest

from braket.ahs.discretization_types import DiscretizationError
from braket.ahs.field import Field
from braket.ahs.pattern import Pattern
from braket.timings.time_series import TimeSeries
Expand Down Expand Up @@ -80,6 +79,12 @@ def test_discretize(
[
(Decimal("0.1"), Decimal("10"), Decimal("0.5")),
(Decimal("10"), Decimal("20"), None),
(Decimal("0.1"), None, Decimal("0.5")),
(None, Decimal("10"), Decimal("0.5")),
(None, None, Decimal("0.5")),
(None, Decimal("10"), None),
(Decimal("0.1"), None, None),
(None, None, None),
(Decimal("100"), Decimal("0.1"), Decimal("1")),
],
)
Expand All @@ -93,14 +98,3 @@ def test_uniform_field(
) or expected.pattern.series == actual.pattern.series
assert expected.time_series.times() == actual.time_series.times()
assert expected.time_series.values() == actual.time_series.values()


@pytest.mark.parametrize(
"time_res, value_res, pattern_res",
[
(Decimal("10"), Decimal("20"), None),
],
)
@pytest.mark.xfail(raises=DiscretizationError)
def test_invalid_pattern_res(default_field, time_res, value_res, pattern_res):
default_field.discretize(time_res, value_res, pattern_res)
22 changes: 21 additions & 1 deletion test/unit_tests/braket/ahs/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

@pytest.fixture
def default_values():
return [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10]
return [
Decimal(0),
Decimal("0.1"),
Decimal(1),
Decimal("0.5"),
Decimal("0.2"),
Decimal("0.001"),
Decimal("1e-10"),
]


@pytest.fixture
Expand All @@ -38,6 +46,18 @@ def test_create():
"res, expected_series",
[
# default pattern: [0, 0.1, 1, 0.5, 0.2, 0.001, 1e-10]
(
None,
[
Decimal("0"),
Decimal("0.1"),
Decimal("1"),
Decimal("0.5"),
Decimal("0.2"),
Decimal("0.001"),
Decimal("1e-10"),
],
),
(
Decimal("0.001"),
[
Expand Down

0 comments on commit 560bbfb

Please sign in to comment.