Skip to content

Commit

Permalink
[AMD] Support global load in local prefetch schedule (#5380)
Browse files Browse the repository at this point in the history
The PR extends the `local-prefetch` instruction scheduling strategy for
the AMD GPUs to handle `global_load` ops.
  • Loading branch information
ravil-mobile authored Jan 7, 2025
1 parent 82625a5 commit 01aa5b2
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 40 deletions.
72 changes: 69 additions & 3 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch arch=gfx942 num_stages=2' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

Expand All @@ -11,6 +11,7 @@ module {
// INSERT_IGLP1-LABEL: @test_dot_op
// INSTR_COUNT_NS1-LABEL: @test_dot_op
// INSTR_COUNT_NS2-LABEL: @test_dot_op
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: @test_dot_op
// LABELING_PS_1-LABEL: @test_dot_op
// LABELING_PS_2-LABEL: @test_dot_op
tt.func @test_dot_op(%lb : index, %ub : index, %step : index,
Expand Down Expand Up @@ -68,8 +69,73 @@ module {
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD:.+]]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE:512]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA:8]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ:32]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ:256]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD]]


// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
Expand Down
4 changes: 2 additions & 2 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,10 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_optimize_dot_operands(pm, True)

stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1"
use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.instruction_sched_variant == "local-prefetch":
stream_prefetch = use_buffer_ops = True
stream_prefetch = True

if amd.has_matrix_core_feature(options.arch):
assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
Expand All @@ -244,6 +243,7 @@ def make_ttgir(mod, metadata, options):
if use_block_pingpong and options.num_stages == 2:
amd.passes.ttgpuir.add_block_pingpong(pm)

use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
if use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
Expand Down
85 changes: 50 additions & 35 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,33 @@ triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) {

return (dotCounter == 1) ? dotOp : nullptr;
}

// The AMDGPU compiler backend can fold consecutive `ds_read/ds_write`
// instructions into wider variants as a part of its load/store optimization
// during the instruction selection pass. If it happens, then it means that
// we are overestimated these types of instructions at the current level of
// the IR. In this scenario, the inserted `sched.group.barriers` will result
// in "fooling" the scheduling solver which can mess up the final assembly.
// To avoid this, we switch off the backend load/store folding optimization
// which is going to prevent instructions folding. In this case, the
// instruction widths of `ds_read/ds_write` instructions are going to match
// their LLVM representations. This is implemented as follows.
// TODO: The current implementation disables `ds_read/ds_write` folding for
// all basic blocks in the currently processed function. We should try to
// avoid it. The compiler backend team proposed to play we the load/store
// alignment values within the currently processed basic block as an
// alternative solution.
void disableInstructionFolding(triton::amdgpu::InstructionSchedHint schedHint) {
auto funcOp = schedHint->getParentOfType<LLVM::LLVMFuncOp>();
MLIRContext *ctx = schedHint->getContext();
llvm::SmallVector<StringAttr> targetFeatures;
if (auto attr = funcOp.getTargetFeatures()) {
llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures));
}
targetFeatures.push_back(str_attr("-load-store-opt"));
funcOp.setTargetFeaturesAttr(
::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures));
}
} // namespace mlir::triton

