Skip to content

Commit

Permalink
[mlir][emitc] arith.cmpf to EmitC conversion (llvm#93671)
Browse files Browse the repository at this point in the history
Convert all arith.cmpf on floats (not vectors/tensors thereof) to EmitC.

---------

Co-authored-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Jose Lopes <jose.lopes@amd.com>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent 4ab7354 commit 46672c1
Show file tree
Hide file tree
Showing 3 changed files with 402 additions and 1 deletion.
159 changes: 158 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;
Expand Down Expand Up @@ -59,6 +61,160 @@ Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}

class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!isa<FloatType>(adaptor.getRhs().getType())) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cmpf currently only supported on "
"floats, not tensors/vectors thereof");
}

bool unordered = false;
emitc::CmpPredicate predicate;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: {
auto constant = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rewriter.getI1Type(),
rewriter.getBoolAttr(/*value=*/false));
rewriter.replaceOp(op, constant);
return success();
}
case arith::CmpFPredicate::OEQ:
unordered = false;
predicate = emitc::CmpPredicate::eq;
break;
case arith::CmpFPredicate::OGT:
unordered = false;
predicate = emitc::CmpPredicate::gt;
break;
case arith::CmpFPredicate::OGE:
unordered = false;
predicate = emitc::CmpPredicate::ge;
break;
case arith::CmpFPredicate::OLT:
unordered = false;
predicate = emitc::CmpPredicate::lt;
break;
case arith::CmpFPredicate::OLE:
unordered = false;
predicate = emitc::CmpPredicate::le;
break;
case arith::CmpFPredicate::ONE:
unordered = false;
predicate = emitc::CmpPredicate::ne;
break;
case arith::CmpFPredicate::ORD: {
// ordered, i.e. none of the operands is NaN
auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
adaptor.getRhs());
rewriter.replaceOp(op, cmp);
return success();
}
case arith::CmpFPredicate::UEQ:
unordered = true;
predicate = emitc::CmpPredicate::eq;
break;
case arith::CmpFPredicate::UGT:
unordered = true;
predicate = emitc::CmpPredicate::gt;
break;
case arith::CmpFPredicate::UGE:
unordered = true;
predicate = emitc::CmpPredicate::ge;
break;
case arith::CmpFPredicate::ULT:
unordered = true;
predicate = emitc::CmpPredicate::lt;
break;
case arith::CmpFPredicate::ULE:
unordered = true;
predicate = emitc::CmpPredicate::le;
break;
case arith::CmpFPredicate::UNE:
unordered = true;
predicate = emitc::CmpPredicate::ne;
break;
case arith::CmpFPredicate::UNO: {
// unordered, i.e. either operand is nan
auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
adaptor.getRhs());
rewriter.replaceOp(op, cmp);
return success();
}
case arith::CmpFPredicate::AlwaysTrue: {
auto constant = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rewriter.getI1Type(),
rewriter.getBoolAttr(/*value=*/true));
rewriter.replaceOp(op, constant);
return success();
}
}

// Compare the values naively
auto cmpResult =
rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
adaptor.getLhs(), adaptor.getRhs());

// Adjust the results for unordered/ordered semantics
if (unordered) {
auto isUnordered = createCheckIsUnordered(
rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
isUnordered, cmpResult);
return success();
}

auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
isOrdered, cmpResult);
return success();
}

private:
/// Return a value that is true if \p operand is NaN.
Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is NaN exactly when it compares unequal to itself.
return rewriter.create<emitc::CmpOp>(
loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
}

/// Return a value that is true if \p operand is not NaN.
Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is not NaN exactly when it compares equal to itself.
return rewriter.create<emitc::CmpOp>(
loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
}

/// Return a value that is true if the operands \p first and \p second are
/// unordered (i.e., at least one of them is NaN).
Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
Location loc, Value first, Value second) const {
auto firstIsNaN = isNaN(rewriter, loc, first);
auto secondIsNaN = isNaN(rewriter, loc, second);
return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
firstIsNaN, secondIsNaN);
}

/// Return a value that is true if the operands \p first and \p second are
/// both ordered (i.e., none one of them is NaN).
Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
Value first, Value second) const {
auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
firstIsNotNaN, secondIsNotNaN);
}
};

class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -463,6 +619,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
CmpFOpConversion,
CmpIOpConversion,
SelectOpConversion,
// Truncation is guaranteed for unsigned types.
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {

// -----

func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vector<5xi1> {
// expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
%t = arith.cmpf uno, %arg0, %arg1 : vector<5xf32>
return %t: vector<5xi1>
}

// -----

func.func @arith_cmpf_tensor(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xi1> {
// expected-error @+1 {{failed to legalize operation 'arith.cmpf'}}
%t = arith.cmpf uno, %arg0, %arg1 : tensor<5xf32>
return %t: tensor<5xi1>
}

// -----

func.func @arith_extsi_i1_to_i32(%arg0: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
%idx = arith.extsi %arg0 : i1 to i32
Expand Down
Loading

0 comments on commit 46672c1

Please sign in to comment.