diff --git a/src/braket/aws/aws_quantum_task.py b/src/braket/aws/aws_quantum_task.py index c6ad36b27..54f544159 100644 --- a/src/braket/aws/aws_quantum_task.py +++ b/src/braket/aws/aws_quantum_task.py @@ -105,7 +105,7 @@ def create( disable_qubit_rewiring: bool = False, tags: dict[str, str] | None = None, inputs: dict[str, float] | None = None, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None = None, + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None, quiet: bool = False, reservation_arn: str | None = None, *args, @@ -148,10 +148,9 @@ def create( IR. If the IR supports inputs, the inputs will be updated with this value. Default: {}. - gate_definitions (Optional[dict[tuple[Gate, QubitSet], PulseSequence]] | None): - A `Dict` for user defined gate calibration. The calibration is defined for - for a particular `Gate` on a particular `QubitSet` and is represented by - a `PulseSequence`. + gate_definitions (dict[tuple[Gate, QubitSet], PulseSequence] | None): A `dict` + of user defined gate calibrations. Each calibration is defined for a particular + `Gate` on a particular `QubitSet` and is represented by a `PulseSequence`. Default: None. quiet (bool): Sets the verbosity of the logger to low and does not report queue @@ -190,6 +189,7 @@ def create( if tags is not None: create_task_kwargs.update({"tags": tags}) inputs = inputs or {} + gate_definitions = gate_definitions or {} if reservation_arn: create_task_kwargs.update( @@ -561,7 +561,7 @@ def _create_internal( device_parameters: Union[dict, BraketSchemaBase], disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -577,7 +577,7 @@ def _( _device_parameters: Union[dict, BraketSchemaBase], # Not currently used for OpenQasmProgram _disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -600,7 +600,7 @@ def _( device_parameters: Union[dict, BraketSchemaBase], _disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -639,7 +639,7 @@ def _( _device_parameters: Union[dict, BraketSchemaBase], _disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -657,7 +657,7 @@ def _( device_parameters: Union[dict, BraketSchemaBase], disable_qubit_rewiring: bool, inputs: dict[str, float], - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], *args, **kwargs, ) -> AwsQuantumTask: @@ -678,7 +678,7 @@ def _( if ( disable_qubit_rewiring or Instruction(StartVerbatimBox()) in circuit.instructions - or gate_definitions is not None + or gate_definitions or any(isinstance(instruction.operator, PulseGate) for instruction in circuit.instructions) ): qubit_reference_type = QubitReferenceType.PHYSICAL diff --git a/src/braket/aws/aws_quantum_task_batch.py b/src/braket/aws/aws_quantum_task_batch.py index a02dfa6d6..ed3430274 100644 --- a/src/braket/aws/aws_quantum_task_batch.py +++ b/src/braket/aws/aws_quantum_task_batch.py @@ -23,8 +23,11 @@ from braket.aws.aws_quantum_task import AwsQuantumTask from braket.aws.aws_session import AwsSession from braket.circuits import Circuit +from braket.circuits.gate import Gate from braket.ir.blackbird import Program as BlackbirdProgram from braket.ir.openqasm import Program as OpenQasmProgram +from braket.pulse.pulse_sequence import PulseSequence +from braket.registers.qubit_set import QubitSet from braket.tasks.quantum_task_batch import QuantumTaskBatch @@ -61,6 +64,13 @@ def __init__( poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, inputs: Union[dict[str, float], list[dict[str, float]]] | None = None, + gate_definitions: ( + Union[ + dict[tuple[Gate, QubitSet], PulseSequence], + list[dict[tuple[Gate, QubitSet], PulseSequence]], + ] + | None + ) = None, reservation_arn: str | None = None, *aws_quantum_task_args: Any, **aws_quantum_task_kwargs: Any, @@ -92,6 +102,9 @@ def __init__( inputs (Union[dict[str, float], list[dict[str, float]]] | None): Inputs to be passed along with the IR. If the IR supports inputs, the inputs will be updated with this value. Default: {}. + gate_definitions (Union[dict[tuple[Gate, QubitSet], PulseSequence], list[dict[tuple[Gate, QubitSet], PulseSequence]]] | None): # noqa: E501 + User-defined gate calibration. The calibration is defined for a particular `Gate` on a + particular `QubitSet` and is represented by a `PulseSequence`. Default: None. reservation_arn (str | None): The reservation ARN provided by Braket Direct to reserve exclusive usage for the device to run the quantum task on. Note: If you are creating tasks in a job that itself was created reservation ARN, @@ -111,6 +124,7 @@ def __init__( poll_timeout_seconds, poll_interval_seconds, inputs, + gate_definitions, reservation_arn, *aws_quantum_task_args, **aws_quantum_task_kwargs, @@ -134,7 +148,7 @@ def __init__( self._aws_quantum_task_kwargs = aws_quantum_task_kwargs @staticmethod - def _tasks_and_inputs( + def _tasks_inputs_gatedefs( task_specifications: Union[ Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation], list[ @@ -144,45 +158,55 @@ def _tasks_and_inputs( ], ], inputs: Union[dict[str, float], list[dict[str, float]]] = None, + gate_definitions: Union[ + dict[tuple[Gate, QubitSet], PulseSequence], + list[dict[tuple[Gate, QubitSet], PulseSequence]], + ] = None, ) -> list[ tuple[ Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation], dict[str, float], + dict[tuple[Gate, QubitSet], PulseSequence], ] ]: inputs = inputs or {} - - max_inputs_tasks = 1 - single_task = isinstance( - task_specifications, - (Circuit, Problem, OpenQasmProgram, BlackbirdProgram, AnalogHamiltonianSimulation), - ) - single_input = isinstance(inputs, dict) - - max_inputs_tasks = ( - max(max_inputs_tasks, len(task_specifications)) if not single_task else max_inputs_tasks - ) - max_inputs_tasks = ( - max(max_inputs_tasks, len(inputs)) if not single_input else max_inputs_tasks + gate_definitions = gate_definitions or {} + + single_task_type = ( + Circuit, + Problem, + OpenQasmProgram, + BlackbirdProgram, + AnalogHamiltonianSimulation, ) + single_input_type = dict + single_gate_definitions_type = dict - if not single_task and not single_input: - if len(task_specifications) != len(inputs): - raise ValueError("Multiple inputs and task specifications must be equal in number.") - if single_task: - task_specifications = repeat(task_specifications, times=max_inputs_tasks) + args = [task_specifications, inputs, gate_definitions] + single_arg_types = [single_task_type, single_input_type, single_gate_definitions_type] - if single_input: - inputs = repeat(inputs, times=max_inputs_tasks) + batch_length = 1 + arg_lengths = [] + for arg, single_arg_type in zip(args, single_arg_types): + arg_length = 1 if isinstance(arg, single_arg_type) else len(arg) + arg_lengths.append(arg_length) - tasks_and_inputs = zip(task_specifications, inputs) + if arg_length != 1: + if batch_length != 1 and arg_length != batch_length: + raise ValueError( + "Multiple inputs, task specifications and gate definitions must " + "be equal in length." + ) + else: + batch_length = arg_length - if single_task and single_input: - tasks_and_inputs = list(tasks_and_inputs) + for i, arg_length in enumerate(arg_lengths): + if arg_length == 1: + args[i] = repeat(args[i], batch_length) - tasks_and_inputs = list(tasks_and_inputs) + tasks_inputs_definitions = list(zip(*args)) - for task_specification, input_map in tasks_and_inputs: + for task_specification, input_map, _gate_definitions in tasks_inputs_definitions: if isinstance(task_specification, Circuit): param_names = {param.name for param in task_specification.parameters} unbounded_parameters = param_names - set(input_map.keys()) @@ -192,7 +216,7 @@ def _tasks_and_inputs( f"{unbounded_parameters}" ) - return tasks_and_inputs + return tasks_inputs_definitions @staticmethod def _execute( @@ -213,13 +237,22 @@ def _execute( poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, inputs: Union[dict[str, float], list[dict[str, float]]] = None, + gate_definitions: ( + Union[ + dict[tuple[Gate, QubitSet], PulseSequence], + list[dict[tuple[Gate, QubitSet], PulseSequence]], + ] + | None + ) = None, reservation_arn: str | None = None, *args, **kwargs, ) -> list[AwsQuantumTask]: - tasks_and_inputs = AwsQuantumTaskBatch._tasks_and_inputs(task_specifications, inputs) + tasks_inputs_gatedefs = AwsQuantumTaskBatch._tasks_inputs_gatedefs( + task_specifications, inputs, gate_definitions + ) max_threads = min(max_parallel, max_workers) - remaining = [0 for _ in tasks_and_inputs] + remaining = [0 for _ in tasks_inputs_gatedefs] try: with ThreadPoolExecutor(max_workers=max_threads) as executor: task_futures = [ @@ -234,11 +267,12 @@ def _execute( poll_timeout_seconds=poll_timeout_seconds, poll_interval_seconds=poll_interval_seconds, inputs=input_map, + gate_definitions=gatedefs, reservation_arn=reservation_arn, *args, **kwargs, ) - for task, input_map in tasks_and_inputs + for task, input_map, gatedefs in tasks_inputs_gatedefs ] except KeyboardInterrupt: # If an exception is thrown before the thread pool has finished, @@ -266,6 +300,7 @@ def _create_task( shots: int, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, inputs: dict[str, float] = None, + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] | None = None, reservation_arn: str | None = None, *args, **kwargs, @@ -278,6 +313,7 @@ def _create_task( shots, poll_interval_seconds=poll_interval_seconds, inputs=inputs, + gate_definitions=gate_definitions, reservation_arn=reservation_arn, *args, **kwargs, diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 36f0e68fb..3f4918a1f 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -1125,6 +1125,7 @@ def to_ir( ValueError: If the supplied `ir_type` is not supported, or if the supplied serialization properties don't correspond to the `ir_type`. """ + gate_definitions = gate_definitions or {} if ir_type == IRType.JAQCD: return self._to_jaqcd() elif ir_type == IRType.OPENQASM: @@ -1137,7 +1138,7 @@ def to_ir( ) return self._to_openqasm( serialization_properties or OpenQASMSerializationProperties(), - gate_definitions.copy() if gate_definitions is not None else None, + gate_definitions.copy(), ) else: raise ValueError(f"Supplied ir_type {ir_type} is not supported.") @@ -1185,7 +1186,7 @@ def _to_jaqcd(self) -> JaqcdProgram: def _to_openqasm( self, serialization_properties: OpenQASMSerializationProperties, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], ) -> OpenQasmProgram: ir_instructions = self._create_openqasm_header(serialization_properties, gate_definitions) openqasm_ir_type = IRType.OPENQASM @@ -1222,7 +1223,7 @@ def _to_openqasm( def _create_openqasm_header( self, serialization_properties: OpenQASMSerializationProperties, - gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]], + gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], ) -> list[str]: ir_instructions = ["OPENQASM 3.0;"] frame_wf_declarations = self._generate_frame_wf_defcal_declarations(gate_definitions) @@ -1244,7 +1245,7 @@ def _create_openqasm_header( ir_instructions.append(frame_wf_declarations) return ir_instructions - def _validate_gate_calbrations_uniqueness( + def _validate_gate_calibrations_uniqueness( self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence], frames: dict[str, Frame], @@ -1277,43 +1278,41 @@ def _generate_frame_wf_defcal_declarations( frames, waveforms = self._get_frames_waveforms_from_instrs(gate_definitions) - if gate_definitions is not None: - self._validate_gate_calbrations_uniqueness(gate_definitions, frames, waveforms) + self._validate_gate_calibrations_uniqueness(gate_definitions, frames, waveforms) # Declare the frames and waveforms across all pulse sequences declarable_frames = [f for f in frames.values() if not f.is_predefined] - if declarable_frames or waveforms or gate_definitions is not None: + if declarable_frames or waveforms or gate_definitions: frame_wf_to_declare = [f._to_oqpy_expression() for f in declarable_frames] frame_wf_to_declare += [wf._to_oqpy_expression() for wf in waveforms.values()] program.declare(frame_wf_to_declare, encal=True) - if gate_definitions is not None: - for key, calibration in gate_definitions.items(): - gate, qubits = key - - # Ignoring parametric gates - # Corresponding defcals with fixed arguments have been added - # in _get_frames_waveforms_from_instrs - if isinstance(gate, Parameterizable) and any( - not isinstance(parameter, (float, int, complex)) - for parameter in gate.parameters - ): - continue - - gate_name = gate._qasm_name - arguments = gate.parameters if isinstance(gate, Parameterizable) else [] - - for param in calibration.parameters: - self._parameters.add(param) - arguments = [ - param._to_oqpy_expression() if isinstance(param, FreeParameter) else param - for param in arguments - ] - - with oqpy.defcal( - program, [oqpy.PhysicalQubits[int(k)] for k in qubits], gate_name, arguments - ): - program += calibration._program + for key, calibration in gate_definitions.items(): + gate, qubits = key + + # Ignoring parametric gates + # Corresponding defcals with fixed arguments have been added + # in _get_frames_waveforms_from_instrs + if isinstance(gate, Parameterizable) and any( + not isinstance(parameter, (float, int, complex)) + for parameter in gate.parameters + ): + continue + + gate_name = gate._qasm_name + arguments = gate.parameters if isinstance(gate, Parameterizable) else [] + + for param in calibration.parameters: + self._parameters.add(param) + arguments = [ + param._to_oqpy_expression() if isinstance(param, FreeParameter) else param + for param in arguments + ] + + with oqpy.defcal( + program, [oqpy.PhysicalQubits[int(k)] for k in qubits], gate_name, arguments + ): + program += calibration._program ast = program.to_ast(encal=False, include_externs=False) return ast_to_qasm(ast) @@ -1321,7 +1320,7 @@ def _generate_frame_wf_defcal_declarations( return None def _get_frames_waveforms_from_instrs( - self, gate_definitions: Optional[dict[tuple[Gate, QubitSet], PulseSequence]] + self, gate_definitions: dict[tuple[Gate, QubitSet], PulseSequence] ) -> tuple[dict[str, Frame], dict[str, Waveform]]: from braket.circuits.gates import PulseGate @@ -1336,7 +1335,7 @@ def _get_frames_waveforms_from_instrs( _validate_uniqueness(waveforms, waveform) waveforms[waveform.id] = waveform # this will change with full parametric calibration support - elif isinstance(instruction.operator, Parameterizable) and gate_definitions is not None: + elif isinstance(instruction.operator, Parameterizable): fixed_argument_calibrations = self._add_fixed_argument_calibrations( gate_definitions, instruction ) diff --git a/test/unit_tests/braket/aws/test_aws_device.py b/test/unit_tests/braket/aws/test_aws_device.py index cac1c1536..a85ca6eb9 100644 --- a/test/unit_tests/braket/aws/test_aws_device.py +++ b/test/unit_tests/braket/aws/test_aws_device.py @@ -1138,7 +1138,7 @@ def test_run_param_circuit_with_reservation_arn_batch_task( 43200, 0.25, inputs, - None, + {}, reservation_arn="arn:aws:braket:us-west-2:123456789123:reservation/a1b123cd-45e6-789f-gh01-i234567jk8l9", ) @@ -1170,6 +1170,7 @@ def test_run_param_circuit_with_inputs_batch_task( 43200, 0.25, inputs, + {}, ) @@ -1303,7 +1304,9 @@ def test_batch_circuit_with_task_and_input_mismatch( inputs = [{"beta": 0.2}, {"gamma": 0.1}, {"theta": 0.2}] circ_1 = Circuit().ry(angle=3, target=0) task_specifications = [[circ_1, single_circuit_input], openqasm_program] - wrong_number_of_inputs = "Multiple inputs and task specifications must " "be equal in number." + wrong_number_of_inputs = ( + "Multiple inputs, task specifications and gate definitions must be equal in length." + ) with pytest.raises(ValueError, match=wrong_number_of_inputs): _run_batch_and_assert( @@ -1318,6 +1321,7 @@ def test_batch_circuit_with_task_and_input_mismatch( 43200, 0.25, inputs, + {}, ) @@ -1494,7 +1498,7 @@ def test_run_with_positional_args_and_kwargs( 86400, 0.25, {}, - ["foo"], + {}, "arn:aws:braket:us-west-2:123456789123:reservation/a1b123cd-45e6-789f-gh01-i234567jk8l9", None, {"bar": 1, "baz": 2}, @@ -1534,6 +1538,7 @@ def test_run_batch_no_extra( 43200, 0.25, {}, + {}, ) @@ -1560,6 +1565,7 @@ def test_run_batch_with_shots( 43200, 0.25, {}, + {}, ) @@ -1586,6 +1592,7 @@ def test_run_batch_with_max_parallel_and_kwargs( 43200, 0.25, inputs={"theta": 0.2}, + gate_definitions={}, extra_kwargs={"bar": 1, "baz": 2}, )