namespace {
Expand All @@ -121,6 +148,8 @@ namespace {
void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
mlir::amdgpu::sched_barrier_opt_enum maskValue,
int sizeValue, int groupIdValue) {
if (sizeValue < 1)
return;
IntegerAttr mask =
rewriter.getI32IntegerAttr(static_cast<int32_t>(maskValue));
IntegerAttr size =
Expand Down Expand Up @@ -242,13 +271,6 @@ struct InstructionSchedHintsRewriter
PatternRewriter &rewriter, Location loc,
triton::amdgpu::InstructionSchedHint schedHint) const {

if (!(schedHint.getIsBufferLoadsAEnabled() &&
schedHint.getIsBufferLoadsBEnabled())) {
LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` "
"instructions");
return;
}

if (!machineDescr) {
schedHint.emitError("unknown target architecture detected");
return;
Expand All @@ -266,12 +288,14 @@ struct InstructionSchedHintsRewriter
schedHint.getNumGlobalLoadsB().getValue();

if (numBufferLoadInstA == 0) {
schedHint.emitError("buffer load count for tile A must be initialized");
schedHint.emitError(
"global/buffer load count for tile A must be initialized");
return;
}

if (numBufferLoadInstB == 0) {
schedHint.emitError("buffer load count for tile B must be initialized");
schedHint.emitError(
"global/buffer load count for tile B must be initialized");
return;
}

Expand All @@ -296,24 +320,39 @@ struct InstructionSchedHintsRewriter
const uint32_t mmaIssueCycle = this->machineDescr->getMmaIssueCycle();
const uint32_t numLdsDataPaths = this->machineDescr->getNumLdsDataPaths();

// Compute how many ds_reads from tile A we can put between to adjacent
// MFMAs
const auto dsReadAMmaRate = (mmaExecCycle - mmaIssueCycle +
numLdsDataPaths * dsReadAIssueCycle - 1) /
(numLdsDataPaths * dsReadAIssueCycle);

// Compute how many ds_reads from tile B we can put between to adjacent
// MFMAs
const auto dsReadBMmaRate = (mmaExecCycle - mmaIssueCycle +
numLdsDataPaths * dsReadBIssueCycle - 1) /
(numLdsDataPaths * dsReadBIssueCycle);

// Compute how many (MFMA [ds_read]+) clusters we can get from tile A
const auto numDsreadAMma =
(numDsReadInstA + dsReadAMmaRate - 1) / dsReadAMmaRate;

// Compute how many (MFMA [ds_read]+) clusters we can get from tile B
const auto numDsreadBMma =
(numDsReadInstB + dsReadBMmaRate - 1) / dsReadBMmaRate;

// stage 1
// Stage 1
// Compute how many MFMAs we have left for stage 1 - i.e., clusters with
// ds_writes, global/buffer_loads, MFMAs
const auto numMmaStage1 = numMmaInst - (numDsreadAMma + numDsreadBMma);
const auto numMmaPerIssue =
numMmaStage1 / (numBufferLoadInstA + numBufferLoadInstB);

// Compute how many ds_writes we have per global/buffer load resulting from
// tile A
const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA;

// Compute how many ds_writes we have per global/buffer load resulting from
// tile B
const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB;

for (size_t i = 0; i < numBufferLoadInstA; ++i) {
Expand Down Expand Up @@ -377,31 +416,7 @@ struct InstructionSchedHintsRewriter
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0);
}

// The AMDGPU compiler backend can fold consecutive `ds_read/ds_write`
// instructions into wider variants as a part of its load/store optimization
// during the instruction selection pass. If it happens, then it means that
// we are overestimated these types of instructions at the current level of
// the IR. In this scenario, the inserted `sched.group.barriers` will result
// in "fooling" the scheduling solver which can mess up the final assembly.
// To avoid this, we switch off the backend load/store folding optimization
// which is going to prevent instructions folding. In this case, the
// instruction widths of `ds_read/ds_write` instructions are going to match
// their LLVM representations. This is implemented as follows.

// TODO: The current implementation disables `ds_read/ds_write` folding for
// all basic blocks in the currently processed function. We should try to
// avoid it. The compiler backend team proposed to play we the load/store
// alignment values within the currently processed basic block as an
// alternative solution.
auto funcOp = schedHint->getParentOfType<LLVM::LLVMFuncOp>();
MLIRContext *ctx = schedHint->getContext();
llvm::SmallVector<StringAttr> targetFeatures;
if (auto attr = funcOp.getTargetFeatures()) {
llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures));
}
targetFeatures.push_back(str_attr("-load-store-opt"));
funcOp.setTargetFeaturesAttr(
::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures));
disableInstructionFolding(schedHint);
}

LogicalResult
Expand Down

0 comments on commit 01aa5b2

Please sign in to comment.