From 58489faf7fdd3e3f20fb849fd89e7bfffe6540fe Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:37:31 -0700 Subject: [PATCH] torch.aten.squeeze.dim lowering with dynamic dims (#3749) Address https://github.com/nod-ai/SHARK-ModelDev/issues/846 Assume the dynamic squeezed dim is 1. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 15 +++++++++++---- test/Conversion/TorchToLinalg/squeeze.mlir | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/squeeze.mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index ac1707ec23a6..902daa1cb5ad 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1658,10 +1658,17 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - // TODO: Handle the case where the dim(th) dimension is dynamic. + // assert dynamic squeeze dim size == 1 if (inputType.isDynamicDim(dim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: dim(th) dimension is not expected to be dynamic"); + Value cstDim = rewriter.create(op.getLoc(), dim); + Value dimVal = rewriter.create(op.getLoc(), input, cstDim); + Value cstOne = rewriter.create(op.getLoc(), 1); + Value cmp = rewriter.create( + op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); + rewriter.create( + op.getLoc(), cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); } const TypeConverter *typeConverter = getTypeConverter(); @@ -1671,7 +1678,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { // If the dim(th) dimension of operand tensor type is not statically unit, // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1) { + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } diff --git a/test/Conversion/TorchToLinalg/squeeze.mlir b/test/Conversion/TorchToLinalg/squeeze.mlir new file mode 100644 index 000000000000..a8922eed5a9d --- /dev/null +++ b/test/Conversion/TorchToLinalg/squeeze.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic +func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} { + // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index + // CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1" + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor into tensor + // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +}