Skip to content

Commit

Permalink
Merge pull request #217 from Xilinx/bump_to_5325d3e6
Browse files Browse the repository at this point in the history
[AutoBump] Merge with 5325d3e (1)
  • Loading branch information
mgehre-amd authored Aug 9, 2024
2 parents 33c647d + c91810a commit cf56ca4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Tuple,
TYPE_CHECKING,
Union,
Iterable,
)
import weakref

Expand Down Expand Up @@ -1173,7 +1174,7 @@ def return_node_values(self, loc, nodes: List[Node]):
func_dialect.ReturnOp(operands, loc=loc)

def import_nodes(
self, nodes: Sequence[Node], *, skip_placeholders_outputs: bool = False
self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False
):
with InsertionPoint(self._b):
loc = Location.unknown()
Expand Down Expand Up @@ -1266,7 +1267,7 @@ def _import_symbolic_torch_op(
(arg.meta["val"].node.pytype if isinstance(arg, Node) else type(arg))
for arg in node.args
]
is_int = [item == int for item in arg_types]
is_int = [item is int for item in arg_types]
if all(is_int):
op_overload = "int"
elif any(is_int):
Expand Down Expand Up @@ -1546,15 +1547,15 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
).result

def _import_list_argument(
self, loc: Location, arg: NodeArgument, expected_jit_type
self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type
) -> Value:
assert (
isinstance(expected_jit_type, torch.ListType)
or (
isinstance(expected_jit_type, torch.OptionalType)
and isinstance(expected_jit_type.getElementType(), torch.ListType)
)
or isinstance(expected_jit_type, NoneType)
or (expected_jit_type is None)
), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}"

# parse list type
Expand Down Expand Up @@ -1630,7 +1631,7 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
with loc:
return cvt(arg, self, self._cc)

def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema):
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]:
return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
Expand All @@ -1649,7 +1650,6 @@ def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema)
result_types = []
for v in node.meta["val"]:
result_types.append(self._cc.value_info_to_type(v))
result_types = tuple(result_types)
return result_types


Expand Down

0 comments on commit cf56ca4

Please sign in to comment.