Skip to content

Commit

Permalink
mlir-gen shape asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jan 10, 2025
1 parent 6513c14 commit 463b80b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tools/mlir-gen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ Value MLIRGenerator::lowerMatmul(Value input, Value weight, Value output) {
SmallVector<int64_t> vnniShape{inputType.getShape()};
vnniShape.back() = vnniShape.back() / vnniFactor;
vnniShape.push_back(vnniFactor);

auto weightShape = cast<ShapedType>(weight.getType()).getShape();
assert(weightShape.size() >= 3 && "Expected VNNI weights");
assert(vnniShape.back() == weightShape.back() &&
vnniShape.end()[-2] == weightShape.end()[-3] &&
"Input and weights VNNI layout mismatch");

auto vnniType =
RankedTensorType::get(vnniShape, inputType.getElementType());

Expand Down

0 comments on commit 463b80b

Please sign in to comment.