diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 7a6cae2e4eb6..4852a397e7fe 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2819,4 +2819,156 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, {loss, logProb}); return success(); }); + patterns.onOp( + "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + std::string mode, nearest_mode, coordTfMode; + Value noneVal = rewriter.create(binder.getLoc()); + + if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.axes")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for axes attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } + if (auto attr = + binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "keep_aspect_ratio_policy attribute"); + } + + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(mode, "mode", "nearest") || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "")) + return failure(); + + if (mode == "nearest" && nearest_mode != "floor") { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for nearest_mode " + "except floor"); + } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value cstTrue = + rewriter.create(binder.getLoc(), true); + Value modeStrValue; + + auto extract = [&rewriter, &binder](Value x, Value v) { + auto xTy = x.getType().cast(); + Type extractTy = rewriter.getType(); + if (isa(xTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + auto getValueList = [&](Value operand) { + SmallVector itemList; + auto sizes = + dyn_cast(operand.getType()).getSizes(); + Torch::BaseTensorType operandType = + operand.getType().cast(); + + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = operandType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); + + MLIRContext *context = binder.op->getContext(); + for (int i = sizes[0] - 2; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value ext = rewriter.create( + binder.getLoc(), selectResultType, operand, zero, selectIndex); + Value item = extract(operand, ext); + itemList.push_back(item); + } + auto xTy = operand.getType().cast(); + Value ValueList; + if (isa(xTy.getDtype())) { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(context)), itemList); + } else { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::FloatType::get(context)), itemList); + } + return ValueList; + }; + + Value scalesValueList = noneVal; + Value sizesValueList = noneVal; + Value alignCorners = + coordTfMode == "align_corners" ? cstTrue : cstFalse; + + if (mode == "cubic") { + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: bicubic mode"); + } + if (mode == "linear") { + modeStrValue = rewriter.create(binder.getLoc(), + "bilinear"); + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizeOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizeOperand); + } + } + if (mode == "nearest") { + modeStrValue = + rewriter.create(binder.getLoc(), "nearest"); + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizesOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizesOperand); + } + } + if (scalesValueList.getType().isa() && + sizesValueList.getType().isa()) { + return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); + } + rewriter + .replaceOpWithNewOp( + binder.op, resultType, operands[0], sizesValueList, + scalesValueList, modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 1498ea257048..c517ffe31dcf 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2040,3 +2040,24 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: %0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>) return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_nearest + func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_linear + func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], +f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + }