Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VNNI affine maps #998

Merged
merged 26 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/TPP/IR/MatcherUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ bool isTwoDFillOpWithZeros(linalg::LinalgOp linalgOp,
SmallVectorImpl<Value> *capturedOperands = nullptr);

// Return a pair where the first member is true if and only if the operation
// represents a brgemm in VNNI layout. The second member tells if the brgemm has
// the batch dimension; it has meaning only if the first field is valid.
// represents a matmul (GEMM or BRGEMM) in VNNI layout. The second member tells
// if the brgemm has the batch dimension; it has meaning only if the first field
// is valid.
std::pair<bool, bool>
isBrgemmVnniOp(linalg::GenericOp linalgOp,
isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *capturedOperands = nullptr);

} // namespace utils
Expand Down
11 changes: 5 additions & 6 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AffineMap;
class VectorType;

namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg

namespace vnni {
Expand All @@ -46,11 +46,10 @@ bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector);

bool isInVnniLayout(int64_t expectedRank, VectorType vector);

// Return the first AffineDimExpr in the map `affineMap`
// with a VNNI layout pattern (AffineDimExpr floordiv VNNI).
FailureOr<AffineDimExpr> isInVnniLayout(linalg::GenericOp linalgOp,
AffineMap affineMap,
int64_t blockingFactor);
// Return true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);

} // namespace utils
} // namespace vnni
Expand Down
81 changes: 56 additions & 25 deletions lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,21 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
}
auto flags = rewriter.getArrayAttr(gemmFlags);
SmallVector<Value> invokeOperands;
SmallVector<Value> inputs = {linalgOp->getOperands()};

// Collapse VNNI factor dimension for matrix A:
// A <32x8x2> -> A <32x16>
if (brgemmInfo.isVnni) {
auto rankA = cast<ShapedType>(inputs[0].getType()).getRank();
assert(rankA >= 3 && "Invalid A mat rank for VNNI");
SmallVector<ReassociationIndices> reassoc;
for (int64_t index = 0; index < rankA - 2; index++)
reassoc.push_back({index});
reassoc.push_back(ReassociationIndices{rankA - 2, rankA - 1});

inputs[0] =
rewriter.create<memref::CollapseShapeOp>(loc, inputs[0], reassoc);
}

if (batch != 0) {
DenseI64ArrayAttr dims = DenseI64ArrayAttr::get(
Expand All @@ -463,8 +478,7 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
Value batchDim = rewriter.create<arith::ConstantOp>(
loc, integer64, rewriter.getIntegerAttr(integer64, batch));
invokeOperands.push_back(dispatched);
invokeOperands.append(linalgOp->getOperands().begin(),
linalgOp->getOperands().end());
invokeOperands.append(inputs);
invokeOperands.push_back(batchDim);
rewriter.replaceOpWithNewOp<xsmm::BrgemmOp>(linalgOp, dtype,
invokeOperands);
Expand All @@ -474,8 +488,7 @@ static void replaceOpWithGemmLikeOp(RewriterBase &rewriter,
Value dispatched = rewriter.create<xsmm::GemmDispatchOp>(
loc, integer64, dims, flags, dtype);
invokeOperands.push_back(dispatched);
invokeOperands.append(linalgOp->getOperands().begin(),
linalgOp->getOperands().end());
invokeOperands.append(inputs);
rewriter.replaceOpWithNewOp<xsmm::GemmOp>(linalgOp, dtype, invokeOperands);
}
}
Expand All @@ -502,7 +515,7 @@ checkStructure(linalg::LinalgOp linalgOp) {
return failure();
}
if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 ||
(contractionDims->k.size() != 2 && contractionDims->k.size() != 1) ||
contractionDims->k.size() > 3 || contractionDims->k.size() < 1 ||
contractionDims->batch.size() != 0) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n");
return failure();
Expand Down Expand Up @@ -575,14 +588,16 @@ static FailureOr<BrgemmInfo> checkAccess(linalg::LinalgOp linalgOp, unsigned m,
auto loops = linalgOp.computeStaticLoopSizes();
int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0;

bool isVnni = vnni::utils::isInVnniLayout(linalgOp);

BrgemmInfo info{loops[m], loops[n], loops[k], batchVal, *lda,
*ldb, *ldc, strideA, strideB};
*ldb, *ldc, strideA, strideB, isVnni};
return info;
}

