diff --git a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp b/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp index 6b2ef75fc45e..947fcba2886b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp @@ -56,17 +56,20 @@ std::optional materializeCastToIllegal(OpBuilder &builder, Type type, ->getResult(0); } -std::optional scalarToTensor(OpBuilder &builder, Type /*type*/, +std::optional scalarToTensor(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); if (llvm::isa(inputs.front().getType())) { return std::nullopt; } - return builder - .create( - loc, RankedTensorType::get({}, inputs.front().getType()), - inputs.front()) - .getResult(); + auto tensor = + builder + .create( + loc, RankedTensorType::get({}, inputs.front().getType()), + inputs.front()) + .getResult(); + return builder.create(loc, type, tensor) + .getResult(0); } } // namespace @@ -77,7 +80,7 @@ RemoveSignTypeConverter::RemoveSignTypeConverter() { addConversion(convertInteger); addConversion(convertShapedType); - addArgumentMaterialization(materializeCastFromIllegal); + addArgumentMaterialization(materializeCastToIllegal); addSourceMaterialization(materializeCastToIllegal); addTargetMaterialization(materializeCastFromIllegal); } diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp index 13958d5fab58..cb8a32bb80fa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp @@ -164,19 +164,22 @@ struct GenericTypeConversionPattern : public ConversionPattern { for (Region &r : op->getRegions()) { Region *newRegion = state.addRegion(); rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin()); - TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + } + Operation *newOp = rewriter.create(state); + + for (Region &newRegion : newOp->getRegions()) { + TypeConverter::SignatureConversion result(newRegion.getNumArguments()); if (failed(getTypeConverter()->convertSignatureArgs( - newRegion->getArgumentTypes(), result))) { + newRegion.getArgumentTypes(), result))) { return rewriter.notifyMatchFailure(op, "argument type conversion failed"); } - rewriter.applySignatureConversion(&newRegion->front(), result, + rewriter.applySignatureConversion(&newRegion.front(), result, typeConverter); } - Operation *newOp = rewriter.create(state); rewriter.replaceOp(op, newOp->getResults()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 0942148c0e7c..060aeb897620 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -413,11 +413,11 @@ struct ConvertAllGatherOp consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, adaptor.getTarget(), + op, collectiveAttr, newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(), + /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, /*source_length=*/newSourceCast.resourceSize, elementCount, @@ -448,11 +448,11 @@ struct ConvertAllReduceOp consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, adaptor.getTarget(), + op, collectiveAttr, newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(), + /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, /*source_length=*/newSourceCast.resourceSize, elementCount, @@ -483,11 +483,11 @@ struct ConvertAllToAllOp consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, adaptor.getTarget(), + op, collectiveAttr, newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(), + /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, /*source_length=*/newSourceCast.resourceSize, elementCount, @@ -518,11 +518,11 @@ struct ConvertReduceScatterOp consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); rewriter.replaceOpWithNewOp( - op, collectiveAttr, adaptor.getTarget(), + op, collectiveAttr, newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(), + /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, /*source_length=*/newSourceCast.resourceSize, elementCount, @@ -567,11 +567,11 @@ struct ConvertCollectiveSendRecvOp auto param = rewriter.create(op.getLoc(), hi, lo); rewriter.replaceOpWithNewOp( - op, collectiveAttr, adaptor.getTarget(), + op, collectiveAttr, newTargetCast.resource, /*target_size=*/newTargetCast.resourceSize, /*target_offset=*/zeroOffset, /*target_end=*/newTargetCast.resourceSize, - /*target_length=*/newTargetCast.resourceSize, adaptor.getSource(), + /*target_length=*/newTargetCast.resourceSize, newSourceCast.resource, /*source_size=*/newSourceCast.resourceSize, /*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize, /*source_length=*/newSourceCast.resourceSize, elementCount, diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/compiler_hints.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/compiler_hints.mlir index f12a2ad45c12..c778fbf1e502 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/compiler_hints.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/compiler_hints.mlir @@ -2,9 +2,9 @@ // CHECK-LABEL: @optimizationBarrier util.func public @optimizationBarrier(%arg0: tensor) -> tensor { - // CHECK: stream.async.transfer - // CHECK: %[[RESOURCE:.*]] = util.optimization_barrier %0 - // CHECK: %[[SIZE:.*]] = stream.resource.size %1 : !stream.resource<*> + // CHECK-SAME: %[[ARG0:.+]]: !stream.resource<*> + // CHECK: %[[RESOURCE:.*]] = util.optimization_barrier %[[ARG0]] + // CHECK: %[[SIZE:.*]] = stream.resource.size %[[RESOURCE]] : !stream.resource<*> // CHECK: util.return %[[RESOURCE]], %[[SIZE]] : !stream.resource<*>, index %0 = util.optimization_barrier %arg0 : tensor util.return %0 : tensor diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 46c09fafad31..11873a2c73d7 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -181,6 +181,31 @@ struct GenericResourcePattern : public ConversionPattern { } }; +namespace { +struct OptimizationBarrierOpConversion + : public OpConversionPattern { + using OpConversionPattern< + IREE::Util::OptimizationBarrierOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newOperands; + for (Value v : adaptor.getOperands()) { + if (isa(v.getType())) { + newOperands.push_back( + consumeTensorOperand(op.getLoc(), v, rewriter).resource); + } else { + newOperands.push_back(v); + } + } + rewriter.replaceOpWithNewOp(op, + newOperands); + return success(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // --iree-stream-conversion //===----------------------------------------------------------------------===// @@ -228,13 +253,16 @@ struct ConvertToStreamPass final auto resourceSize = inputs[1]; assert(inputs.size() == 2 && "expecting 2 operands (resource + size)"); + Value cast = builder + .create( + loc, resourceValue.getType(), resourceValue, + resourceSize, resourceSize, + /*source_affinity=*/nullptr, + /*result_affinity=*/nullptr) + .getResult(); return builder - .create( - loc, resourceValue.getType(), resourceValue, resourceSize, - resourceSize, - /*source_affinity=*/nullptr, - /*result_affinity=*/nullptr) - .getResult(); + .create(loc, resultType, cast) + .getResult(0); }); populateUtilConversionPatterns(context, conversionTarget, typeConverter, @@ -252,6 +280,8 @@ struct ConvertToStreamPass final conversionTarget.markUnknownOpDynamicallyLegal( [&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); }); patterns.insert(context, typeConverter); + patterns.insert(typeConverter, context, + /*benefit=*/2); // NOTE: we allow ops that we don't know about to allow custom dialects // that don't need anything Stream-specific to pass through. diff --git a/third_party/llvm-project b/third_party/llvm-project index 1f11b9fed233..9372a3b70cf3 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 1f11b9fed2337ea24d137ff82fec75bddcd85b3c +Subproject commit 9372a3b70cf3969dac2d1a14cf41358205944e60