Skip to content

Commit

Permalink
Bump LLVM (#907)
Browse files Browse the repository at this point in the history
Also, retires more method cast variants and updates omp tests with
updated loop constructs.
Fixes LLVM build check after refactor.
  • Loading branch information
adam-smnk authored Apr 25, 2024
1 parent ff22c3b commit d9fd677
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 40 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4e0b8eae4cb4328f98e6b748c31050a704d378f6
fe47e8ff3ae7fc8975eaade6bfa6679737c28b93
4 changes: 2 additions & 2 deletions include/TPP/IR/StructuredOpMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct HasStaticShape {

bool operator()(OpOperand *operand, Operation *op) const {
auto operandType = operand->get().getType();
if (auto shapedType = operandType.dyn_cast_or_null<ShapedType>()) {
if (auto shapedType = dyn_cast_or_null<ShapedType>(operandType)) {
if (!shapedType.hasStaticShape())
return false;
if (shape) {
Expand All @@ -188,7 +188,7 @@ struct HasStaticStrides {
bool operator()(OpOperand *operand, Operation *op) const {
auto operandType = operand->get().getType();
SmallVector<int64_t> strides;
if (auto memRefType = operandType.dyn_cast_or_null<MemRefType>()) {
if (auto memRefType = dyn_cast_or_null<MemRefType>(operandType)) {
int64_t offset;
if (failed(getStridesAndOffset(memRefType, strides, offset)))
return false;
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ static SmallVector<Type> extractInvokeOperandTypes(OpBuilder &builder,
results.push_back(integer64);
for (Value operand : operands) {
Type operandType = operand.getType();
if (auto memrefType = operandType.dyn_cast<MemRefType>()) {
if (auto memrefType = dyn_cast<MemRefType>(operandType)) {
// TODO: non-POD will require an LLVMTypeConverter.
Type basePtrType = LLVM::LLVMPointerType::get(builder.getContext());
results.push_back(basePtrType);
Expand All @@ -65,7 +65,7 @@ static SmallVector<Value> getOperands(OpBuilder &builder, Location loc,
builder.create<arith::ConstantOp>(loc, integer64, dataTypeAttr));

for (Value operand : operands) {
auto memrefType = operand.getType().dyn_cast<MemRefType>();
auto memrefType = dyn_cast<MemRefType>(operand.getType());
if (!memrefType) {
res.push_back(operand);
continue;
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/GPU/GpuVulkanAbi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static Type getVulkanTypeWrapper(Type type,
assert(!isa<TensorType>(type) && "Tensors are not supported by Vulkan");

// Buffers are already Vulkan compatible.
if (auto memrefType = type.dyn_cast<MemRefType>())
if (auto memrefType = dyn_cast<MemRefType>(type))
return FlattenMemrefType(memrefType);

// Index has to be converted to a fixed-size integer.
Expand Down Expand Up @@ -120,7 +120,7 @@ static Value FlattenMemrefOperand(Value operand, RewriterBase &rewriter) {
auto loc = operand.getLoc();

// Ignore non-memref types and 1D buffers.
auto memrefType = operand.getType().dyn_cast<MemRefType>();
auto memrefType = dyn_cast<MemRefType>(operand.getType());
if (!memrefType || memrefType.getRank() <= 1)
return operand;

Expand Down
6 changes: 3 additions & 3 deletions lib/TPP/Transforms/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,21 +410,21 @@ struct ConvertToForAll : public OpRewritePattern<scf::ForOp> {
Value destVal = mapping.lookup(insertSlice.getDest());
SmallVector<OpFoldResult> offsets;
for (OpFoldResult offset : insertSlice.getMixedOffsets()) {
if (auto valueOffset = offset.dyn_cast<Value>())
if (auto valueOffset = dyn_cast<Value>(offset))
offsets.push_back(mapping.lookupOrDefault(valueOffset));
else
offsets.push_back(offset);
}
SmallVector<OpFoldResult> sizes;
for (OpFoldResult size : insertSlice.getMixedSizes()) {
if (auto valueSize = size.dyn_cast<Value>())
if (auto valueSize = dyn_cast<Value>(size))
sizes.push_back(mapping.lookupOrDefault(valueSize));
else
sizes.push_back(size);
}
SmallVector<OpFoldResult> strides;
for (OpFoldResult stride : insertSlice.getMixedStrides()) {
if (auto valueStride = stride.dyn_cast<Value>())
if (auto valueStride = dyn_cast<Value>(stride))
strides.push_back(mapping.lookupOrDefault(valueStride));
else
strides.push_back(stride);
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/Utils/ValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ FailureOr<SmallVector<int64_t>> getStaticStrides(Value value) {

std::pair<Value, Value> getPtrAndOffset(OpBuilder &builder, Value operand,
Location loc) {
auto memrefType = operand.getType().dyn_cast<MemRefType>();
auto memrefType = dyn_cast<MemRefType>(operand.getType());
assert(memrefType && "Expect a memref value");
MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType());
Type basePtrType = builder.getIndexType();
Expand Down
2 changes: 1 addition & 1 deletion scripts/buildkite/build_llvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ fi

# Check LLVM
echo "--- CHECK"
echo_run ninja -C ${LLVM_BUILD_DIR} check-tpp
echo_run ninja -C ${LLVM_BUILD_DIR} check-all
if [ $? != 0 ]; then
exit 1
fi
Expand Down
13 changes: 7 additions & 6 deletions test/Passes/pass-convert-gemm-to-parallel-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ module {
// CHECK: %[[c0_i64:.*]] = arith.constant 0 : i64
// CHECK: %[[temp0:.*]] = call @xsmm_brgemm_dispatch(%[[c1_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c1024_i64]], %[[c1024_i64]], %[[c0_i64]])
// CHECK: omp.parallel {
// CHECK: omp.wsloop for (%[[ARG3:.*]], %[[ARG4:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) {
// CHECK: memref.alloca_scope {
// CHECK: scf.for %[[ARG5:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index
// CHECK: scf.for %[[ARG6:.*]] = %[[c0]] to %[[c8]] step %[[c1]] {
// CHECK: %[[temp2:.*]] = arith.addi %[[ARG6]], %[[ARG4]] : index
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest (%[[ARG3:.*]], %[[ARG4:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) {
// CHECK: memref.alloca_scope {
// CHECK: scf.for %[[ARG5:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
// CHECK: %[[temp1:.*]] = arith.addi %[[ARG5]], %[[ARG3]] : index
// CHECK: scf.for %[[ARG6:.*]] = %[[c0]] to %[[c8]] step %[[c1]] {
// CHECK: %[[temp2:.*]] = arith.addi %[[ARG6]], %[[ARG4]] : index

41 changes: 21 additions & 20 deletions test/Passes/pass-convert-mlp-to-parallel-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,27 @@ module {
//CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
//CHECK: %[[temp0:.*]] = call @xsmm_fused_brgemm_dispatch(%[[c1_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c32_i64]], %[[c1024_i64]], %[[c1024_i64]], %[[c0_i64]], %[[c0_i64]], %[[c5_i64]], %[[c4_i64]], %[[c1_i64]])
//CHECK: omp.parallel {
//CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index
//CHECK: omp.wsloop {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index
//CHECK: omp.parallel {
//CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index
//CHECK: omp.wsloop {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index
//CHECK: omp.parallel {
//CHECK: omp.wsloop for (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index


//CHECK: omp.wsloop {
//CHECK: omp.loop_nest (%[[ARG10:.*]], %[[ARG11:.*]]) : index = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c16]]) {
//CHECK: memref.alloca_scope {
//CHECK: scf.for %[[ARG12:.*]] = %[[c0]] to %[[c2]] step %[[c1]] {
//CHECK: %[[temp1:.*]] = arith.addi %[[ARG12]], %[[ARG10]] : index
//CHECK: scf.for %[[ARG13:.*]] = %[[c0]] to %[[c16]] step %[[c1]] {
//CHECK: %[[temp2:.*]] = arith.addi %[[ARG13]], %[[ARG11]] : index

4 changes: 2 additions & 2 deletions tools/tpp-run/MLIRBench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
auto constant = dyn_cast<arith::ConstantOp>(op);
if (!constant)
continue;
auto cstType = constant.getType().dyn_cast<ShapedType>();
auto cstType = dyn_cast<ShapedType>(constant.getType());
if (!cstType)
continue;
auto newAttr = replaceSplat(cstType, constant.getValueAttr());
Expand Down Expand Up @@ -318,7 +318,7 @@ void MLIRBench::printMean(Value mean) {

void MLIRBench::printVector(Value vector) {
auto op = vector;
auto vectorValue = vector.getType().dyn_cast<VectorType>();
auto vectorValue = dyn_cast<VectorType>(vector.getType());
if (vectorValue.getElementType().isBF16()) {
VectorType vecType =
VectorType::get(vectorValue.getShape(), builder.getF32Type());
Expand Down

0 comments on commit d9fd677

Please sign in to comment.