Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VNNI affine maps #998

Merged
merged 26 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AffineMap;
class VectorType;

namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg

namespace vnni {
Expand All @@ -48,9 +48,8 @@ bool isInVnniLayout(int64_t expectedRank, VectorType vector);

// Return the first AffineDimExpr in the map `affineMap`
// with a VNNI layout pattern (AffineDimExpr floordiv VNNI).
FailureOr<AffineDimExpr> isInVnniLayout(linalg::GenericOp linalgOp,
AffineMap affineMap,
int64_t blockingFactor);
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);

} // namespace utils
} // namespace vnni
Expand Down
49 changes: 36 additions & 13 deletions lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,20 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
}
auto flags = rewriter.getArrayAttr(gemmFlags);
SmallVector<Value> invokeOperands;
SmallVector<Value> inputs = {linalgOp->getOperands()};

// Collapse VNNI factor dimension for matrix A.
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
if (brgemmInfo.isVnni) {
auto rankA = cast<ShapedType>(inputs[0].getType()).getRank();
assert(rankA >= 3 && "Invalid A mat rank for VNNI");
SmallVector<ReassociationIndices> reassoc;
for (int64_t index = 0; index < rankA - 2; index++)
reassoc.push_back({index});
reassoc.push_back(ReassociationIndices{rankA - 2, rankA - 1});

inputs[0] =
rewriter.create<memref::CollapseShapeOp>(loc, inputs[0], reassoc);
}

if (batch != 0) {
DenseI64ArrayAttr dims = DenseI64ArrayAttr::get(
Expand All @@ -463,8 +477,7 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
Value batchDim = rewriter.create<arith::ConstantOp>(
loc, integer64, rewriter.getIntegerAttr(integer64, batch));
invokeOperands.push_back(dispatched);
invokeOperands.append(linalgOp->getOperands().begin(),
linalgOp->getOperands().end());
invokeOperands.append(inputs);
invokeOperands.push_back(batchDim);
rewriter.replaceOpWithNewOp<xsmm::BrgemmOp>(linalgOp, dtype,
invokeOperands);
Expand All @@ -474,8 +487,7 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
Value dispatched = rewriter.create<xsmm::GemmDispatchOp>(
loc, integer64, dims, flags, dtype);
invokeOperands.push_back(dispatched);
invokeOperands.append(linalgOp->getOperands().begin(),
linalgOp->getOperands().end());
invokeOperands.append(inputs);
rewriter.replaceOpWithNewOp<xsmm::GemmOp>(linalgOp, dtype, invokeOperands);
}
}
Expand All @@ -502,7 +514,7 @@ checkStructure(linalg::LinalgOp linalgOp) {
return failure();
}
if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 ||
(contractionDims->k.size() != 2 && contractionDims->k.size() != 1) ||
(contractionDims->k.size() > 3 || contractionDims->k.size() < 1) ||
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
contractionDims->batch.size() != 0) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n");
return failure();
Expand Down Expand Up @@ -575,8 +587,10 @@ static FailureOr<BrgemmInfo> checkAccess(linalg::LinalgOp linalgOp, unsigned m,
auto loops = linalgOp.computeStaticLoopSizes();
int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0;

bool isVnni = vnni::utils::isInVnniLayout(linalgOp);

BrgemmInfo info{loops[m], loops[n], loops[k], batchVal, *lda,
*ldb, *ldc, strideA, strideB};
*ldb, *ldc, strideA, strideB, isVnni};
return info;
}

