Skip to content

Commit

Permalink
Add tests for SerializableProgram handling
Browse files Browse the repository at this point in the history
  • Loading branch information
rmshaffer committed May 13, 2024
1 parent fe60109 commit f6cc106
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/unit_tests/braket/aws/test_aws_quantum_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
IRType,
OpenQASMSerializationProperties,
QubitReferenceType,
SerializableProgram,
)
from braket.device_schema import GateModelParameters, error_mitigation
from braket.device_schema.dwave import (
Expand Down Expand Up @@ -123,6 +124,21 @@ def openqasm_program():
return OpenQASMProgram(source="OPENQASM 3.0; h $0;")


class DummySerializableProgram(SerializableProgram):
def __init__(self, source: str):
self.source = source

def to_ir(
self, ir_type: IRType = IRType.OPENQASM, allow_implicit_build: bool = False
) -> str:
return self.source


@pytest.fixture
def serializable_program():
return DummySerializableProgram(source="OPENQASM 3.0; h $0;")


@pytest.fixture
def blackbird_program():
return BlackbirdProgram(source="Vac | q[0]")
Expand Down Expand Up @@ -614,6 +630,20 @@ def test_create_openqasm_program_em_serialized(aws_session, arn, openqasm_progra
)


def test_create_serializable_program(aws_session, arn, serializable_program):
aws_session.create_quantum_task.return_value = arn
shots = 21
AwsQuantumTask.create(aws_session, SIMULATOR_ARN, serializable_program, S3_TARGET, shots)

_assert_create_quantum_task_called_with(
aws_session,
SIMULATOR_ARN,
OpenQASMProgram(source=serializable_program.to_ir()).json(),
S3_TARGET,
shots,
)


def test_create_blackbird_program(aws_session, arn, blackbird_program):
aws_session.create_quantum_task.return_value = arn
shots = 21
Expand Down
40 changes: 40 additions & 0 deletions test/unit_tests/braket/devices/test_local_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from braket.annealing import Problem, ProblemType
from braket.circuits import Circuit, FreeParameter, Gate, Noise
from braket.circuits.noise_model import GateCriteria, NoiseModel, NoiseModelInstruction
from braket.circuits.serialization import IRType, SerializableProgram
from braket.device_schema import DeviceActionType, DeviceCapabilities
from braket.device_schema.openqasm_device_action_properties import OpenQASMDeviceActionProperties
from braket.devices import LocalSimulator, local_simulator
Expand Down Expand Up @@ -250,6 +251,26 @@ def properties(self) -> DeviceCapabilities:
return device_properties


class DummySerializableProgram(SerializableProgram):
def __init__(self, source: str):
self.source = source

def to_ir(
self, ir_type: IRType = IRType.OPENQASM, allow_implicit_build: bool = False
) -> str:
return self.source


class DummySerializableProgramSimulator(DummyProgramSimulator):
def run(
self,
program: SerializableProgram,
shots: int = 0,
batch_size: int = 1,
) -> GateModelQuantumTaskResult:
return GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT)


class DummyProgramDensityMatrixSimulator(BraketSimulator):
def run(
self, program: ir.openqasm.Program, shots: Optional[int], *args, **kwargs
Expand Down Expand Up @@ -556,6 +577,25 @@ def test_run_program_model():
assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT)


def test_run_serializable_program_model():
dummy = DummySerializableProgramSimulator()
sim = LocalSimulator(dummy)
task = sim.run(
DummySerializableProgram(
source="""
qubit[2] q;
bit[2] c;
h q[0];
cnot q[0], q[1];
c = measure q;
"""
)
)
assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT)


@pytest.mark.xfail(raises=ValueError)
def test_run_gate_model_value_error():
dummy = DummyCircuitSimulator()
Expand Down

0 comments on commit f6cc106

Please sign in to comment.