Skip to content

Commit

Permalink
Update type hints (#753)
Browse files Browse the repository at this point in the history
* change: Update type hints to use built-ins
  • Loading branch information
laurencap authored Oct 23, 2023
1 parent fb0b4ec commit a079c52
Show file tree
Hide file tree
Showing 15 changed files with 87 additions and 97 deletions.
56 changes: 23 additions & 33 deletions src/braket/experimental/autoqasm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import copy
import functools
import inspect
from collections.abc import Callable
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import openqasm3.ast as qasm_ast
import oqpy.base
Expand Down Expand Up @@ -120,15 +121,15 @@ def gate_calibration(*args, implements: Callable, **kwargs) -> Callable[[], Gate


def _function_wrapper(
*args: Tuple[Any],
*args: tuple[Any],
converter_callback: Callable,
converter_args: Optional[Dict[str, Any]] = None,
converter_args: Optional[dict[str, Any]] = None,
) -> Callable[[Any], aq_program.Program]:
"""Wrapping and conversion logic around the user function `f`.
Args:
converter_callback (Callable): The function converter, e.g., _convert_main.
converter_args (Optional[Dict[str, Any]]): Extra arguments for the function converter.
converter_args (Optional[dict[str, Any]]): Extra arguments for the function converter.
Returns:
Callable[[Any], Program]: A callable which returns the converted
Expand Down Expand Up @@ -167,7 +168,7 @@ def _wrapper(*args, **kwargs) -> Callable:
return autograph_artifact(decorated_wrapper)


def _autograph_optional_features() -> Tuple[converter.Feature]:
def _autograph_optional_features() -> tuple[converter.Feature]:
# Exclude autograph features which are TensorFlow-specific
return converter.Feature.all_but(
(converter.Feature.NAME_SCOPES, converter.Feature.AUTO_CONTROL_DEPS)
Expand All @@ -177,8 +178,8 @@ def _autograph_optional_features() -> Tuple[converter.Feature]:
def _convert_main(
f: Callable,
options: converter.ConversionOptions,
args: List[Any],
kwargs: Dict[str, Any],
args: list[Any],
kwargs: dict[str, Any],
user_config: aq_program.UserConfig,
) -> None:
"""Convert the initial callable `f` into a full AutoQASM program `program`.
Expand All @@ -191,8 +192,8 @@ def _convert_main(
Args:
f (Callable): The function to be converted.
options (converter.ConversionOptions): Converter options.
args (List[Any]): Arguments passed to the program when called.
kwargs (Dict[str, Any]): Keyword arguments passed to the program when called.
args (list[Any]): Arguments passed to the program when called.
kwargs (dict[str, Any]): Keyword arguments passed to the program when called.
user_config (UserConfig): User-specified settings that influence program building.
"""
if aq_program.in_active_program_conversion_context():
Expand Down Expand Up @@ -260,8 +261,8 @@ def _add_qubit_declaration(program_conversion_context: aq_program.ProgramConvers
def _convert_subroutine(
f: Callable,
options: converter.ConversionOptions,
args: List[Any],
kwargs: Dict[str, Any],
args: list[Any],
kwargs: dict[str, Any],
) -> None:
"""Convert the initial callable `f` into a full AutoQASM program `program`.
The contents of `f` are converted into a subroutine in the program.
Expand All @@ -272,8 +273,8 @@ def _convert_subroutine(
Args:
f (Callable): The function to be converted.
options (converter.ConversionOptions): Converter options.
args (List[Any]): Arguments passed to the program when called.
kwargs (Dict[str, Any]): Keyword arguments passed to the program when called.
args (list[Any]): Arguments passed to the program when called.
kwargs (dict[str, Any]): Keyword arguments passed to the program when called.
"""
if not aq_program.in_active_program_conversion_context():
raise errors.AutoQasmTypeError(
Expand Down Expand Up @@ -426,18 +427,7 @@ def _make_return_instance_from_f_annotation(f: Callable) -> Any:
# TODO: Recursive functions should work even if the user's type hint is wrong
annotations = f.__annotations__
return_type = annotations["return"] if "return" in annotations else None

return_instance = None
if return_type and aq_types.is_qasm_type(return_type):
return_instance = return_type()
elif return_type:
if hasattr(return_type, "__origin__"):
# Types from python's typing module, such as `List`. origin gives us `list``
return_instance = return_type.__origin__()
else:
return_instance = return_type()

return return_instance
return return_type() if return_type else None


def _make_return_instance_from_oqpy_return_type(return_type: Any) -> Any:
Expand All @@ -461,8 +451,8 @@ def _get_bitvar_size(node: qasm_ast.BitType) -> Optional[int]:
def _convert_gate(
f: Callable,
options: converter.ConversionOptions,
args: List[Any],
kwargs: Dict[str, Any],
args: list[Any],
kwargs: dict[str, Any],
) -> Callable:
# We must be inside an active conversion context in order to invoke a gate
program_conversion_context = aq_program.get_program_conversion_context()
Expand Down Expand Up @@ -558,8 +548,8 @@ def _get_gate_args(f: Callable) -> aq_program.GateArgs:
def _convert_calibration(
f: Callable,
options: converter.ConversionOptions,
args: List[Any],
kwargs: Dict[str, Any],
args: list[Any],
kwargs: dict[str, Any],
gate_function: Callable,
**decorator_kwargs,
) -> GateCalibration:
Expand All @@ -569,8 +559,8 @@ def _convert_calibration(
Args:
f (Callable): The function to be converted.
options (converter.ConversionOptions): Converter options.
args (List[Any]): Arguments passed to the program when called.
kwargs (Dict[str, Any]): Keyword arguments passed to the program when called.
args (list[Any]): Arguments passed to the program when called.
kwargs (dict[str, Any]): Keyword arguments passed to the program when called.
gate_function (Callable): The gate function which calibration is being defined.
Returns:
Expand Down Expand Up @@ -624,14 +614,14 @@ def _convert_calibration(

def _validate_calibration_args(
gate_function: Callable,
decorator_args: Dict[str, Union[Qubit, float]],
decorator_args: dict[str, Union[Qubit, float]],
func_args: aq_program.GateArgs,
) -> None:
"""Validate the arguments passed to the calibration decorator and function.
Args:
gate_function (Callable): The gate function which calibration is being defined.
decorator_args (Dict[str, Union[Qubit, float]]): The calibration decorator arguments.
decorator_args (dict[str, Union[Qubit, float]]): The calibration decorator arguments.
func_args (aq_program.GateArgs): The gate function arguments.
"""
gate_args = _get_gate_args(gate_function)
Expand Down
4 changes: 2 additions & 2 deletions src/braket/experimental/autoqasm/instructions/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Non-unitary instructions that apply to qubits.
"""

from typing import Any, List
from typing import Any

from braket.experimental.autoqasm import program as aq_program

from .qubits import QubitIdentifierType, _qubit


def _qubit_instruction(
name: str, qubits: List[QubitIdentifierType], *args: Any, is_unitary: bool = True
name: str, qubits: list[QubitIdentifierType], *args: Any, is_unitary: bool = True
) -> None:
program_conversion_context = aq_program.get_program_conversion_context()
program_conversion_context.validate_gate_targets(qubits, args)
Expand Down
8 changes: 4 additions & 4 deletions src/braket/experimental/autoqasm/instructions/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@ def my_program():
"""


from typing import List, Union
from typing import Union

from braket.experimental.autoqasm import program
from braket.experimental.autoqasm import types as aq_types
from braket.experimental.autoqasm.instructions.qubits import QubitIdentifierType, _qubit


def measure(qubits: Union[QubitIdentifierType, List[QubitIdentifierType]]) -> aq_types.BitVar:
def measure(qubits: Union[QubitIdentifierType, list[QubitIdentifierType]]) -> aq_types.BitVar:
"""Add qubit measurement statements to the program and assign the measurement
results to bit variables.
Args:
qubits (Union[QubitIdentifierType, List[QubitIdentifierType]]): The target qubits
qubits (Union[QubitIdentifierType, list[QubitIdentifierType]]): The target qubits
to measure.
Returns:
BitVar: Bit variable the measurement results are assigned to.
"""
if not isinstance(qubits, List):
if not isinstance(qubits, list):
qubits = [qubits]

oqpy_program = program.get_program_conversion_context().get_oqpy_program()
Expand Down
8 changes: 4 additions & 4 deletions src/braket/experimental/autoqasm/instructions/qubits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import re
from functools import singledispatch
from typing import Any, List, Union
from typing import Any, Union

import oqpy.base
from openpulse.printer import dumps
Expand All @@ -38,14 +38,14 @@ def is_qubit_identifier_type(qubit: Any) -> bool:
return isinstance(qubit, QubitIdentifierType.__args__)


def _get_physical_qubit_indices(qids: List[str]) -> List[int]:
def _get_physical_qubit_indices(qids: list[str]) -> list[int]:
"""Convert physical qubit labels to the corresponding qubit indices.
Args:
qids (List[str]): Physical qubit labels.
qids (list[str]): Physical qubit labels.
Returns:
List[int]: Qubit indices corresponding to the input physical qubits.
list[int]: Qubit indices corresponding to the input physical qubits.
"""
braket_qubits = []
for qid in qids:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

"""Operators for conditional expressions (e.g. the ternary if statement)."""

from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional

import oqpy.base

Expand Down
3 changes: 2 additions & 1 deletion src/braket/experimental/autoqasm/operators/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

"""Operators for control flow constructs (e.g. if, for, while)."""

from typing import Any, Callable, Iterable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union

import oqpy.base

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Callable, Iterable
from collections.abc import Callable, Iterable

from braket.experimental.autoqasm.instructions.qubits import QubitIdentifierType as Qubit
from braket.experimental.autoqasm.program import Program
Expand Down
31 changes: 16 additions & 15 deletions src/braket/experimental/autoqasm/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import contextlib
import threading
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Optional, Union

import oqpy.base

Expand Down Expand Up @@ -99,12 +100,12 @@ def __init__(
self._oqpy_program = oqpy_program
self._has_pulse_control = has_pulse_control

def with_calibrations(self, gate_calibrations: Union[Callable, List[Callable]]) -> Program:
def with_calibrations(self, gate_calibrations: Union[Callable, list[Callable]]) -> Program:
"""Add the gate calibrations to the program. The calibration added program is returned
as a new object. The original program is not modified.
Args:
gate_calibrations (Union[Callable, List[Callable]]): The gate calibrations to add to
gate_calibrations (Union[Callable, list[Callable]]): The gate calibrations to add to
the main program. Calibration are passed as callable without evaluation.
Returns:
Expand Down Expand Up @@ -156,7 +157,7 @@ class GateArgs:
"""Represents a list of qubit and angle arguments for a gate definition."""

def __init__(self):
self._args: List[Union[oqpy.Qubit, oqpy.AngleVar]] = []
self._args: list[Union[oqpy.Qubit, oqpy.AngleVar]] = []

def __len__(self):
return len(self._args)
Expand All @@ -174,19 +175,19 @@ def append(self, name: str, is_qubit: bool) -> None:
self._args.append(oqpy.AngleVar(name=name))

@property
def qubits(self) -> List[oqpy.Qubit]:
def qubits(self) -> list[oqpy.Qubit]:
return [self._args[i] for i in self.qubit_indices]

@property
def angles(self) -> List[oqpy.AngleVar]:
def angles(self) -> list[oqpy.AngleVar]:
return [self._args[i] for i in self.angle_indices]

@property
def qubit_indices(self) -> List[int]:
def qubit_indices(self) -> list[int]:
return [i for i, arg in enumerate(self._args) if isinstance(arg, oqpy.Qubit)]

@property
def angle_indices(self) -> List[int]:
def angle_indices(self) -> list[int]:
return [i for i, arg in enumerate(self._args) if isinstance(arg, oqpy.AngleVar)]


Expand Down Expand Up @@ -229,11 +230,11 @@ def make_program(self) -> Program:
return Program(self.get_oqpy_program(), has_pulse_control=self._has_pulse_control)

@property
def qubits(self) -> List[int]:
def qubits(self) -> list[int]:
"""Return a sorted list of virtual qubits used in this program.
Returns:
List[int]: The list of virtual qubits, e.g. [0, 1, 2]
list[int]: The list of virtual qubits, e.g. [0, 1, 2]
"""
# Can be memoized or otherwise made more performant
return sorted(list(self._virtual_qubits_used))
Expand Down Expand Up @@ -323,12 +324,12 @@ def is_var_name_used(self, var_name: str) -> bool:
or var_name in oqpy_program.undeclared_vars.keys()
)

def validate_gate_targets(self, qubits: List[Any], angles: List[Any]) -> None:
def validate_gate_targets(self, qubits: list[Any], angles: list[Any]) -> None:
"""Validate that the specified gate targets are valid at this point in the program.
Args:
qubits (List[Any]): The list of target qubits to validate.
angles (List[Any]): The list of target angles to validate.
qubits (list[Any]): The list of target qubits to validate.
angles (list[Any]): The list of target angles to validate.
Raises:
errors.InvalidTargetQubit: Target qubits are invalid in the current context.
Expand Down Expand Up @@ -359,10 +360,10 @@ def validate_gate_targets(self, qubits: List[Any], angles: List[Any]) -> None:
)

@staticmethod
def _normalize_gate_names(gate_names: Iterable[str]) -> List[str]:
def _normalize_gate_names(gate_names: Iterable[str]) -> list[str]:
return [gate_name.lower() for gate_name in gate_names]

def _validate_verbatim_target_qubits(self, qubits: List[Any]) -> None:
def _validate_verbatim_target_qubits(self, qubits: list[Any]) -> None:
# Only physical target qubits are allowed in a verbatim block:
for qubit in qubits:
if not isinstance(qubit, str):
Expand Down
Loading

0 comments on commit a079c52

Please sign in to comment.