Skip to content

Commit

Permalink
Bump LLVM (#910)
Browse files Browse the repository at this point in the history
Updates LLVM version, fixes more deprecated casts, and updates IR
syntax.
  • Loading branch information
adam-smnk authored May 14, 2024
1 parent cf0d840 commit 7ee8e7e
Show file tree
Hide file tree
Showing 28 changed files with 83 additions and 90 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
fe47e8ff3ae7fc8975eaade6bfa6679737c28b93
61d4ca872215d3dfff0b3c92151dcbdc546a0aab
5 changes: 3 additions & 2 deletions include/TPP/Transforms/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef TPP_TRANSFORMS_UTILS_TRANSFORMUTILS_H
#define TPP_TRANSFORMS_UTILS_TRANSFORMUTILS_H

#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

Expand Down Expand Up @@ -94,11 +95,11 @@ void populateScfForToForAllRewritePattern(RewritePatternSet &patterns);

// Given a value `val` expand its shape based on `reassociationMap`.
Value expand(OpBuilder &builder, Location loc, Value val, Type newType,
ArrayAttr reassociationMap);
ArrayRef<ReassociationIndices> reassociationMap);

// Given a value `val` collapse its shape based on `reassociationMap`.
Value collapse(OpBuilder &builder, Location loc, Value val, Type newType,
ArrayAttr reassociationMap);
ArrayRef<ReassociationIndices> reassociationMap);

} // namespace utils
} // namespace linalgx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,8 @@ static Value makeOperandShapeRowBroadCastable(RewriterBase &rewriter,
auto reassoc =
getReassociationIndicesForReshape(shapedOperand, newShapedOperand);
assert(reassoc.has_value());
return linalgx::utils::expand(
rewriter, loc, operand, newShapedOperand,
getReassociationIndicesAttribute(rewriter, *reassoc));
return linalgx::utils::expand(rewriter, loc, operand, newShapedOperand,
*reassoc);
}

