Skip to content

Commit

Permalink
Refactor util
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Sep 4, 2024
1 parent 80695c9 commit 97f05a1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 18 deletions.
3 changes: 3 additions & 0 deletions include/TPP/Transforms/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ bool isBlockedMatmul(Operation *op);
FailureOr<linalg::ContractionDimensions>
isContraction(linalg::LinalgOp linalgOp);

// Return constant range span or nullopt, otherwise.
std::optional<int64_t> getConstantRange(const Range &range);

// Validate a tile configuration for a linalgOp when we can statically do that.
// Specific dims can be passed using 'dims'. If dims is empty the validation
// will start from the outermost dimension, moving to innermost ones up to the
Expand Down
20 changes: 3 additions & 17 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,20 +274,6 @@ packConvolutions(RewriterBase &rewriter, OpTy convOp,
return replacementOp;
}

/// Return constant range span or nullopt, otherwise.
static std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> stride = getConstantIntValue(range.stride);
if (!stride || *stride != 1)
return std::nullopt;
std::optional<int64_t> offset = getConstantIntValue(range.offset);
if (!offset)
return std::nullopt;
std::optional<int64_t> size = getConstantIntValue(range.size);
if (!size)
return std::nullopt;
return (*size - *offset);
}

//===----------------------------------------------------------------------===//
// Conv2DNhwcHwcfOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -528,13 +514,13 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);

if (std::optional<int64_t> dimM =
getConstantRange(iterationDomain[dims->m.back()]))
linalgx::utils::getConstantRange(iterationDomain[dims->m.back()]))
options.blockFactors[0] = std::min(*dimM, options.blockFactors[0]);
if (std::optional<int64_t> dimN =
getConstantRange(iterationDomain[dims->n.back()]))
linalgx::utils::getConstantRange(iterationDomain[dims->n.back()]))
options.blockFactors[1] = std::min(*dimN, options.blockFactors[1]);
if (std::optional<int64_t> dimK =
getConstantRange(iterationDomain[dims->k.back()]))
linalgx::utils::getConstantRange(iterationDomain[dims->k.back()]))
options.blockFactors[2] = std::min(*dimK, options.blockFactors[2]);

// Apply more restrictive packing validation.
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ isContraction(linalg::LinalgOp linalgOp) {
return dims;
}

static std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> stride = getConstantIntValue(range.stride);
if (!stride || *stride != 1)
return std::nullopt;
Expand Down

0 comments on commit 97f05a1

Please sign in to comment.