Expand All @@ -601,7 +615,8 @@ static FailureOr<BrgemmInfo> isMappableToBrgemm(linalg::LinalgOp linalgOp) {
unsigned n = contractionDims->n[0];
unsigned k = contractionDims->k.back();
std::optional<unsigned> batch;
if (contractionDims->k.size() == 2)
unsigned extraVnniRed = vnni::utils::isInVnniLayout(linalgOp);
if (contractionDims->k.size() == (2 + extraVnniRed))
batch = contractionDims->k.front();

LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Candidate dims: "
Expand Down Expand Up @@ -772,14 +787,18 @@ makeMinorDimensionsInnerMost(RewriterBase &rewriter, linalg::GenericOp linalgOp,
return linalgOp;
}

if (!isInnerMostDim(operandA, *minorKInCodomainOpA)) {
bool isVnni = vnni::utils::isInVnniLayout(linalgOp);

if (!isVnni && !isInnerMostDim(operandA, *minorKInCodomainOpA)) {
LLVM_DEBUG(llvm::dbgs()
<< "[makeMinorDimensionsInnerMost] emit transpose for A\n");
assert(isInnerMostDim(operandA, *minorMInCodomainOpA));
emitTransposeOnOperand(rewriter, linalgOp, operandA, *minorKInCodomainOpA,
*minorMInCodomainOpA);
}
if (!isInnerMostDim(operandB, *minorNInCodomainOpB)) {
// Do not inject transposes in case of VNNI format.
// Otherwise, it breaks later VNNI layout validation.
if (!isVnni && !isInnerMostDim(operandB, *minorNInCodomainOpB)) {
LLVM_DEBUG(llvm::dbgs()
<< "[makeMinorDimensionsInnerMost] emit transpose for B\n");
assert(isInnerMostDim(operandB, *minorKInCodomainOpB));
Expand All @@ -804,7 +823,7 @@ void ConvertLinalgToXsmm::runOnOperation() {
unsigned n = contractionDims->n[0];
unsigned k = contractionDims->k.back();
std::optional<unsigned> batch;
if (contractionDims->k.size() == 2)
if (contractionDims->k.size() == 3)
contractionDims->k.front();

if (failed(checkAccess(genericOp, m, n, k, batch))) {
Expand Down Expand Up @@ -1085,7 +1104,12 @@ struct ConvertGenericToVnniMatmulLikeOp
int64_t kPos = 1;
if (hasBatch)
kPos++;
int64_t k = cast<ShapedType>(bufferA.getType()).getShape()[kPos];
// Take the whole reduction dim size. Account for the VNNI factor that
// splits the K dim in the shape.
std::optional<int64_t> vnniFactor =
vnni::utils::getVnniBlockingFactor(bufferB.getType());
int64_t k =
cast<ShapedType>(bufferA.getType()).getShape()[kPos] * *vnniFactor;
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
int64_t batch = 0;
if (hasBatch)
batch = cast<ShapedType>(bufferA.getType()).getShape()[0];
Expand All @@ -1107,8 +1131,7 @@ struct ConvertGenericToVnniMatmulLikeOp
if (hasBatch)
leadingDimPosOnAandB++;
int64_t lda = (*stridesOnLhs)[leadingDimPosOnAandB];
int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] /
*vnni::utils::getVnniBlockingFactor(bufferB.getType());
int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] / *vnniFactor;
int64_t ldc = (*stridesOnOutput)[0];

BrgemmInfo brgemmInfo{m, n, k, batch, lda,
Expand Down
16 changes: 5 additions & 11 deletions lib/TPP/IR/MatcherUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
.operation(NumOfLoops(_OR(EqualsTo(5), EqualsTo(4))))
.input(MatchAll(), HasStaticShape())
.output(MatchAll(), HasStaticShape())
.input(MatchOne(0), HasMap(BroadcastableProjectedPermutation(), &mapOperandA))
.input(MatchOne(0), HasMap(Any(), &mapOperandA))
.input(MatchOne(1), HasMap(Any(), &mapOperandB))
.output(MatchOne(0), HasMap(BroadcastableProjectedPermutation(), &mapOperandC))
.region(MatchOne(0),
Expand All @@ -82,17 +82,13 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,

llvm::SmallVector<int64_t> operandAPosIterRed = getIteratorPos(
linalgOp, mapOperandA, mlir::utils::IteratorType::reduction);
if (operandAPosIterRed.size() != 2 && operandAPosIterRed.size() != 1)
if (operandAPosIterRed.size() != 3 && operandAPosIterRed.size() != 2)
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
return std::make_pair(false, hasBatch);

int64_t batchRedIter = std::numeric_limits<int64_t>::max();
int64_t kRedIter = std::numeric_limits<int64_t>::max();
if (operandAPosIterRed.size() == 2) {
if (operandAPosIterRed.size() == 3) {
batchRedIter = operandAPosIterRed[0];
kRedIter = operandAPosIterRed[1];
hasBatch = true;
} else {
kRedIter = operandAPosIterRed[0];
}

// Operand B: One parallel iterator (j) and three reduction ones (batch,
Expand All @@ -112,10 +108,8 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
return std::make_pair(false, hasBatch);
}

auto vnniDim =
vnni::utils::isInVnniLayout(linalgOp, mapOperandB, *blockingFactor);
bool isBrgemmOp = succeeded(vnniDim) && vnniDim->getPosition() == kRedIter;
return std::make_pair(isBrgemmOp, hasBatch);
bool isBrgemmVnni = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor);
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
return std::make_pair(isBrgemmVnni, hasBatch);
}

// Return true if all the operand have the same type, i.e., no implicit
Expand Down
71 changes: 40 additions & 31 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,34 +81,33 @@ static Value toUnPackLayoutImpl(OpBuilder &builder, Location loc, Value input,
}

static Value handleLayout_VNNI(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> tiles) {
ArrayRef<OpFoldResult> tiles, int64_t kDimPos) {
assert(tiles.size() == 1 && "expect 1 block for VNNI");
SmallVector<int64_t> innerDimPos = {
cast<ShapedType>(input.getType()).getRank() - 2};
return toPackLayoutImpl(builder, loc, input, tiles, innerDimPos,
return toPackLayoutImpl(builder, loc, input, tiles,
SmallVector<int64_t>{kDimPos},
/*outerDimsPerm=*/{});
}

static Value handleBRGemmLayout_VNNI(OpBuilder &builder, Location loc,
Value input,
ArrayRef<OpFoldResult> tiles) {
Value input, ArrayRef<OpFoldResult> tiles,
int64_t kDimPos) {
assert(tiles.size() == 1 && "expect 1 block for VNNI");
SmallVector<int64_t> innerDimPos = {1};
return toPackLayoutImpl(builder, loc, input, tiles, innerDimPos,
return toPackLayoutImpl(builder, loc, input, tiles,
SmallVector<int64_t>{kDimPos},
/*outerDimsPerm=*/{});
}

// Helper function to pack from NC to [N/2][C][2].
// Helper function to pack from [outer][K][inner] to [outer][K/2][inner][2].
static Value toPackLayout_VNNI(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> tiles) {
return handleLayout_VNNI(builder, loc, input, tiles);
ArrayRef<OpFoldResult> tiles, int64_t kDimPos) {
return handleLayout_VNNI(builder, loc, input, tiles, kDimPos);
}

// Helper function to pack from [N][K][C] to [N][K/2][C][2].
// Helper function to pack from [outer][K][inner] to [outer][K/2][inner][2].
static Value toPackBRGemmLayout_VNNI(OpBuilder &builder, Location loc,
Value input,
ArrayRef<OpFoldResult> tiles) {
return handleBRGemmLayout_VNNI(builder, loc, input, tiles);
Value input, ArrayRef<OpFoldResult> tiles,
int64_t kDimPos) {
return handleBRGemmLayout_VNNI(builder, loc, input, tiles, kDimPos);
}

static Value handleLayoutNCHW_NCHWc(OpBuilder &builder, Location loc,
Expand Down Expand Up @@ -328,7 +327,8 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter,
if (matmulOp.hasPureBufferSemantics())
return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");

if (failed(linalgx::utils::isContraction(matmulOp)))
auto dims = linalgx::utils::isContraction(matmulOp);
if (failed(dims))
return rewriter.notifyMatchFailure(matmulOp, "require matmul semantics");

OpOperand &operandB = matmulOp->getOpOperand(1);
Expand All @@ -339,30 +339,36 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter,
"unsupported blocking factor for type");
}

AffineMap mapOperandB = matmulOp.getMatchingIndexingMap(&operandB);
if (succeeded(vnni::utils::isInVnniLayout(matmulOp, mapOperandB,
*blockingFactor))) {
if (vnni::utils::isInVnniLayout(matmulOp, *blockingFactor)) {
return rewriter.notifyMatchFailure(matmulOp, "already packed to VNNI");
}

Location loc = matmulOp.getLoc();
SmallVector<OpFoldResult> tilesOnSmallK = {
rewriter.getI64IntegerAttr(*blockingFactor)};
// reshape input B.
Value packedMatrixB =
toPackLayout_VNNI(rewriter, loc, operandB.get(), tilesOnSmallK);
SmallVector<std::pair<Value, unsigned>> kOperands;
matmulOp.mapIterationSpaceDimToAllOperandDims(dims->k.back(), kOperands);
if (kOperands.size() != 2)
return rewriter.notifyMatchFailure(matmulOp,
"Invalid reduction dim operands");
// Reshape input A.
Value packedMatrixA =
toPackLayout_VNNI(rewriter, loc, matmulOp.getInputs()[0], tilesOnSmallK,
kOperands[0].second);
// Reshape input B.
Value packedMatrixB = toPackLayout_VNNI(rewriter, loc, operandB.get(),
tilesOnSmallK, kOperands[1].second);

MLIRContext *ctx = matmulOp.getContext();
AffineExpr p1, p2, r1, p3, p4, r2, r3;
SmallVector<Value> packedInputs = {matmulOp.getInputs()[0], packedMatrixB};
SmallVector<Value> packedInputs = {packedMatrixA, packedMatrixB};
AffineMap mapA, mapB, mapC;
Value matrixC = matmulOp.getOutputs()[0];

// IB JB KB ib jb kb VNNI
bindDims(ctx, p1, p2, r1, p3, p4, r2, r3);
mapA = AffineMap::get(/*dims=*/7, /*symbols=*/0, {p1, r1, p3, r2}, ctx);
mapB = AffineMap::get(/*dims=*/7, /*symbols=*/0,
{p2, r1, r2.floorDiv(*blockingFactor), p4, r3}, ctx);
mapA = AffineMap::get(/*dims=*/7, /*symbols=*/0, {p1, r1, p3, r2, r3}, ctx);
mapB = AffineMap::get(/*dims=*/7, /*symbols=*/0, {p2, r1, r2, p4, r3}, ctx);
mapC = AffineMap::get(/*dims=*/7, /*symbols=*/0, {p1, p2, p3, p4}, ctx);
auto replacementOp = rewriter.create<linalg::GenericOp>(
loc, matrixC.getType(), packedInputs, ValueRange{matrixC},
Expand Down Expand Up @@ -411,22 +417,24 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> tilesOnK = {rewriter.getI64IntegerAttr(2)};

Location loc = brgemmOp.getLoc();
// Reshape input A.
Value packedMatrixA = toPackBRGemmLayout_VNNI(
rewriter, loc, brgemmOp.getInputs()[0], tilesOnK, 2);
// Reshape input B.
Value packedMatrixB =
toPackBRGemmLayout_VNNI(rewriter, loc, operandB, tilesOnK);
toPackBRGemmLayout_VNNI(rewriter, loc, operandB, tilesOnK, 1);

MLIRContext *ctx = brgemmOp.getContext();
AffineExpr r1, p1, p2, r3, r4;
AffineMap mapA, mapB, mapC;
bindDims(ctx, r1, p1, p2, r3, r4);
mapA = AffineMap::get(/*dims=*/5, /*symbols=*/0, {r1, p1, r3}, ctx);
mapB = AffineMap::get(/*dims=*/5, /*symbols=*/0,
{r1, r3.floorDiv(*blockingFactor), p2, r4}, ctx);
mapA = AffineMap::get(/*dims=*/5, /*symbols=*/0, {r1, p1, r3, r4}, ctx);
mapB = AffineMap::get(/*dims=*/5, /*symbols=*/0, {r1, r3, p2, r4}, ctx);
mapC = AffineMap::get(/*dims=*/5, /*symbols=*/0, {p1, p2}, ctx);

auto replacementOp = rewriter.create<linalg::GenericOp>(
loc, brgemmOp.getOutputs()[0].getType(),
ValueRange{brgemmOp.getInputs()[0], packedMatrixB},
ValueRange{packedMatrixA, packedMatrixB},
ValueRange{brgemmOp.getOutputs()[0]},
ArrayRef<AffineMap>{mapA, mapB, mapC},
ArrayRef<mlir::utils::IteratorType>{
Expand Down Expand Up @@ -664,6 +672,7 @@ struct PackVNNI : public tpp::impl::PackVNNIBase<PackVNNI> {
RewritePatternSet patterns(ctx);
linalg::populateLinalgDeGeneralizationPatterns(patterns);
patterns.add<VNNIOnMatmul, VNNIOnBRGemm>(ctx);
tensor::populateSimplifyPackAndUnpackPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down
Loading