// Convert linalg.generic to xsmm unary relu or identity op.
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void addKindOperand(RewriterBase &rewriter,
static int64_t getOredFlags(ArrayAttr flags) {
int64_t oredFlag = 0;
for (auto flag : flags) {
int64_t intAttr = flag.template dyn_cast<IntegerAttr>().getInt();
int64_t intAttr = dyn_cast<IntegerAttr>(flag).getInt();
// LIBXSMM is col-major, swap A and B flags.
if (auto gemmFlag = dyn_cast_or_null<xsmm::GemmFlagsAttr>(flag)) {
if (gemmFlag.getValue() == GemmFlags::VNNI_A)
Expand Down Expand Up @@ -412,7 +412,7 @@ struct ConvertFusedBrgemmOp : public OpRewritePattern<FusedBrgemmDispatchOp> {
auto isFusedAdd = dispatchOp.getBinaryKind() == xsmm::BinaryKind::ADD;
auto binaryFlags = dispatchOp.getBinaryFlags();
if (isFusedAdd && (binaryFlags.size() != 1 ||
binaryFlags[0].cast<BinaryFlagsAttr>().getValue() !=
cast<BinaryFlagsAttr>(binaryFlags[0]).getValue() !=
BinaryFlags::BCAST_COL_IN_0)) {
return failure();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Dialect/Perf/PerfOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ ParseResult BenchOp::parse(OpAsmParser &parser, OperationState &result) {
if (types.size() != 1)
return parser.emitError(locs[0], "expect one types for argument");
if (parser.resolveOperand(operands[0], types[0], result.operands) ||
!types[0].isa<IntegerType>())
!isa<IntegerType>(types[0]))
return parser.emitError(locs[0], "expect integer number of iterations");
operands.clear();
types.clear();
Expand Down
6 changes: 3 additions & 3 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ static void printerFlagsImpl(OpAsmPrinter &printer,
const std::string_view &flagsName) {
printer << " " << flagsName << " = (";
llvm::interleaveComma(fn(), printer, [&](auto &flag) {
printer << stringifyEnum(flag.template cast<AttrTy>().getValue());
printer << stringifyEnum(cast<AttrTy>(flag).getValue());
});
printer << ") ";
}
Expand Down Expand Up @@ -385,7 +385,7 @@ LogicalResult FusedBrgemmDispatchOp::verify() {
if (unaryKind == xsmm::UnaryKind::NONE) {
auto unaryFlags = getUnaryFlags();
if (unaryFlags.size() != 1 ||
unaryFlags[0].cast<xsmm::UnaryFlagsAttr>().getValue() !=
cast<xsmm::UnaryFlagsAttr>(unaryFlags[0]).getValue() !=
xsmm::UnaryFlags::NONE) {
return emitOpError() << "invalid unary flags for kind none";
}
Expand All @@ -394,7 +394,7 @@ LogicalResult FusedBrgemmDispatchOp::verify() {
if (binaryKind == xsmm::BinaryKind::NONE) {
auto binaryFlags = getBinaryFlags();
if (binaryFlags.size() != 1 ||
binaryFlags[0].cast<xsmm::BinaryFlagsAttr>().getValue() !=
cast<xsmm::BinaryFlagsAttr>(binaryFlags[0]).getValue() !=
xsmm::BinaryFlags::NONE) {
return emitOpError() << "invalid binary flags for kind none";
}
Expand Down
3 changes: 1 addition & 2 deletions lib/TPP/GPU/LinalgToGpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,7 @@ struct ConvertGemmLikeToGpu : public OpRewritePattern<LinalgOpTy> {
}

// Ensure that reduction dimension tiling also works for smaller workloads.
auto aType =
gemmLikeOp.getDpsInputs()[0].getType().template cast<ShapedType>();
auto aType = cast<ShapedType>(gemmLikeOp.getDpsInputs()[0].getType());
auto kDim = aType.getShape().back();
auto kTile = kDim < options.kTile ? kDim : options.kTile;

Expand Down
5 changes: 2 additions & 3 deletions lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ struct RankReducedExtractSliceOp
reassociation->size() == static_cast<size_t>(resultType.getRank())) {
return failure();
}
auto rankReducedType =
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
strides)
.cast<RankedTensorType>();
strides));

Location loc = sliceOp.getLoc();
Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
Expand Down
12 changes: 4 additions & 8 deletions lib/TPP/Transforms/RewriteConvsToMatmulOrBrgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,13 @@ struct CollapseFilterAndImage : OpRewritePattern<linalg::GenericOp> {
return failure();

Value collapsedImage = linalgx::utils::collapse(
rewriter, loc, image->get(), newImageType,
getReassociationIndicesAttribute(rewriter, *reassociationImage));
rewriter, loc, image->get(), newImageType, *reassociationImage);

Value collapsedFilter = linalgx::utils::collapse(
rewriter, loc, filter->get(), newFilterType,
getReassociationIndicesAttribute(rewriter, *reassociationFilter));
rewriter, loc, filter->get(), newFilterType, *reassociationFilter);

Value collapsedOutput = linalgx::utils::collapse(
rewriter, loc, output.get(), newOutputType,
getReassociationIndicesAttribute(rewriter, *reassociationOutput));
rewriter, loc, output.get(), newOutputType, *reassociationOutput);

linalg::GenericOp replacementOp = rewriter.create<linalg::GenericOp>(
loc, newOutputType, ValueRange{collapsedImage, collapsedFilter},
Expand All @@ -464,8 +461,7 @@ struct CollapseFilterAndImage : OpRewritePattern<linalg::GenericOp> {
if (!reassociationOutput)
return failure();
Value resExpanded = linalgx::utils::expand(
rewriter, loc, res, output.get().getType(),
getReassociationIndicesAttribute(rewriter, *reassociationOutput));
rewriter, loc, res, output.get().getType(), *reassociationOutput);
rewriter.replaceOp(linalgOp, resExpanded);
return success();
}
Expand Down
15 changes: 7 additions & 8 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,9 +804,9 @@ struct PackAsReshape : public OpRewritePattern<tensor::PackOp> {
getReassociationIndicesForReshape(sourceType, packOp.getDestType());
if (!reassoc)
return failure();
Value expanded = linalgx::utils::expand(
rewriter, packOp.getLoc(), packOp.getSource(), packOp.getDestType(),
getReassociationIndicesAttribute(rewriter, *reassoc));
Value expanded =
linalgx::utils::expand(rewriter, packOp.getLoc(), packOp.getSource(),
packOp.getDestType(), *reassoc);
rewriter.replaceOp(packOp, expanded);
return success();
}
Expand Down Expand Up @@ -835,8 +835,7 @@ struct UnPackAsReshape : public OpRewritePattern<tensor::UnPackOp> {
if (!reassoc)
return failure();
Value collapse = linalgx::utils::collapse(
rewriter, unPackOp.getLoc(), unPackOp.getSource(), destType,
getReassociationIndicesAttribute(rewriter, *reassoc));
rewriter, unPackOp.getLoc(), unPackOp.getSource(), destType, *reassoc);
rewriter.replaceOp(unPackOp, collapse);
return success();
}
Expand Down Expand Up @@ -885,9 +884,9 @@ struct PackOfReshape : public OpRewritePattern<tensor::PackOp> {
packOp.getMixedTiles(), packOp.getPaddingValue(),
packOp.getOuterDimsPerm());

Value expanded = linalgx::utils::expand(
rewriter, packOp.getLoc(), packedVal, packOp.getDestType(),
getReassociationIndicesAttribute(rewriter, *reassocExpand));
Value expanded =
linalgx::utils::expand(rewriter, packOp.getLoc(), packedVal,
packOp.getDestType(), *reassocExpand);
rewriter.replaceOp(packOp, expanded);

return success();
Expand Down
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace utils {

// Given a value `val` expand it's shape based on `reassociationMap`.
Value expand(OpBuilder &builder, Location loc, Value val, Type newType,
ArrayAttr reassociationMap) {
ArrayRef<ReassociationIndices> reassociationMap) {
OpBuilder::InsertionGuard guard(builder);
if (newType == val.getType())
return val;
Expand All @@ -44,7 +44,7 @@ Value expand(OpBuilder &builder, Location loc, Value val, Type newType,

// Given a value `val` collapse it's shape based on `reassociationMap`.
Value collapse(OpBuilder &builder, Location loc, Value val, Type newType,
ArrayAttr reassociationMap) {
ArrayRef<ReassociationIndices> reassociationMap) {
if (newType == val.getType())
return val;
if (isa<RankedTensorType>(newType)) {
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/LinalgToXsmm/linalg-to-binary.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ func.func @mul_bcast_row_in0(%arg0: memref<10xf32>, %arg1: memref<10x10xf32>) {

// CHECK-LABEL: mul_bcast_row_in0
// CHECK-SAME: %[[ARG0:.+]]: memref<10xf32>, %[[ARG1:.+]]: memref<10x10xf32>
// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] : memref<10xf32> into memref<10x1xf32>
// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [10, 1] : memref<10xf32> into memref<10x1xf32>
// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [10, 10, 1, 10, 10] flags = (bcast_row_in0) data_type = f32
// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[EXP]], %[[ARG1]], %[[ARG1]])

Expand All @@ -716,6 +716,6 @@ func.func @mul_bcast_row_in1(%arg0: memref<10xf32>, %arg1: memref<10x10xf32>) {

// CHECK-LABEL: mul_bcast_row_in1
// CHECK-SAME: %[[ARG0:.+]]: memref<10xf32>, %[[ARG1:.+]]: memref<10x10xf32>
// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] : memref<10xf32> into memref<10x1xf32>
// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [10, 1] : memref<10xf32> into memref<10x1xf32>
// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [10, 10, 10, 1, 10] flags = (bcast_row_in1) data_type = f32
// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG1]], %[[EXP]], %[[ARG1]])
8 changes: 4 additions & 4 deletions test/Conversion/LinalgToXsmm/linalg-to-unary.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func.func @identity_3(%arg0: memref<128x1xf32>, %arg1: memref<128x512xf32>) {

func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>,
%arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) {
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]]
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32]
: memref<32x32xbf16, strided<[512, 1], offset: ?>>
into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>
linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>)
Expand All @@ -315,7 +315,7 @@ func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>

func.func @not_vnni_packing(%arg0 : memref<32x32xf32, strided<[512, 1], offset: ?>>,
%arg1: memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) {
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]]
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32]
: memref<32x32xf32, strided<[512, 1], offset: ?>>
into memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>
linalg.transpose ins(%expand_shape : memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>)
Expand Down Expand Up @@ -359,7 +359,7 @@ func.func @vnni_packing_1(%arg1: memref<128x128xbf16>, %arg2: memref<4x4x16x32x2
: memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1]
: memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview [[0, 1], [2]]
%expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape[16, 2, 32]
: memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>
linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>)
outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>)
Expand Down Expand Up @@ -417,7 +417,7 @@ func.func @identity_5(%arg0 : memref<10xf32>, %arg1 : memref<10x10xf32>) {

// CHECK-LABEL: identity_5
// CHECK-SAME: %[[ARG0:.+]]: memref<10xf32>, %[[ARG1:.+]]: memref<10x10xf32>
// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] : memref<10xf32> into memref<10x1xf32>
// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [10, 1] : memref<10xf32> into memref<10x1xf32>
// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [10, 10, 1, 10] flags = (bcast_row) data_type = f32
// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[EXP]], %[[ARG1]])

Expand Down
6 changes: 3 additions & 3 deletions test/GPU/CUDA/Integration/wmma/pack-brgemm-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

func.func @entry(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf16>) -> memref<32x32xf16> {
%alloc = gpu.alloc() {alignment = 64 : i64} : memref<2x2x16x16xf16>
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<32x32xf16> into memref<2x16x2x16xf16>
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [2, 16, 2, 16] : memref<32x32xf16> into memref<2x16x2x16xf16>
%alloc_0 = gpu.alloc() {alignment = 64 : i64} : memref<2x2x16x16xf16>
linalg.transpose ins(%expand_shape : memref<2x16x2x16xf16>) outs(%alloc_0 : memref<2x2x16x16xf16>) permutation = [0, 2, 1, 3]
%expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2, 3]] : memref<32x32xf16> into memref<2x16x2x16xf16>
%expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2, 3]] output_shape [2, 16, 2, 16] : memref<32x32xf16> into memref<2x16x2x16xf16>
%alloc_2 = gpu.alloc() {alignment = 64 : i64} : memref<2x2x16x16xf16>
linalg.transpose ins(%expand_shape_1 : memref<2x16x2x16xf16>) outs(%alloc_2 : memref<2x2x16x16xf16>) permutation = [2, 0, 1, 3]
%expand_shape_3 = memref.expand_shape %arg2 [[0, 1], [2, 3]] : memref<32x32xf16> into memref<2x16x2x16xf16>
%expand_shape_3 = memref.expand_shape %arg2 [[0, 1], [2, 3]] output_shape [2, 16, 2, 16] : memref<32x32xf16> into memref<2x16x2x16xf16>
linalg.transpose ins(%expand_shape_3 : memref<2x16x2x16xf16>) outs(%alloc : memref<2x2x16x16xf16>) permutation = [0, 2, 1, 3]
scf.forall (%arg3, %arg4) in (2, 2) {
%subview = memref.subview %alloc_0[%arg3, 0, 0, 0] [1, 2, 16, 16] [1, 1, 1, 1] : memref<2x2x16x16xf16> to memref<2x16x16xf16, strided<[256, 16, 1], offset: ?>>
Expand Down
2 changes: 1 addition & 1 deletion test/Integration/relayout-more-interesting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func.func @entry() {
%2 = bufferization.alloc_tensor() : tensor<6x16xf32>
%3 = linalg.copy ins(%d: tensor<6x16xf32>) outs(%2: tensor<6x16xf32>) -> tensor<6x16xf32>
%4 = tensor.collapse_shape %3 [[0, 1]] : tensor<6x16xf32> into tensor<96xf32>
%5 = tensor.expand_shape %4 [[0, 1, 2, 3]] : tensor<96xf32> into tensor<3x8x2x2xf32>
%5 = tensor.expand_shape %4 [[0, 1, 2, 3]] output_shape [3, 8, 2, 2] : tensor<96xf32> into tensor<3x8x2x2xf32>
%v1 = vector.transfer_read %5[%c0, %c0, %c0, %c0], %d1 : tensor<3x8x2x2xf32>, vector<3x8x2x2xf32>
//
// CHECK: ( ( ( ( 1.1, 2.1 ), ( 3.1, 4.1 ) ),
Expand Down
8 changes: 4 additions & 4 deletions test/Integration/xsmm-strided-brgemm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
#map3 = affine_map<(i, ii, j, jj) -> (i, ii, j, jj)>

func.func @matmul_static(%A : !A_tensor_t, %B : !B_tensor_t, %C : !C_tensor_t, %D : !D_tensor_t) {
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] :
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] output_shape[2, 2, 2, 4] :
!A_tensor_t into tensor<2x2x2x4xf32>
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] :
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] output_shape[2, 4, 8, 2] :
!B_tensor_t into tensor<2x4x8x2xf32>
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] :
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] output_shape[2, 2, 8, 2] :
!C_tensor_t into tensor<2x2x8x2xf32>
%D_exp = tensor.expand_shape %D [[0, 1], [2, 3]] :
%D_exp = tensor.expand_shape %D [[0, 1], [2, 3]] output_shape[2, 2, 8, 2] :
!D_tensor_t into tensor<2x2x8x2xf32>

