Skip to content

Commit

Permalink
Tile config hoisting pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Kavitha Madhu committed Feb 27, 2024
1 parent 547e656 commit 951f7e5
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 11 deletions.
12 changes: 11 additions & 1 deletion include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,17 @@ def TileConfigInsertionPass : Pass<"tile-config-insertion-pass",
"func::FuncOp"> {
let summary = "Insert tile configuration xsmm calls";
let description = [{
Insert tile configuration xsmm calls and perform LICM on them.
Insert tile configuration xsmm calls.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

def TileConfigHoistingPass : Pass<"tile-config-hoisting-pass",
"func::FuncOp"> {
let summary = "Hoist tile configuration invoke xsmm calls";
let description = [{
Run LICM on Tile configuration invoke calls.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_mlir_library(TPPTransforms
CombineXsmmPass.cpp
SCFParallelLoopTiling.cpp
TileConfig.cpp
TileConfigHoisting.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
19 changes: 9 additions & 10 deletions lib/TPP/Transforms/TileConfig.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
//===- TileConfig.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements loop tiling on parallel loops.
// This file inserts tile configuration calls.
//
//===----------------------------------------------------------------------===//
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <list>
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_TILECONFIGINSERTIONPASS
Expand Down Expand Up @@ -109,11 +108,11 @@ struct TileConfig : OpRewritePattern<InvokeOpTy> {

auto alloca = rewriter.create<memref::AllocaOp>(
op.getLoc(), MemRefType::get({64}, rewriter.getI8Type()));

ValueRange tileConfigInputs{alloca};
rewriter.create<mlir::xsmm::TileConfigOp>(
op.getLoc(), tileConfigSetup, tileConfigInputs);
rewriter.create<mlir::xsmm::TileConfigOp>(op.getLoc(), tileConfigSetup,
tileConfigInputs);

SmallVector<Value> invokeOperands;
invokeOperands.push_back(dispatch);
auto opItr = op->getOperands().begin();
Expand All @@ -125,10 +124,10 @@ struct TileConfig : OpRewritePattern<InvokeOpTy> {
invokeOperands);

ValueRange tileResetInputs{alloca};
rewriter.create<mlir::xsmm::TileConfigOp>(
op.getLoc(), tileConfigReset, tileResetInputs);
rewriter.create<mlir::xsmm::TileConfigOp>(op.getLoc(), tileConfigReset,
tileResetInputs);

//rewriter.create<memref::DeallocOp>(op.getLoc(), alloca);
// rewriter.create<memref::DeallocOp>(op.getLoc(), alloca);
rewriter.eraseOp(op);
rewriter.eraseOp(op.getOperand(0).getDefiningOp());
return success();
Expand Down
101 changes: 101 additions & 0 deletions lib/TPP/Transforms/TileConfigHoisting.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===- TileConfigHoisting.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements tile configuration hoisting on parallel loops.
//
//===----------------------------------------------------------------------===//
#include "TPP/Dialect/Xsmm/XsmmOps.h"
#include "TPP/Dialect/Xsmm/XsmmUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_TILECONFIGHOISTINGPASS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;
using namespace mlir::xsmm;

namespace mlir {
namespace tpp {

struct TileConfigHoisting : OpRewritePattern<memref::AllocaOp> {
using OpRewritePattern<memref::AllocaOp>::OpRewritePattern;

LogicalResult matchAndRewrite(memref::AllocaOp alloca,
PatternRewriter &rewriter) const override {

xsmm::TileConfigOp firstTileConfig, secondTileConfig;
for (auto *user : alloca->getUsers()) {
if (!dyn_cast<xsmm::TileConfigOp>(user)) {
return failure();
}
auto flags =
dyn_cast<xsmm::TileConfigDispatchOp>(
dyn_cast<xsmm::TileConfigOp>(user).getOperand(0).getDefiningOp())
.getFlags();
for (auto flagItr : flags) {
if (flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_RESET_TILECONFIG)) {
firstTileConfig = dyn_cast<xsmm::TileConfigOp>(user);

} else if (flagItr == xsmm::GemmFlagsAttr::get(
rewriter.getContext(),
mlir::xsmm::GemmFlags::NO_SETUP_TILECONFIG)) {
secondTileConfig = dyn_cast<xsmm::TileConfigOp>(user);
}
}
}

scf::ParallelOp parallelOpParent = NULL;
auto op = alloca.getOperation();
while (true) {
if (op->getParentOfType<scf::ParallelOp>()) {
if (&op->getParentOfType<scf::ParallelOp>().getRegion() ==
alloca->getParentRegion()) {
return failure();
}
parallelOpParent = op->getParentOfType<scf::ParallelOp>();
break;
}
op = op->getParentOp();
}

if (parallelOpParent == NULL)
return failure();

rewriter.moveOpBefore(alloca, parallelOpParent.getBody(),
parallelOpParent.getBody()->begin());
rewriter.moveOpAfter(firstTileConfig, alloca);
rewriter.moveOpBefore(secondTileConfig, parallelOpParent.getBody(),
std::prev(parallelOpParent.getBody()->end(), 1));
return success();
}
};

struct TileConfigHoistingPass
: public impl::TileConfigHoistingPassBase<TileConfigHoistingPass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<TileConfigHoisting>(patterns.getContext());
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCombinePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace tpp
} // namespace mlir

0 comments on commit 951f7e5

Please sign in to comment.