diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index 1278b268f85519..e2d929e9fa0e92 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -201,6 +201,7 @@ void mlir::populateMathToLibmConversionPatterns( populatePatternsForOp(patterns, ctx, "sinf", "sin"); populatePatternsForOp(patterns, ctx, "sinhf", "sinh"); populatePatternsForOp(patterns, ctx, "sqrtf", "sqrt"); + populatePatternsForOp(patterns, ctx, "rsqrtf", "rsqrt"); populatePatternsForOp(patterns, ctx, "tanf", "tan"); populatePatternsForOp(patterns, ctx, "tanhf", "tanh"); populatePatternsForOp(patterns, ctx, "truncf", "trunc"); diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp index 3d99f3033cf560..2e60fe455dcade 100644 --- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp +++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp @@ -65,10 +65,11 @@ void mlir::math::populateLegalizeToF32TypeConverter( void mlir::math::populateLegalizeToF32ConversionTarget( ConversionTarget &target, TypeConverter &typeConverter) { - target.addDynamicallyLegalDialect( - [&typeConverter](Operation *op) -> bool { - return typeConverter.isLegal(op); - }); + target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { + if (isa(op->getDialect())) + return typeConverter.isLegal(op); + return true; + }); target.addLegalOp(); target.addLegalOp(); } diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir index 9d115c73b53d3b..fd5d8c322bde4f 100644 --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -1,9 +1,18 @@ + // RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s // CHECK-DAG: @acos(f64) -> f64 attributes {libm, llvm.readnone} // CHECK-DAG: @acosf(f32) -> f32 attributes {libm, llvm.readnone} +// CHECK-DAG: @acosh(f64) -> f64 attributes {libm, llvm.readnone} +// CHECK-DAG: @acoshf(f32) -> f32 attributes {libm, llvm.readnone} +// CHECK-DAG: @asin(f64) -> f64 attributes {libm, llvm.readnone} +// CHECK-DAG: @asinf(f32) -> f32 attributes {libm, llvm.readnone} +// CHECK-DAG: @asinh(f64) -> f64 attributes {libm, llvm.readnone} +// CHECK-DAG: @asinhf(f32) -> f32 attributes {libm, llvm.readnone} // CHECK-DAG: @atan(f64) -> f64 attributes {libm, llvm.readnone} // CHECK-DAG: @atanf(f32) -> f32 attributes {libm, llvm.readnone} +// CHECK-DAG: @atanh(f64) -> f64 attributes {libm, llvm.readnone} +// CHECK-DAG: @atanhf(f32) -> f32 attributes {libm, llvm.readnone} // CHECK-DAG: @erf(f64) -> f64 attributes {libm, llvm.readnone} // CHECK-DAG: @erff(f32) -> f32 attributes {libm, llvm.readnone} // CHECK-DAG: @exp(f64) -> f64 attributes {libm, llvm.readnone} @@ -50,6 +59,8 @@ // CHECK-DAG: @ceilf(f32) -> f32 attributes {libm, llvm.readnone} // CHECK-DAG: @sqrt(f64) -> f64 attributes {libm, llvm.readnone} // CHECK-DAG: @sqrtf(f32) -> f32 attributes {libm, llvm.readnone} +// CHECK-DAG: @rsqrt(f64) -> f64 attributes {libm, llvm.readnone} +// CHECK-DAG: @rsqrtf(f32) -> f32 attributes {libm, llvm.readnone} // CHECK-DAG: @pow(f64, f64) -> f64 attributes {libm, llvm.readnone} // CHECK-DAG: @powf(f32, f32) -> f32 attributes {libm, llvm.readnone} @@ -991,6 +1002,43 @@ func.func @sqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> // CHECK: } +// CHECK-LABEL: func @rsqrt_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @rsqrt_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @rsqrtf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.rsqrt %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @rsqrt(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.rsqrt %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +func.func @rsqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { + %float_result = math.rsqrt %float : vector<2xf32> + %double_result = math.rsqrt %double : vector<2xf64> + return %float_result, %double_result : vector<2xf32>, vector<2xf64> +} +// CHECK-LABEL: func @rsqrt_vec_caller( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { +// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> +// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32> +// CHECK: %[[OUT0_F32:.*]] = call @rsqrtf(%[[IN0_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> +// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32> +// CHECK: %[[OUT1_F32:.*]] = call @rsqrtf(%[[IN1_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> +// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64> +// CHECK: %[[OUT0_F64:.*]] = call @rsqrt(%[[IN0_F64]]) : (f64) -> f64 +// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> +// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64> +// CHECK: %[[OUT1_F64:.*]] = call @rsqrt(%[[IN1_F64]]) : (f64) -> f64 +// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> +// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> +// CHECK: } + // CHECK-LABEL: func @powf_caller( // CHECK-SAME: %[[FLOATA:.*]]: f32, %[[FLOATB:.*]]: f32 // CHECK-SAME: %[[DOUBLEA:.*]]: f64, %[[DOUBLEB:.*]]: f64 diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir index ae6ae7c5bc4b44..ebb0de9d2653e2 100644 --- a/mlir/test/Dialect/Math/legalize-to-f32.mlir +++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir @@ -83,3 +83,17 @@ func.func @sequences(%arg0: f16) -> f16 { %1 = math.sin %0 : f16 return %1 : f16 } + +// CHECK-LABEL: @promote_in_if_block +func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 { + // CHECK: [[EXTF0:%.+]] = arith.extf + // CHECK-NEXT: %[[RES:.*]] = scf.if + %0 = scf.if %arg2 -> bf16 { + %1 = math.absf %arg0 : bf16 + // CHECK: [[TRUNCF0:%.+]] = arith.truncf + scf.yield %1 : bf16 + } else { + scf.yield %arg1 : bf16 + } + return %0 : bf16 +}