// IR-DAG: %[[C1:.+]] = arith.constant 1 : i64
Expand Down
6 changes: 3 additions & 3 deletions test/Integration/xsmm-strided-brgemm1.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

// IR-LABEL: matmul_static
func.func @matmul_static(%A : !A_tensor_t, %B : !B_tensor_t, %C : !C_tensor_t) {
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] :
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] output_shape[2, 8, 2, 4] :
!A_tensor_t into tensor<2x8x2x4xf32>
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] :
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] output_shape[2, 4, 8, 2] :
!B_tensor_t into tensor<2x4x8x2xf32>
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] :
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] output_shape[2, 2, 8, 2] :
!C_tensor_t into tensor<2x2x8x2xf32>

%cst_fill = arith.constant 0.0 : f32
Expand Down
6 changes: 3 additions & 3 deletions test/Integration/xsmm-strided-brgemm2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

// IR-LABEL: matmul_static
func.func @matmul_static(%A : !A_tensor_t, %B : !B_tensor_t, %C : !C_tensor_t) {
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] :
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] output_shape[2, 2, 2, 4] :
!A_tensor_t into tensor<2x2x2x4xf32>
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] :
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] output_shape[2, 4, 2, 8] :
!B_tensor_t into tensor<2x4x2x8xf32>
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] :
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] output_shape[2, 2, 2, 8] :
!C_tensor_t into tensor<2x2x2x8xf32>

