diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 339862a4b2f3..4ec85c7b96c6 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -105,9 +105,18 @@ class ConvertAtenBinaryOp : public OpConversionPattern { OpConversionPattern::getTypeConverter()->convertType( op.getType())); - auto binaryOp = - tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); - rewriter.replaceOp(op, binaryOp.getResult()); + Value binaryOp; + + // TOSA ArithmeticRightShiftOp has a round parameter. + if constexpr (std::is_same()) { + binaryOp = rewriter.create(op->getLoc(), outTy, lhs, rhs, + /*round=*/false); + } else { + binaryOp = + tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); + } + + rewriter.replaceOp(op, binaryOp); return success(); } }; @@ -354,6 +363,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { // For bitwise operators, only integer datatype legalization is supported constexpr bool isBitwiseOp = std::is_same() || + std::is_same() || std::is_same() || std::is_same(); if (isa(lhsElemTy) && isBitwiseOp) { @@ -375,8 +385,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { constexpr auto swapLhsRhs = (std::is_same() || std::is_same() || std::is_same() || - std::is_same() || - std::is_same()); + std::is_same()); // Promote lhs and rhs dtypes for bitwise operators. TensorType resultTy = cast( @@ -692,39 +701,30 @@ class ConvertAtenOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenTanhOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); - if (selfTy && isa(selfTy.getElementType())) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); - return success(); - } - // Sigmoid legalization in TOSA for quantized element-type uses specialized - // tosa.table construct. - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); -} +template +class ConvertAtenActivationFunctionOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); + + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only Tensor types supported"); + + if (!isa(selfTy.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), self); -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSigmoidOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); - if (selfTy && isa(selfTy.getElementType())) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); return success(); } - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); -} +}; template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1209,73 +1209,63 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { } }; -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - Value exp = adaptor.getExponent(); - auto expTy = dyn_cast(exp.getType()); - - if (!expTy) - return rewriter.notifyMatchFailure( - op, "Only ranked tensor types supported in TOSA Pow"); - - if (!isa(expTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); - - Value selfTensor; - Value selfScalar = op.getSelf(); - if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor, - expTy.getElementType(), {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA Pow operation"); - - auto outType = - cast(getTypeConverter()->convertType(op.getType())); - - auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, - selfTensor, exp); - rewriter.replaceOp(op, powOp.getResult()); - - return success(); -} +template +class ConvertAtenPowOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + auto outType = + cast(this->getTypeConverter()->convertType(op.getType())); - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + Value selfTensor; + if constexpr (std::is_same()) { + Value selfScalar = op.getSelf(); + if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor, + outType.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA PowScalar operation"); + } else { + selfTensor = adaptor.getSelf(); + auto selfTy = cast(selfTensor.getType()); - if (!selfTy) - return rewriter.notifyMatchFailure( - op, "Only ranked tensor types supported in TOSA Pow"); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); - if (!isa(selfTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); + if (!isa(selfTy.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); + } - auto outType = - cast(getTypeConverter()->convertType(op.getType())); + Value expTensor; + if constexpr (std::is_same()) { + Value expScalar = op.getExponent(); + if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, + outType.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); + } else { + expTensor = adaptor.getExponent(); + auto expTy = cast(expTensor.getType()); - Value expTensor; - Value expScalar = op.getExponent(); - if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, - outType.getElementType(), {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA Pow operation"); + if (!expTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); + } - auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, - self, expTensor); - rewriter.replaceOp(op, powOp.getResult()); + auto powOp = tosa::createBinaryOpAndCast( + rewriter, op, outType, selfTensor, expTensor); + rewriter.replaceOp(op, powOp.getResult()); - return success(); -} + return success(); + } +}; template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -6721,6 +6711,10 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, + tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ @@ -6744,11 +6738,14 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) #undef INSERT_BINARY_COMPARE_PATTERN @@ -6889,18 +6886,30 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); #undef INSERT_MASKED_FILL_PATTERN +#define INSERT_POW_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); +#undef INSERT_POW_OP_PATTERN + +#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); +#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenLeakyReluOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenPowScalarOp); - INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenRsubScalarOp); INSERT_ATENOP_PATTERN(AtenConvolutionOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a9b95564e1d0..9736213fb381 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1753,6 +1753,35 @@ "ElementwiseRemainderTensorModule_Int_basic", "TriuBroadcastModule_basic", "TriuModule_basic", + "AtenHannWindowPeriodicFalseModule_basic", + "AtenHannWindowPeriodicTrueModule_basic", + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseAtenLogicalNotOpModule_basic", + "ElementwiseAtenLogicalXorOpModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseCosModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseLeFloatIntScalarModule_basic", + "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeFloatTensorNanModule_basic", + "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseSinModule_basic", "ArgminIntModule_basic", "ArgminIntModule_multiple_mins", "ArgminModule_basic", @@ -3450,10 +3479,6 @@ "MultinomialModule_basic", "RenormModuleFloat16_basic", # REMOVE WHEN ENABLE_GQA IS ADDED - "ScaledDotProductAttentionBoolMaskModule_basic", - "ScaledDotProductAttentionDifferentCausalModule_basic", - "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", - "ScaledDotProductAttentionSameCausalModule_basic", "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex128IntModule_basic", @@ -3504,8 +3529,6 @@ "AtenEyeMModuleInt2D_basic", "AtenEyeModuleInt2D_basic", "AtenFloatScalarModule_basic", - "AtenHannWindowPeriodicTrueModule_basic", - "AtenHannWindowPeriodicFalseModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", @@ -3623,8 +3646,6 @@ "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAndScalarModule_basic", - "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAsinModule_basic", "ElementwiseAsinhIntModule_basic", @@ -3642,44 +3663,23 @@ "ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", - "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", - "ElementwiseAtenLogicalXorOpModule_basic", - "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", - "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", - "ElementwiseBitwiseAndScalarInt32Module_basic", - "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseBitwiseLeftShiftInt32Module_basic", - "ElementwiseBitwiseLeftShiftInt64Module_basic", - "ElementwiseBitwiseLeftShiftInt8Module_basic", - "ElementwiseBitwiseRightShiftInt32Module_basic", - "ElementwiseBitwiseRightShiftInt64Module_basic", - "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseClampMinTensorFloatModule_basic", "ElementwiseClampMinTensorIntModule_basic", "ElementwiseClampTensorFloatModule_basic", "ElementwiseClampTensorIntModule_basic", "ElementwiseCosIntModule_basic", - "ElementwiseCosModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseErfModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "ElementwiseGeluApproximateTanhModule_basic", - "ElementwiseHardshrinkModule_basic", - "ElementwiseHardshrinkStaticModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", - "ElementwiseLeFloatIntScalarModule_basic", - "ElementwiseLeFloatScalarModule_basic", - "ElementwiseLeFloatTensorNanModule_basic", - "ElementwiseLeIntScalarModule_basic", - "ElementwiseLeMixedIntScalarModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog10Module_basic", "ElementwiseLog1pModule_basic", @@ -3690,18 +3690,12 @@ "ElementwiseMishModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexModule_basic", - "ElementwisePowScalarModule_basic", - "ElementwisePowTensorBroadcastModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", - "ElementwisePowTensorModule_basic", - "ElementwisePowTensorStaticModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", - "ElementwiseSinModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", "ElementwiseTanIntModule_basic", @@ -4418,9 +4412,7 @@ "ElementwiseAtenLogicalOrOpNegativeModule_basic", "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", "ElementwiseAtenLogicalOrOpRandomModule_basic", - "ElementwiseAtenLogicalXorOpModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", - "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", @@ -4439,7 +4431,6 @@ "ElementwiseClampModule_basic", "ElementwiseClampTensorInt8Module_basic", "ElementwiseCosIntModule_basic", - "ElementwiseCosModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", @@ -4459,7 +4450,6 @@ "ElementwiseEqBoolScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseErfModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", @@ -4471,7 +4461,6 @@ "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseIsinfModule_basic", - "ElementwiseLeFloatTensorNanModule_basic", "ElementwiseLeMixedIntScalarModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", @@ -4486,12 +4475,6 @@ "ElementwiseNanToNumModule_Basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwisePowModule_basic", - "ElementwisePowScalarModule_basic", - "ElementwisePowTensorBroadcastModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", - "ElementwisePowTensorModule_basic", - "ElementwisePowTensorStaticModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -4504,7 +4487,6 @@ "ElementwiseSgnModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", - "ElementwiseSinModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", "ElementwiseSqrtIntModule_basic", @@ -4663,8 +4645,6 @@ "LinalgNormKeepDimComplexModule_basic", "LinalgNormModule_basic", "LinalgVectorNormComplexModule_basic", - "LinalgVectorNormKeepDimModule_basic", - "LinalgVectorNormModule_basic", "LogSoftmaxBackwardModule_basic", "LogSoftmaxIntModule_basic", "MaskedFillTensorFloatValueModule_basic", @@ -4752,8 +4732,6 @@ "NativeGroupNormBackwardModule_basic", "NativeGroupNormModule_basic", "NativeLayerNormDynamicModule_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyStridedModuleDefaultDtype_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index e2de2db7cd7b..dbdb10312b24 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1914,3 +1914,190 @@ func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !tor %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> return %0 : !torch.vtensor<[2, 4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_not( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<4x5xi1>) -> tensor<4x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[4,5],i1> +// CHECK: } +func.func @torch.aten.logical_not(%arg0: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { + %0 = torch.aten.logical_not %arg0 : !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1> + return %0 : !torch.vtensor<[4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cos( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.cos(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.cos %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sin( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.sin %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_1]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.aten.pow.Scalar %float2.000000e00, %arg0 : !torch.float, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.pow.Tensor_Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.erf$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.erf %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.erf %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_and.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.bitwise_and.Scalar %arg0, %int2 : !torch.vtensor<[?,?],si32>, !torch.int -> !torch.vtensor<[?,?],si32> + return %0 : !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.le.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.le.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %int2 = torch.constant.int 2 + %0 = torch.aten.le.Scalar %arg0, %int2 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_xor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_xor %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_left_shift %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_left_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0: !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.arithmetic_right_shift %[[VAL_3]], %[[VAL_2]] {round = false} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0: !torch.vtensor<[?,?],si32> +}