From 5325d3e6e6e0722ba78e14725b93107e0915710a Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Tue, 2 Apr 2024 08:31:43 +0800 Subject: [PATCH] [fx] Fix type hint for fx importer (#3066) Co-authored-by: Stella Laurenzo --- python/torch_mlir/extras/fx_importer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 23ed415d5160..edcf62c69bbe 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -27,6 +27,7 @@ Tuple, TYPE_CHECKING, Union, + Iterable, ) import weakref @@ -1173,7 +1174,7 @@ def return_node_values(self, loc, nodes: List[Node]): func_dialect.ReturnOp(operands, loc=loc) def import_nodes( - self, nodes: Sequence[Node], *, skip_placeholders_outputs: bool = False + self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False ): with InsertionPoint(self._b): loc = Location.unknown() @@ -1266,7 +1267,7 @@ def _import_symbolic_torch_op( (arg.meta["val"].node.pytype if isinstance(arg, Node) else type(arg)) for arg in node.args ] - is_int = [item == int for item in arg_types] + is_int = [item is int for item in arg_types] if all(is_int): op_overload = "int" elif any(is_int): @@ -1546,7 +1547,7 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: ).result def _import_list_argument( - self, loc: Location, arg: NodeArgument, expected_jit_type + self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type ) -> Value: assert ( isinstance(expected_jit_type, torch.ListType) @@ -1554,7 +1555,7 @@ def _import_list_argument( isinstance(expected_jit_type, torch.OptionalType) and isinstance(expected_jit_type.getElementType(), torch.ListType) ) - or isinstance(expected_jit_type, NoneType) + or (expected_jit_type is None) ), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}" # parse list type @@ -1630,7 +1631,7 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: with loc: return cvt(arg, self, self._cc) - def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema): + def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]: return_count = len(schema.returns) if return_count == 1: # Unary return directly maps a single meta["val"] and cannot be subscripted. @@ -1649,7 +1650,6 @@ def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) result_types = [] for v in node.meta["val"]: result_types.append(self._cc.value_info_to_type(v)) - result_types = tuple(result_types) return result_types