// Check if the given generic is mappable to a brgemm xsmm op.
// - It is a contraction, with:
// -- 1 m and 1 n and 2 k dimensions.
// -- 1 m, 1 n, and 2 or 3 (VNNI) k dimensions.
// -- m appears on the LHS and OUT but not in RHS.
// -- n appears on the RHS and OUT but not in LHS.
// -- k and k' appear on the RHS and LHS but not OUT.
Expand All @@ -600,8 +615,15 @@ static FailureOr<BrgemmInfo> isMappableToBrgemm(linalg::LinalgOp linalgOp) {
unsigned m = contractionDims->m[0];
unsigned n = contractionDims->n[0];
unsigned k = contractionDims->k.back();

// Check if there is a batch reduce dimension.
// At least one K-dim is the GEMM reduction.
// In case of VNNI layout, there is additional reduction dimension
// representing VNNI blocking factor.
std::optional<unsigned> batch;
if (contractionDims->k.size() == 2)
unsigned numBrgemmReductionDims =
vnni::utils::isInVnniLayout(linalgOp) ? 3 : 2;
if (contractionDims->k.size() == numBrgemmReductionDims)
batch = contractionDims->k.front();

LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Candidate dims: "
Expand Down Expand Up @@ -772,17 +794,23 @@ makeMinorDimensionsInnerMost(RewriterBase &rewriter, linalg::GenericOp linalgOp,
return linalgOp;
}

if (!isInnerMostDim(operandA, *minorKInCodomainOpA)) {
bool isVnni = vnni::utils::isInVnniLayout(linalgOp);

if (!isVnni && !isInnerMostDim(operandA, *minorKInCodomainOpA)) {
LLVM_DEBUG(llvm::dbgs()
<< "[makeMinorDimensionsInnerMost] emit transpose for A\n");
assert(isInnerMostDim(operandA, *minorMInCodomainOpA));
if (!isInnerMostDim(operandA, *minorMInCodomainOpA))
return failure();
emitTransposeOnOperand(rewriter, linalgOp, operandA, *minorKInCodomainOpA,
*minorMInCodomainOpA);
}
if (!isInnerMostDim(operandB, *minorNInCodomainOpB)) {
// Do not inject transposes in case of VNNI format.
// Otherwise, it breaks later VNNI layout validation.
if (!isVnni && !isInnerMostDim(operandB, *minorNInCodomainOpB)) {
LLVM_DEBUG(llvm::dbgs()
<< "[makeMinorDimensionsInnerMost] emit transpose for B\n");
assert(isInnerMostDim(operandB, *minorKInCodomainOpB));
if (!isInnerMostDim(operandB, *minorKInCodomainOpB))
return failure();
emitTransposeOnOperand(rewriter, linalgOp, operandB, *minorKInCodomainOpB,
*minorNInCodomainOpB);
}
Expand All @@ -795,7 +823,7 @@ void ConvertLinalgToXsmm::runOnOperation() {
IRRewriter rewriter(&getContext());

// Enable conversion for linalg.generic to XSMM Brgemm if possible.
auto res = getOperation()->walk([&](linalg::GenericOp genericOp) {
getOperation()->walk([&](linalg::GenericOp genericOp) {
auto contractionDims = checkStructure(genericOp);
// If the generic does not match the structure of a Brgemm op, skip it.
if (failed(contractionDims))
Expand All @@ -804,22 +832,18 @@ void ConvertLinalgToXsmm::runOnOperation() {
unsigned n = contractionDims->n[0];
unsigned k = contractionDims->k.back();
std::optional<unsigned> batch;
if (contractionDims->k.size() == 2)
if (contractionDims->k.size() == 3)
contractionDims->k.front();

if (failed(checkAccess(genericOp, m, n, k, batch))) {
// The generic is a Brgemm but the strides of the selected dims (m, n, k)
// are not unit strides. Inject transposes to bring them innermost.
if (failed(makeMinorDimensionsInnerMost(rewriter, genericOp, m, n, k))) {
return WalkResult::interrupt();
return WalkResult::skip();
}
}
return WalkResult::advance();
});
if (res.wasInterrupted()) {
LLVM_DEBUG(llvm::dbgs() << "pass failed!\n");
return signalPassFailure();
}
SmallVector<StringRef> skipPatterns(skipOperations.begin(),
skipOperations.end());
tpp::populateLinalgToXsmmPatterns(patterns, skipPatterns);
Expand Down Expand Up @@ -1069,11 +1093,11 @@ struct ConvertGenericToVnniMatmulLikeOp
return rewriter.notifyMatchFailure(genericOp, "expects buffer semantics");
}

auto [isBrgemmOp, hasBatch] = structured_match::utils::isBrgemmVnniOp(
auto [isMatmulVnni, hasBatch] = structured_match::utils::isMatmulVnniOp(
genericOp, /*operands=*/nullptr);
if (!isBrgemmOp) {
if (!isMatmulVnni) {
return rewriter.notifyMatchFailure(
genericOp, "expects an operation mappable to brgemm");
genericOp, "expects an operation mappable to VNNI contraction");
}

Value bufferA = genericOp.getDpsInputs()[0];
Expand All @@ -1085,7 +1109,15 @@ struct ConvertGenericToVnniMatmulLikeOp
int64_t kPos = 1;
if (hasBatch)
kPos++;
int64_t k = cast<ShapedType>(bufferA.getType()).getShape()[kPos];
// Take the whole reduction dim size. Account for the VNNI factor (ensured
// by the earlier check) that splits the K dim in the shape.
std::optional<int64_t> vnniFactor =
vnni::utils::getVnniBlockingFactor(bufferB.getType());
if (!vnniFactor)
return rewriter.notifyMatchFailure(genericOp,
"failed to determine VNNI factor");
int64_t k =
cast<ShapedType>(bufferA.getType()).getShape()[kPos] * *vnniFactor;
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
int64_t batch = 0;
if (hasBatch)
batch = cast<ShapedType>(bufferA.getType()).getShape()[0];
Expand All @@ -1107,8 +1139,7 @@ struct ConvertGenericToVnniMatmulLikeOp
if (hasBatch)
leadingDimPosOnAandB++;
int64_t lda = (*stridesOnLhs)[leadingDimPosOnAandB];
int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] /
*vnni::utils::getVnniBlockingFactor(bufferB.getType());
int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] / *vnniFactor;
int64_t ldc = (*stridesOnOutput)[0];

BrgemmInfo brgemmInfo{m, n, k, batch, lda,
Expand Down
31 changes: 16 additions & 15 deletions lib/TPP/IR/MatcherUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap,
return res;
}

// Return true if the linalg.generic can be mapped to a brgemm in VNNI
// format.
std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
// Return true if the linalg.generic can be mapped to a matmul (GEMM or BRGEMM)
// in VNNI format.
std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *operands) {
bool hasBatch = false;
auto blockingFactor =
Expand All @@ -56,8 +56,8 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
.operation(NumOfLoops(_OR(EqualsTo(5), EqualsTo(4))))
.input(MatchAll(), HasStaticShape())
.output(MatchAll(), HasStaticShape())
.input(MatchOne(0), HasMap(BroadcastableProjectedPermutation(), &mapOperandA))
.input(MatchOne(1), HasMap(Any(), &mapOperandB))
.input(MatchOne(0), HasMap(ProjectedPermutation(), &mapOperandA))
.input(MatchOne(1), HasMap(ProjectedPermutation(), &mapOperandB))
.output(MatchOne(0), HasMap(BroadcastableProjectedPermutation(), &mapOperandC))
.region(MatchOne(0),
WithOpChain<arith::MulFOp, arith::AddFOp>(operands));
Expand All @@ -82,17 +82,18 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,

llvm::SmallVector<int64_t> operandAPosIterRed = getIteratorPos(
linalgOp, mapOperandA, mlir::utils::IteratorType::reduction);
if (operandAPosIterRed.size() != 2 && operandAPosIterRed.size() != 1)
unsigned numRedItersA = operandAPosIterRed.size();
if (numRedItersA != 3 && numRedItersA != 2)
return std::make_pair(false, hasBatch);

// Check if there is an extra outer batch reduce K-dim.
// For VNNI format:
// - one inner K-dim is the GEMM reduction
// - one inner K-dim is the VNNI blocking factor
int64_t batchRedIter = std::numeric_limits<int64_t>::max();
int64_t kRedIter = std::numeric_limits<int64_t>::max();
if (operandAPosIterRed.size() == 2) {
if (numRedItersA == 3) {
batchRedIter = operandAPosIterRed[0];
kRedIter = operandAPosIterRed[1];
hasBatch = true;
} else {
kRedIter = operandAPosIterRed[0];
}

// Operand B: One parallel iterator (j) and three reduction ones (batch,
Expand All @@ -112,10 +113,10 @@ std::pair<bool, bool> isBrgemmVnniOp(linalg::GenericOp linalgOp,
return std::make_pair(false, hasBatch);
}

auto vnniDim =
vnni::utils::isInVnniLayout(linalgOp, mapOperandB, *blockingFactor);
bool isBrgemmOp = succeeded(vnniDim) && vnniDim->getPosition() == kRedIter;
return std::make_pair(isBrgemmOp, hasBatch);
// At this point, the operation is a valid matmul contraction.
// Finally, ensure that it is in VNNI layout.
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor);
return std::make_pair(isVnniMatmul, hasBatch);
}

// Return true if all the operand have the same type, i.e., no implicit
Expand Down
Loading