%cst_fill = arith.constant 0.0 : f32
Expand Down
6 changes: 3 additions & 3 deletions test/Integration/xsmm-strided-brgemm3.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

// IR-LABEL: matmul_static
func.func @matmul_static(%A : !A_tensor_t, %B : !B_tensor_t, %C : !C_tensor_t) {
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] :
%A_exp = tensor.expand_shape %A [[0, 1], [2, 3]] output_shape[2, 2, 2, 4] :
!A_tensor_t into tensor<2x2x2x4xf32>
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] :
%B_exp = tensor.expand_shape %B [[0, 1], [2, 3]] output_shape[2, 8, 2, 4] :
!B_tensor_t into tensor<2x8x2x4xf32>
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] :
%C_exp = tensor.expand_shape %C [[0, 1], [2, 3]] output_shape[2, 2, 8, 2] :
!C_tensor_t into tensor<2x2x8x2xf32>

%cst_fill = arith.constant 0.0 : f32
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/dps-in-tile-and-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func.func @dps_test(%arg0: tensor<8x48x32x32xbf16>,
%add = arith.addf %out, %mul : bf16
linalg.yield %add : bf16
} -> tensor<8x48x32x32xbf16>
%expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<1536xbf16> into tensor<48x32xbf16>
%expanded = tensor.expand_shape %arg2 [[0, 1]] output_shape [48, 32] : tensor<1536xbf16> into tensor<48x32xbf16>
%2 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0, %expanded : tensor<8x48x32x32xbf16>, tensor<48x32xbf16>) outs(%arg3 : tensor<8x48x32x32xbf16>) {
^bb0(%in: bf16, %in_0: bf16, %out: bf16):
%add = arith.addf %in, %in_0 : bf16
Expand Down
Loading

0 comments on commit 7ee8e7e

Please sign in to comment.