diff --git a/include/TPP/Dialect/Xsmm/XsmmUtils.h b/include/TPP/Dialect/Xsmm/XsmmUtils.h index e935191c5..4a8ec2754 100644 --- a/include/TPP/Dialect/Xsmm/XsmmUtils.h +++ b/include/TPP/Dialect/Xsmm/XsmmUtils.h @@ -42,9 +42,10 @@ struct BinaryInfo { /// Represents a chain of XSMM ops that can be fused. All broadcast ops /// should have already been converted to flags. All stray allocations -/// should have already been converted to in-place reuse. Init zero -/// should have already been converted to Beta=0. +/// should have already been converted to in-place reuse. struct FusedMatch { + // This is the (optional) zero op that precedes the GEMM op + UnaryOp zeroOp; // This is the BRGEMM op BrgemmOp brgemmOp; // This is the (optional) binary op that follows the GEMM diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 6ac62a5fd..84d17f195 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -289,13 +289,11 @@ FailureOr getBinaryFlags(Type operandType, Type outputType, FailureOr getFusedBrgemmSequenceFromProducer(Operation *op) { // The loop is in reverse order, so we deduplicate the list making sure we // only have one type of each - SmallVector chain; + SmallVector chain; Operation *prev = nullptr; for (auto *user : op->getUsers()) { // Deduplicate, only take each operation once - if ((dyn_cast(user) && - dyn_cast(user).getCallee() == UnaryKind::ZERO) || - dyn_cast(user) || user == prev) + if (dyn_cast(user) || user == prev) continue; chain.push_back(user); prev = user; @@ -313,28 +311,44 @@ FailureOr getFusedBrgemmSequenceFromProducer(Operation *op) { int numUses = std::count(user->getOperands().begin(), user->getOperands().end(), op->getResult(0)); // At least one input and the last operand (output) is the same buffer - if (numUses < 2 || + if (((dyn_cast(user) && + dyn_cast(user).getCallee() != UnaryKind::ZERO) && + numUses < 2) || user->getOperands()[user->getOperands().size() - 1] != op->getResult(0)) return failure(); } - // We don't know how to fuse more than two tail ops after the BRGEMM - if (chain.size() > 3) + // We don't know how to fuse more than two tail ops after and a zero op before + // BRGEMM + if (chain.size() > 4) return failure(); - if (!isa(chain[0])) - // List is in reverse order, put the brgemm at the top + if (!(isa(chain[0]) || + (dyn_cast(chain[0]) && + dyn_cast(chain[0]).getCallee() == UnaryKind::ZERO))) + // List is in reverse order, put the brgemm or zero at the top std::reverse(chain.begin(), chain.end()); - // If we haven't found a BRGEMM, this are not the droids we're looking for - assert(isa(chain[0]) && "First op must be brgemm"); + // If we haven't found a BRGEMM or zero, this are not the droids we're looking + // for + assert(isa(chain[0]) || + (dyn_cast(chain[0]) && + dyn_cast(chain[0]).getCallee() == UnaryKind::ZERO && + isa(chain[1])) && + "First op must be brgemm or zero"); // Now, we're sure we have a chain, but not yet if it has the right types - // and in the right order: BRGEMM -> BINARY -> UNARY + // and in the right order: (ZER0) -> BRGEMM -> BINARY -> UNARY // Allowed patterns are: - // - GEMM + BINARY - // - GEMM + UNARY - // - GEMM + BINARY + UNARY + // - (ZERO) + GEMM + BINARY + // - (ZERO)+ GEMM + UNARY + // - (ZERO) + GEMM + BINARY + UNARY xsmm::FusedMatch fusedMatch; for (auto *user : chain) { + if (auto unaryOp = dyn_cast(user)) { + if (dyn_cast(user).getCallee() == UnaryKind::ZERO) { + fusedMatch.zeroOp = unaryOp; + continue; + } + } if (auto brgemmOp = (dyn_cast(user))) { // We only accept one of each if (fusedMatch.brgemmOp) diff --git a/lib/TPP/Transforms/CombineXsmmPass.cpp b/lib/TPP/Transforms/CombineXsmmPass.cpp index 327b40970..b8dbffe02 100644 --- a/lib/TPP/Transforms/CombineXsmmPass.cpp +++ b/lib/TPP/Transforms/CombineXsmmPass.cpp @@ -90,13 +90,22 @@ struct CombineXsmmOp : public OpRewritePattern { true); if (failed(brgemmFlags)) return failure(); + auto attributes = *brgemmFlags; + if (fusedMatch.zeroOp) { + if (attributes[0] == xsmm::GemmFlagsAttr::get(rewriter.getContext(), + xsmm::GemmFlags::NONE)) { + attributes.clear(); + } + attributes.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(), + xsmm::GemmFlags::BETA_0)); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(fusedMatch.binaryOp); Value dispatched = rewriter.create( loc, integer64, dims, xsmm::BinaryKindAttr::get(rewriter.getContext(), fusedMatch.binaryKind), xsmm::UnaryKindAttr::get(rewriter.getContext(), fusedMatch.unaryKind), - rewriter.getArrayAttr(*brgemmFlags), + rewriter.getArrayAttr(attributes), rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( rewriter.getContext(), xsmm::UnaryFlags::NONE)), rewriter.getArrayAttr( @@ -119,11 +128,18 @@ struct CombineXsmmOp : public OpRewritePattern { rewriter.create(loc, dtype, invokeOperands); rewriter.eraseOp(brgemmOp); rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp()); - rewriter.eraseOp(fusedMatch.binaryOp); - rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp()); - rewriter.eraseOp(fusedMatch.unaryOp); - rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp()); - + if (fusedMatch.binaryOp) { + rewriter.eraseOp(fusedMatch.binaryOp); + rewriter.eraseOp(fusedMatch.binaryOp->getOperand(0).getDefiningOp()); + } + if (fusedMatch.unaryOp) { + rewriter.eraseOp(fusedMatch.unaryOp); + rewriter.eraseOp(fusedMatch.unaryOp->getOperand(0).getDefiningOp()); + } + if (fusedMatch.zeroOp) { + rewriter.eraseOp(fusedMatch.zeroOp); + rewriter.eraseOp(fusedMatch.zeroOp->getOperand(0).getDefiningOp()); + } return success(); } }; diff --git a/test/Passes/xsmm-combine.mlir b/test/Passes/xsmm-combine.mlir index 812c30757..8edf7a2bc 100644 --- a/test/Passes/xsmm-combine.mlir +++ b/test/Passes/xsmm-combine.mlir @@ -261,4 +261,48 @@ func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x51 // CHECK-NOT: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (vnni_b, beta_0) binary_flags = (none) unary_flags = (none) data_type = bf16 // CHECK-NOT: xsmm.fused_brgemm(data_type = bf16, %[[DISPATCH]] , %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) +// ----- + memref.global "private" constant @__constant_32x32x32xf32_1 : memref<32x32x32xf32> = dense<1.600000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_32xf32_1 : memref<32xf32> = dense<1.300000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_32x32x32xf32_0 : memref<32x32x32xf32> = dense<1.500000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_32xf32_0 : memref<32xf32> = dense<1.200000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_32x32x32xf32 : memref<32x32x32xf32> = dense<1.400000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<1.100000e+00> {alignment = 64 : i64} + +func.func @forward(%arg0: memref<256x1024xf32>) -> memref<256x1024xf32> { + %c32_i64 = arith.constant 32 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.get_global @__constant_32xf32 : memref<32xf32> + %1 = memref.get_global @__constant_32x32x32xf32 : memref<32x32x32xf32> + %2 = memref.get_global @__constant_32x32x32xf32_0 : memref<32x32x32xf32> + %3 = memref.get_global @__constant_32xf32_0 : memref<32xf32> + %4 = memref.get_global @__constant_32x32x32xf32_1 : memref<32x32x32xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<256x1024xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 32) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + %5 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32 + xsmm.unary zero(data_type = f32, %5, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () + %subview_3 = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32 + xsmm.brgemm(data_type = f32, %6, %subview_3, %4, %subview, %c32_i64) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> () + %7 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32 + xsmm.binary add(data_type = f32, %7, %3, %subview, %subview) : (i64, memref<32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () + %8 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32 + xsmm.unary relu(data_type = f32, %8, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () + } + return %alloc : memref<256x1024xf32> +} + +// CHECK-LABEL:func.func @forward( +// CHECK: %[[ARG0:.*]]: memref<256x1024xf32>) -> memref<256x1024xf32> { +// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 +// CHECK: scf.forall (%[[arg1:.*]], %[[arg2:.*]]) in (8, 32) { +// CHECK: %[[subview:.*]] = memref.subview %{{.*}}[%[[arg1]], %[[arg2]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: %[[subview_2:.*]] = memref.subview %{{.*}}[%[[arg1]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[temp2:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (beta_0) binary_flags = (bcast_col_in0) unary_flags = (none) data_type = f32 +// CHECK: xsmm.fused_brgemm(data_type = f32, %[[temp2]], %[[subview_2]], %{{.*}}, %[[subview]], %{{.*}} %[[c32_i64]]) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> () +// CHECK: } +// CHECK: return %{{.*}} : memref<256x1024xf32>