Skip to content

Commit

Permalink
Integrate llvm-project @9372a3b70cf3969dac2d1a14cf41358205944e60 (ire…
Browse files Browse the repository at this point in the history
…e-org#17926)

Bumps llvm-project to
https://github.com/llvm/llvm-project/commits/266a5a9cb9daa96c1eeaebc18e10f5a37d638734

Still carrying revert:
iree-org/llvm-project@9372a3b

llvm/llvm-project#97903 Updated type conversion
argument materialization, so this PR includes minor bug fixes in Codegen
and Stream conversions after the change.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Co-authored-by: Matthias Springer <mspringer@nvidia.com>
  • Loading branch information
Max191 and matthias-springer authored Jul 17, 2024
1 parent 7ce8c8e commit 37a3db2
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 31 deletions.
17 changes: 10 additions & 7 deletions compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,20 @@ std::optional<Value> materializeCastToIllegal(OpBuilder &builder, Type type,
->getResult(0);
}

std::optional<Value> scalarToTensor(OpBuilder &builder, Type /*type*/,
std::optional<Value> scalarToTensor(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
if (llvm::isa<ShapedType>(inputs.front().getType())) {
return std::nullopt;
}
return builder
.create<tensor::FromElementsOp>(
loc, RankedTensorType::get({}, inputs.front().getType()),
inputs.front())
.getResult();
auto tensor =
builder
.create<tensor::FromElementsOp>(
loc, RankedTensorType::get({}, inputs.front().getType()),
inputs.front())
.getResult();
return builder.create<UnrealizedConversionCastOp>(loc, type, tensor)
.getResult(0);
}

} // namespace
Expand All @@ -77,7 +80,7 @@ RemoveSignTypeConverter::RemoveSignTypeConverter() {
addConversion(convertInteger);
addConversion(convertShapedType);

addArgumentMaterialization(materializeCastFromIllegal);
addArgumentMaterialization(materializeCastToIllegal);
addSourceMaterialization(materializeCastToIllegal);
addTargetMaterialization(materializeCastFromIllegal);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,22 @@ struct GenericTypeConversionPattern : public ConversionPattern {
for (Region &r : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
}
Operation *newOp = rewriter.create(state);

for (Region &newRegion : newOp->getRegions()) {
TypeConverter::SignatureConversion result(newRegion.getNumArguments());

if (failed(getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result))) {
newRegion.getArgumentTypes(), result))) {
return rewriter.notifyMatchFailure(op,
"argument type conversion failed");
}

rewriter.applySignatureConversion(&newRegion->front(), result,
rewriter.applySignatureConversion(&newRegion.front(), result,
typeConverter);
}

Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,11 @@ struct ConvertAllGatherOp
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);

rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
op, collectiveAttr, adaptor.getTarget(),
op, collectiveAttr, newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
/*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
/*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
/*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
/*source_length=*/newSourceCast.resourceSize, elementCount,
Expand Down Expand Up @@ -448,11 +448,11 @@ struct ConvertAllReduceOp
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);

rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
op, collectiveAttr, adaptor.getTarget(),
op, collectiveAttr, newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
/*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
/*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
/*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
/*source_length=*/newSourceCast.resourceSize, elementCount,
Expand Down Expand Up @@ -483,11 +483,11 @@ struct ConvertAllToAllOp
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);

rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
op, collectiveAttr, adaptor.getTarget(),
op, collectiveAttr, newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
/*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
/*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
/*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
/*source_length=*/newSourceCast.resourceSize, elementCount,
Expand Down Expand Up @@ -518,11 +518,11 @@ struct ConvertReduceScatterOp
consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter);

rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
op, collectiveAttr, adaptor.getTarget(),
op, collectiveAttr, newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
/*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
/*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
/*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
/*source_length=*/newSourceCast.resourceSize, elementCount,
Expand Down Expand Up @@ -567,11 +567,11 @@ struct ConvertCollectiveSendRecvOp
auto param = rewriter.create<arith::OrIOp>(op.getLoc(), hi, lo);

rewriter.replaceOpWithNewOp<IREE::Stream::AsyncCollectiveOp>(
op, collectiveAttr, adaptor.getTarget(),
op, collectiveAttr, newTargetCast.resource,
/*target_size=*/newTargetCast.resourceSize,
/*target_offset=*/zeroOffset,
/*target_end=*/newTargetCast.resourceSize,
/*target_length=*/newTargetCast.resourceSize, adaptor.getSource(),
/*target_length=*/newTargetCast.resourceSize, newSourceCast.resource,
/*source_size=*/newSourceCast.resourceSize,
/*source_offset=*/zeroOffset, /*source_end=*/newSourceCast.resourceSize,
/*source_length=*/newSourceCast.resourceSize, elementCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

// CHECK-LABEL: @optimizationBarrier
util.func public @optimizationBarrier(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: stream.async.transfer
// CHECK: %[[RESOURCE:.*]] = util.optimization_barrier %0
// CHECK: %[[SIZE:.*]] = stream.resource.size %1 : !stream.resource<*>
// CHECK-SAME: %[[ARG0:.+]]: !stream.resource<*>
// CHECK: %[[RESOURCE:.*]] = util.optimization_barrier %[[ARG0]]
// CHECK: %[[SIZE:.*]] = stream.resource.size %[[RESOURCE]] : !stream.resource<*>
// CHECK: util.return %[[RESOURCE]], %[[SIZE]] : !stream.resource<*>, index
%0 = util.optimization_barrier %arg0 : tensor<i32>
util.return %0 : tensor<i32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,31 @@ struct GenericResourcePattern : public ConversionPattern {
}
};

namespace {
struct OptimizationBarrierOpConversion
: public OpConversionPattern<IREE::Util::OptimizationBarrierOp> {
using OpConversionPattern<
IREE::Util::OptimizationBarrierOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands;
for (Value v : adaptor.getOperands()) {
if (isa<TensorType>(v.getType())) {
newOperands.push_back(
consumeTensorOperand(op.getLoc(), v, rewriter).resource);
} else {
newOperands.push_back(v);
}
}
rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(op,
newOperands);
return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// --iree-stream-conversion
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -228,13 +253,16 @@ struct ConvertToStreamPass final
auto resourceSize = inputs[1];
assert(inputs.size() == 2 &&
"expecting 2 operands (resource + size)");
Value cast = builder
.create<IREE::Stream::AsyncTransferOp>(
loc, resourceValue.getType(), resourceValue,
resourceSize, resourceSize,
/*source_affinity=*/nullptr,
/*result_affinity=*/nullptr)
.getResult();
return builder
.create<IREE::Stream::AsyncTransferOp>(
loc, resourceValue.getType(), resourceValue, resourceSize,
resourceSize,
/*source_affinity=*/nullptr,
/*result_affinity=*/nullptr)
.getResult();
.create<UnrealizedConversionCastOp>(loc, resultType, cast)
.getResult(0);
});

populateUtilConversionPatterns(context, conversionTarget, typeConverter,
Expand All @@ -252,6 +280,8 @@ struct ConvertToStreamPass final
conversionTarget.markUnknownOpDynamicallyLegal(
[&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); });
patterns.insert<GenericResourcePattern>(context, typeConverter);
patterns.insert<OptimizationBarrierOpConversion>(typeConverter, context,
/*benefit=*/2);

// NOTE: we allow ops that we don't know about to allow custom dialects
// that don't need anything Stream-specific to pass through.
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project
Submodule llvm-project updated 106 files

0 comments on commit 37a3db2

Please sign in to comment.