diff --git a/src/braket/devices/local_simulator.py b/src/braket/devices/local_simulator.py index 166a1ba19..bb1d63832 100644 --- a/src/braket/devices/local_simulator.py +++ b/src/braket/devices/local_simulator.py @@ -123,7 +123,7 @@ def run( if self._noise_model: task_specification = self._apply_noise_model_to_circuit(task_specification) payload = self._construct_payload(task_specification, inputs, shots) - result = self._delegate.run(*payload, shots=shots, *args, **kwargs) + result = self._delegate.run(payload, *args, shots=shots, **kwargs) return LocalQuantumTask(self._to_result_object(result)) def run_batch( # noqa: C901 @@ -201,10 +201,6 @@ def run_batch( # noqa: C901 tasks_and_inputs = list(tasks_and_inputs) payloads = [] - payload_args = [] - kwargs_with_shots = dict(**kwargs) - kwargs_with_shots["shots"] = shots - payload_kwargs = [kwargs_with_shots] * len(tasks_and_inputs) for task_specification, input_map in tasks_and_inputs: if isinstance(task_specification, Circuit): param_names = {param.name for param in task_specification.parameters} @@ -213,11 +209,11 @@ def run_batch( # noqa: C901 f"Cannot execute circuit with unbound parameters: " f"{unbounded_parameters}" ) - payload = self._construct_payload(task_specification, input_map, shots) - payloads.append(payload[0]) - payload_args.append(payload[1:] + args) + payloads.append(self._construct_payload(task_specification, input_map, shots)) - results = self._delegate.run_multiple(payloads, payload_args, payload_kwargs, max_parallel) + results = self._delegate.run_multiple( + payloads, *args, shots=shots, max_parallel=max_parallel, **kwargs + ) return LocalQuantumTaskBatch([self._to_result_object(result) for result in results]) @property @@ -263,7 +259,7 @@ def _construct_payload( task_specification: Any, inputs: Optional[dict[str, float]], shots: Optional[int], - ) -> tuple: + ) -> Any: raise NotImplementedError(f"Unsupported task type {type(task_specification)}") @_construct_payload.register @@ -273,10 +269,10 @@ def _(self, circuit: Circuit, inputs: Optional[dict[str, float]], shots: Optiona validate_circuit_and_shots(circuit, shots) program = circuit.to_ir(ir_type=IRType.OPENQASM) program.inputs.update(inputs or {}) - return (program,) + return program elif DeviceActionType.JAQCD in simulator.properties.action: validate_circuit_and_shots(circuit, shots) - return circuit.to_ir(ir_type=IRType.JAQCD), circuit.qubit_count + return circuit.to_ir(ir_type=IRType.JAQCD) raise NotImplementedError(f"{type(simulator)} does not support qubit gate-based programs") @_construct_payload.register @@ -291,11 +287,11 @@ def _(self, program: OpenQASMProgram, inputs: Optional[dict[str, float]], _shots source=program.source, inputs=inputs_copy, ) - return (program,) + return program @_construct_payload.register def _(self, program: SerializableProgram, _inputs, _shots): - return (OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM)),) + return OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM)) @_construct_payload.register def _(self, program: AnalogHamiltonianSimulation, _inputs, _shots): @@ -304,7 +300,7 @@ def _(self, program: AnalogHamiltonianSimulation, _inputs, _shots): raise NotImplementedError( f"{type(simulator)} does not support analog Hamiltonian simulation programs" ) - return (program.to_ir(),) + return program.to_ir() @_construct_payload.register def _(self, program: AHSProgram, _inputs, _shots): @@ -313,20 +309,16 @@ def _(self, program: AHSProgram, _inputs, _shots): raise NotImplementedError( f"{type(simulator)} does not support analog Hamiltonian simulation programs" ) - return (program,) + return program @_construct_payload.register - def _(self, problem: Problem, _inputs, _shots) -> Union[ - GateModelQuantumTaskResult, - AnalogHamiltonianSimulationQuantumTaskResult, - AnnealingQuantumTaskResult, - ]: + def _(self, problem: Problem, _inputs, _shots): simulator = self._delegate if DeviceActionType.ANNEALING not in simulator.properties.action: raise NotImplementedError( f"{type(simulator)} does not support quantum annealing problems" ) - return (problem.to_ir(),) + return problem.to_ir() @singledispatchmethod def _to_result_object(self, result: Any) -> Any: diff --git a/test/unit_tests/braket/devices/test_local_simulator.py b/test/unit_tests/braket/devices/test_local_simulator.py index 90542c8fe..451553f02 100644 --- a/test/unit_tests/braket/devices/test_local_simulator.py +++ b/test/unit_tests/braket/devices/test_local_simulator.py @@ -156,7 +156,12 @@ def properties(self) -> DeviceCapabilities: class DummyJaqcdSimulator(BraketSimulator): def run( - self, program: ir.jaqcd.Program, qubits: int, shots: Optional[int], *args, **kwargs + self, + program: ir.jaqcd.Program, + qubits: Optional[int] = None, + shots: Optional[int] = None, + *args, + **kwargs, ) -> dict[str, Any]: if not isinstance(program, ir.jaqcd.Program): raise TypeError("Not a Jaqcd program") @@ -552,7 +557,7 @@ def test_run_jaqcd_only(): sim = LocalSimulator(dummy) task = sim.run(Circuit().h(0).cnot(0, 1), 10) dummy.assert_shots(10) - dummy.assert_qubits(2) + dummy.assert_qubits(None) assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT)