diff --git a/src/braket/experimental/autoqasm/api.py b/src/braket/experimental/autoqasm/api.py index bd19a1e43..04c1fe4d9 100644 --- a/src/braket/experimental/autoqasm/api.py +++ b/src/braket/experimental/autoqasm/api.py @@ -13,6 +13,8 @@ """This module implements the decorator API for generating programs using AutoQASM.""" +from __future__ import annotations + import copy import functools import inspect @@ -43,10 +45,11 @@ def main( - *args, + func: Optional[Callable] = None, + *, num_qubits: Optional[int] = None, device: Optional[Union[Device, str]] = None, -) -> Callable[[Any], aq_program.Program]: +) -> Callable[..., aq_program.Program]: """Decorator that converts a function into a callable that returns a Program object containing the quantum program. @@ -54,20 +57,22 @@ def main( function is called, and a new Program object is returned each time. Args: + func (Optional[Callable]): Decorated function. May be `None` in the case where decorator + is used with parentheses. num_qubits (Optional[int]): Configuration to set the total number of qubits to declare in the program. device (Optional[Union[Device, str]]): Configuration to set the target device for the program. Can be either an Device object or a valid Amazon Braket device ARN. Returns: - Callable[[Any], Program]: A callable which returns the converted + Callable[..., Program]: A callable which returns the converted quantum program when called. """ if isinstance(device, str): device = AwsDevice(device) return _function_wrapper( - *args, + func, converter_callback=_convert_main, converter_args={ "user_config": aq_program.UserConfig( @@ -78,28 +83,36 @@ def main( ) -def subroutine(*args) -> Callable[[Any], aq_program.Program]: +def subroutine(func: Optional[Callable] = None) -> Callable[..., aq_program.Program]: """Decorator that converts a function into a callable that will insert a subroutine into the quantum program. + Args: + func (Optional[Callable]): Decorated function. May be `None` in the case where decorator + is used with parentheses. + Returns: - Callable[[Any], Program]: A callable which returns the converted + Callable[..., Program]: A callable which returns the converted quantum program when called. """ - return _function_wrapper(*args, converter_callback=_convert_subroutine) + return _function_wrapper(func, converter_callback=_convert_subroutine) -def gate(*args) -> Callable[[Any], None]: +def gate(func: Optional[Callable] = None) -> Callable[..., None]: """Decorator that converts a function into a callable gate definition. + Args: + func (Optional[Callable]): Decorated function. May be `None` in the case where decorator + is used with parentheses. + Returns: - Callable[[Any], None]: A callable which can be used as a custom gate inside an + Callable[..., None]: A callable which can be used as a custom gate inside an aq.function or inside another aq.gate. """ - return _function_wrapper(*args, converter_callback=_convert_gate) + return _function_wrapper(func, converter_callback=_convert_gate) -def gate_calibration(*args, implements: Callable, **kwargs) -> Callable[[], GateCalibration]: +def gate_calibration(*, implements: Callable, **kwargs) -> Callable[[], GateCalibration]: """A decorator that register the decorated function as a gate calibration definition. The decorated function is added to a main program using `with_calibrations` method of the main program. The fixed values of qubits or angles that the calibration is implemented against @@ -114,40 +127,50 @@ def gate_calibration(*args, implements: Callable, **kwargs) -> Callable[[], Gate `with_calibrations` method of the main program. """ return _function_wrapper( - *args, + None, converter_callback=_convert_calibration, converter_args={"gate_function": implements, **kwargs}, ) def _function_wrapper( - *args: tuple[Any], + func: Optional[Callable], + *, converter_callback: Callable, converter_args: Optional[dict[str, Any]] = None, -) -> Callable[[Any], aq_program.Program]: +) -> Callable[..., Optional[Union[aq_program.Program, GateCalibration]]]: """Wrapping and conversion logic around the user function `f`. Args: + func (Optional[Callable]): Decorated function. May be `None` in the case where decorator + is used with parentheses. converter_callback (Callable): The function converter, e.g., _convert_main. converter_args (Optional[dict[str, Any]]): Extra arguments for the function converter. Returns: - Callable[[Any], Program]: A callable which returns the converted - quantum program when called. + Callable[..., Optional[Union[Program, GateCalibration]]]: A callable which + returns the converted construct, if any, when called. """ - if not args: - # This the case where a decorator is called with only keyword args, for example: - # @aq.main(num_qubits=4) + if not (func and callable(func)): + # This the case where a decorator is called either without a positional argument, + # or with a non-callable positional argument, which is as close of an approximation + # we can get to the case where a decorator is called with parentheses. + # + # There is still a false negative case, where we have something like: + # @aq.main(callable_pos_arg) # def my_function(): + # + # but this is known limitation in python (consider the valid non-decorator usage + # `aq.main(my_function)` for an example of why this ambiguity exists). + # # To make this work, here we simply return a partial application of this function - # which still expects a Callable as the first argument. + # which still expects a Callable as the single positional argument. return functools.partial( _function_wrapper, converter_callback=converter_callback, converter_args=converter_args ) - f = args[0] - if is_autograph_artifact(f): - return f + if is_autograph_artifact(func): + return func if not converter_args: converter_args = {} @@ -159,12 +182,12 @@ def _wrapper(*args, **kwargs) -> Callable: optional_features=_autograph_optional_features(), ) # Call the appropriate function converter - return converter_callback(f, options, args, kwargs, **converter_args) + return converter_callback(func, options, args, kwargs, **converter_args) - if inspect.isfunction(f) or inspect.ismethod(f): - _wrapper = functools.update_wrapper(_wrapper, f) + if inspect.isfunction(func) or inspect.ismethod(func): + _wrapper = functools.update_wrapper(_wrapper, func) - decorated_wrapper = tf_decorator.make_decorator(f, _wrapper) + decorated_wrapper = tf_decorator.make_decorator(func, _wrapper) return autograph_artifact(decorated_wrapper) @@ -178,7 +201,7 @@ def _autograph_optional_features() -> tuple[converter.Feature]: def _convert_main( f: Callable, options: converter.ConversionOptions, - args: list[Any], + args: tuple[Any], kwargs: dict[str, Any], user_config: aq_program.UserConfig, ) -> None: @@ -192,7 +215,7 @@ def _convert_main( Args: f (Callable): The function to be converted. options (converter.ConversionOptions): Converter options. - args (list[Any]): Arguments passed to the program when called. + args (tuple[Any]): Arguments passed to the program when called. kwargs (dict[str, Any]): Keyword arguments passed to the program when called. user_config (UserConfig): User-specified settings that influence program building. """