Skip to content

Commit

Permalink
Use same *args, **kwargs across all tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Jun 27, 2024
1 parent 0e578f9 commit 3262872
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
36 changes: 14 additions & 22 deletions src/braket/devices/local_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run(
if self._noise_model:
task_specification = self._apply_noise_model_to_circuit(task_specification)
payload = self._construct_payload(task_specification, inputs, shots)
result = self._delegate.run(*payload, shots=shots, *args, **kwargs)
result = self._delegate.run(payload, *args, shots=shots, **kwargs)
return LocalQuantumTask(self._to_result_object(result))

def run_batch( # noqa: C901
Expand Down Expand Up @@ -201,10 +201,6 @@ def run_batch( # noqa: C901
tasks_and_inputs = list(tasks_and_inputs)

payloads = []
payload_args = []
kwargs_with_shots = dict(**kwargs)
kwargs_with_shots["shots"] = shots
payload_kwargs = [kwargs_with_shots] * len(tasks_and_inputs)
for task_specification, input_map in tasks_and_inputs:
if isinstance(task_specification, Circuit):
param_names = {param.name for param in task_specification.parameters}
Expand All @@ -213,11 +209,11 @@ def run_batch( # noqa: C901
f"Cannot execute circuit with unbound parameters: "
f"{unbounded_parameters}"
)
payload = self._construct_payload(task_specification, input_map, shots)
payloads.append(payload[0])
payload_args.append(payload[1:] + args)
payloads.append(self._construct_payload(task_specification, input_map, shots))

results = self._delegate.run_multiple(payloads, payload_args, payload_kwargs, max_parallel)
results = self._delegate.run_multiple(
payloads, *args, shots=shots, max_parallel=max_parallel, **kwargs
)
return LocalQuantumTaskBatch([self._to_result_object(result) for result in results])

@property
Expand Down Expand Up @@ -263,7 +259,7 @@ def _construct_payload(
task_specification: Any,
inputs: Optional[dict[str, float]],
shots: Optional[int],
) -> tuple:
) -> Any:
raise NotImplementedError(f"Unsupported task type {type(task_specification)}")

@_construct_payload.register
Expand All @@ -273,10 +269,10 @@ def _(self, circuit: Circuit, inputs: Optional[dict[str, float]], shots: Optiona
validate_circuit_and_shots(circuit, shots)
program = circuit.to_ir(ir_type=IRType.OPENQASM)
program.inputs.update(inputs or {})
return (program,)
return program
elif DeviceActionType.JAQCD in simulator.properties.action:
validate_circuit_and_shots(circuit, shots)
return circuit.to_ir(ir_type=IRType.JAQCD), circuit.qubit_count
return circuit.to_ir(ir_type=IRType.JAQCD)
raise NotImplementedError(f"{type(simulator)} does not support qubit gate-based programs")

@_construct_payload.register
Expand All @@ -291,11 +287,11 @@ def _(self, program: OpenQASMProgram, inputs: Optional[dict[str, float]], _shots
source=program.source,
inputs=inputs_copy,
)
return (program,)
return program

@_construct_payload.register
def _(self, program: SerializableProgram, _inputs, _shots):
return (OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM)),)
return OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM))

@_construct_payload.register
def _(self, program: AnalogHamiltonianSimulation, _inputs, _shots):
Expand All @@ -304,7 +300,7 @@ def _(self, program: AnalogHamiltonianSimulation, _inputs, _shots):
raise NotImplementedError(
f"{type(simulator)} does not support analog Hamiltonian simulation programs"
)
return (program.to_ir(),)
return program.to_ir()

@_construct_payload.register
def _(self, program: AHSProgram, _inputs, _shots):
Expand All @@ -313,20 +309,16 @@ def _(self, program: AHSProgram, _inputs, _shots):
raise NotImplementedError(
f"{type(simulator)} does not support analog Hamiltonian simulation programs"
)
return (program,)
return program

@_construct_payload.register
def _(self, problem: Problem, _inputs, _shots) -> Union[
GateModelQuantumTaskResult,
AnalogHamiltonianSimulationQuantumTaskResult,
AnnealingQuantumTaskResult,
]:
def _(self, problem: Problem, _inputs, _shots):
simulator = self._delegate
if DeviceActionType.ANNEALING not in simulator.properties.action:
raise NotImplementedError(
f"{type(simulator)} does not support quantum annealing problems"
)
return (problem.to_ir(),)
return problem.to_ir()

@singledispatchmethod
def _to_result_object(self, result: Any) -> Any:
Expand Down
9 changes: 7 additions & 2 deletions test/unit_tests/braket/devices/test_local_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ def properties(self) -> DeviceCapabilities:

class DummyJaqcdSimulator(BraketSimulator):
def run(
self, program: ir.jaqcd.Program, qubits: int, shots: Optional[int], *args, **kwargs
self,
program: ir.jaqcd.Program,
qubits: Optional[int] = None,
shots: Optional[int] = None,
*args,
**kwargs,
) -> dict[str, Any]:
if not isinstance(program, ir.jaqcd.Program):
raise TypeError("Not a Jaqcd program")
Expand Down Expand Up @@ -552,7 +557,7 @@ def test_run_jaqcd_only():
sim = LocalSimulator(dummy)
task = sim.run(Circuit().h(0).cnot(0, 1), 10)
dummy.assert_shots(10)
dummy.assert_qubits(2)
dummy.assert_qubits(None)
assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT)


Expand Down

0 comments on commit 3262872

Please sign in to comment.