Skip to content

Commit

Permalink
Zero initialization as beta=0 during fusion (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kavitha authored Mar 27, 2024
1 parent b04f662 commit 42c6ac0
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 23 deletions.
5 changes: 3 additions & 2 deletions include/TPP/Dialect/Xsmm/XsmmUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ struct BinaryInfo {

/// Represents a chain of XSMM ops that can be fused. All broadcast ops
/// should have already been converted to flags. All stray allocations
/// should have already been converted to in-place reuse. Init zero
/// should have already been converted to Beta=0.
/// should have already been converted to in-place reuse.
struct FusedMatch {
// This is the (optional) zero op that precedes the GEMM op
UnaryOp zeroOp;
// This is the BRGEMM op
BrgemmOp brgemmOp;
// This is the (optional) binary op that follows the GEMM
Expand Down
44 changes: 29 additions & 15 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,11 @@ FailureOr<BinaryFlags> getBinaryFlags(Type operandType, Type outputType,
FailureOr<FusedMatch> getFusedBrgemmSequenceFromProducer(Operation *op) {
// The loop is in reverse order, so we deduplicate the list making sure we
// only have one type of each
SmallVector<Operation *, 3> chain;
SmallVector<Operation *, 4> chain;
Operation *prev = nullptr;
for (auto *user : op->getUsers()) {
// Deduplicate, only take each operation once
if ((dyn_cast<xsmm::UnaryOp>(user) &&
dyn_cast<xsmm::UnaryOp>(user).getCallee() == UnaryKind::ZERO) ||
dyn_cast<func::ReturnOp>(user) || user == prev)
if (dyn_cast<func::ReturnOp>(user) || user == prev)
continue;
chain.push_back(user);
prev = user;
Expand All @@ -313,28 +311,44 @@ FailureOr<FusedMatch> getFusedBrgemmSequenceFromProducer(Operation *op) {
int numUses = std::count(user->getOperands().begin(),
user->getOperands().end(), op->getResult(0));
// At least one input and the last operand (output) is the same buffer
if (numUses < 2 ||
if (((dyn_cast<xsmm::UnaryOp>(user) &&
dyn_cast<xsmm::UnaryOp>(user).getCallee() != UnaryKind::ZERO) &&
numUses < 2) ||
user->getOperands()[user->getOperands().size() - 1] != op->getResult(0))
return failure();
}
// We don't know how to fuse more than two tail ops after the BRGEMM
if (chain.size() > 3)
// We don't know how to fuse more than two tail ops after and a zero op before
// BRGEMM
if (chain.size() > 4)
return failure();
if (!isa<xsmm::BrgemmOp>(chain[0]))
// List is in reverse order, put the brgemm at the top
if (!(isa<xsmm::BrgemmOp>(chain[0]) ||
(dyn_cast<xsmm::UnaryOp>(chain[0]) &&
dyn_cast<xsmm::UnaryOp>(chain[0]).getCallee() == UnaryKind::ZERO)))
// List is in reverse order, put the brgemm or zero at the top
std::reverse(chain.begin(), chain.end());

// If we haven't found a BRGEMM, this are not the droids we're looking for
assert(isa<xsmm::BrgemmOp>(chain[0]) && "First op must be brgemm");
// If we haven't found a BRGEMM or zero, this are not the droids we're looking
// for
assert(isa<xsmm::BrgemmOp>(chain[0]) ||
(dyn_cast<xsmm::UnaryOp>(chain[0]) &&
dyn_cast<xsmm::UnaryOp>(chain[0]).getCallee() == UnaryKind::ZERO &&
isa<xsmm::BrgemmOp>(chain[1])) &&
"First op must be brgemm or zero");

// Now, we're sure we have a chain, but not yet if it has the right types
// and in the right order: BRGEMM -> BINARY -> UNARY
// and in the right order: (ZER0) -> BRGEMM -> BINARY -> UNARY
// Allowed patterns are:
// - GEMM + BINARY
// - GEMM + UNARY
// - GEMM + BINARY + UNARY
// - (ZERO) + GEMM + BINARY
// - (ZERO)+ GEMM + UNARY
// - (ZERO) + GEMM + BINARY + UNARY
xsmm::FusedMatch fusedMatch;
for (auto *user : chain) {
if (auto unaryOp = dyn_cast<xsmm::UnaryOp>(user)) {
if (dyn_cast<xsmm::UnaryOp>(user).getCallee() == UnaryKind::ZERO) {
fusedMatch.zeroOp = unaryOp;
continue;
}
}
if (auto brgemmOp = (dyn_cast<xsmm::BrgemmOp>(user))) {
// We only accept one of each
if (fusedMatch.brgemmOp)
Expand Down
28 changes: 22 additions & 6 deletions lib/TPP/Transforms/CombineXsmmPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,22 @@ struct CombineXsmmOp : public OpRewritePattern<xsmm::BrgemmOp> {
true);
if (failed(brgemmFlags))
return failure();
auto attributes = *brgemmFlags;
if (fusedMatch.zeroOp) {
if (attributes[0] == xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::NONE)) {
attributes.clear();
}
attributes.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::BETA_0));
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(fusedMatch.binaryOp);
Value dispatched = rewriter.create<xsmm::FusedBrgemmDispatchOp>(
loc, integer64, dims,
xsmm::BinaryKindAttr::get(rewriter.getContext(), fusedMatch.binaryKind),
xsmm::UnaryKindAttr::get(rewriter.getContext(), fusedMatch.unaryKind),
rewriter.getArrayAttr(*brgemmFlags),
rewriter.getArrayAttr(attributes),
rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get(
rewriter.getContext(), xsmm::UnaryFlags::NONE)),
rewriter.getArrayAttr(
Expand All @@ -119,11 +128,18 @@ struct CombineXsmmOp : public OpRewritePattern<xsmm::BrgemmOp> {
rewriter.create<xsmm::FusedBrgemmOp>(loc, dtype, invokeOperands);
rewriter.eraseOp(brgemmOp);
rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp());
rewriter.eraseOp(fusedMatch.binaryOp);
rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp());
rewriter.eraseOp(fusedMatch.unaryOp);
rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp());

if (fusedMatch.binaryOp) {
rewriter.eraseOp(fusedMatch.binaryOp);
rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp());
}
if (fusedMatch.unaryOp) {
rewriter.eraseOp(fusedMatch.unaryOp);
rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp());
}
if (fusedMatch.zeroOp) {
rewriter.eraseOp(fusedMatch.zeroOp);
rewriter.eraseOp(fusedMatch.zeroOp->getOperand(0).getDefiningOp());
}
return success();
}
};
Expand Down
44 changes: 44 additions & 0 deletions test/Passes/xsmm-combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,48 @@ func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x51
// CHECK-NOT: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (vnni_b, beta_0) binary_flags = (none) unary_flags = (none) data_type = bf16
// CHECK-NOT: xsmm.fused_brgemm(data_type = bf16, %[[DISPATCH]] , %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}})

// -----
memref.global "private" constant @__constant_32x32x32xf32_1 : memref<32x32x32xf32> = dense<1.600000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32xf32_1 : memref<32xf32> = dense<1.300000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32x32x32xf32_0 : memref<32x32x32xf32> = dense<1.500000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32xf32_0 : memref<32xf32> = dense<1.200000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32x32x32xf32 : memref<32x32x32xf32> = dense<1.400000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<1.100000e+00> {alignment = 64 : i64}

func.func @forward(%arg0: memref<256x1024xf32>) -> memref<256x1024xf32> {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_32xf32 : memref<32xf32>
%1 = memref.get_global @__constant_32x32x32xf32 : memref<32x32x32xf32>
%2 = memref.get_global @__constant_32x32x32xf32_0 : memref<32x32x32xf32>
%3 = memref.get_global @__constant_32xf32_0 : memref<32xf32>
%4 = memref.get_global @__constant_32x32x32xf32_1 : memref<32x32x32xf32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<256x1024xf32>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xf32>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xf32>
scf.forall (%arg1, %arg2) in (8, 32) {
%subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
%5 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32
xsmm.unary zero(data_type = f32, %5, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
%subview_3 = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
%6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32
xsmm.brgemm(data_type = f32, %6, %subview_3, %4, %subview, %c32_i64) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> ()
%7 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32
xsmm.binary add(data_type = f32, %7, %3, %subview, %subview) : (i64, memref<32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
%8 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32
xsmm.unary relu(data_type = f32, %8, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
}
return %alloc : memref<256x1024xf32>
}

// CHECK-LABEL:func.func @forward(
// CHECK: %[[ARG0:.*]]: memref<256x1024xf32>) -> memref<256x1024xf32> {
// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64
// CHECK: scf.forall (%[[arg1:.*]], %[[arg2:.*]]) in (8, 32) {
// CHECK: %[[subview:.*]] = memref.subview %{{.*}}[%[[arg1]], %[[arg2]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
// CHECK: %[[subview_2:.*]] = memref.subview %{{.*}}[%[[arg1]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>
// CHECK: %[[temp2:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (beta_0) binary_flags = (bcast_col_in0) unary_flags = (none) data_type = f32
// CHECK: xsmm.fused_brgemm(data_type = f32, %[[temp2]], %[[subview_2]], %{{.*}}, %[[subview]], %{{.*}} %[[c32_i64]]) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> ()
// CHECK: }
// CHECK: return %{{.*}} : memref<256x1024xf32>

0 comments on commit 42c6ac0

Please sign in to comment.