From 992dad34ac36e7f8c32e6ebbc9612264678c9aee Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Tue, 17 Dec 2024 09:39:38 +0100 Subject: [PATCH] EmitC: Allow casts between opaque and float types (#428) * EmitC: Allow casts between opaque and float types * Use EmitC cast compatibility check * Allow opaque types in casts (also for array types) --- mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 4 ++-- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 4 ++++ mlir/test/Dialect/EmitC/ops.mlir | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 938bc73d439969..6a84ead33d5c2b 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -757,7 +757,7 @@ class TruncFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!castOp.areCastCompatible(operandType, dstType)) + if (!emitc::CastOp::areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, @@ -787,7 +787,7 @@ class ExtFConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); - if (!castOp.areCastCompatible(operandType, dstType)) + if (!emitc::CastOp::areCastCompatible(operandType, dstType)) return rewriter.notifyMatchFailure(castOp, "cast-incompatible types"); rewriter.replaceOpWithNewOp(castOp, dstType, diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 8ed1d609b91818..66421c2f6fff66 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -313,6 +313,10 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); + // Opaque types are always allowed + if (isa(input) || isa(output)) + return true; + // Cast to array is only possible from an array if (isa(input) != isa(output)) return false; diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index ad70ea61cb2958..80a33b2b9621fe 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -36,11 +36,15 @@ emitc.func private @extern(i32) attributes {specifiers = ["extern"]} func.func @cast(%arg0: i32) { %1 = emitc.cast %arg0: i32 to f32 + %2 = emitc.cast %1: f32 to !emitc.opaque<"some type"> + %3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.size_t return } func.func @cast_array(%arg : !emitc.array<4xf32>) { %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref + %2 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.opaque<"some type"> + %3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.array<4xf32> ref return }