From 01aa5b25c98a95f1cff1b109785ccf7cdecef2e3 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 7 Jan 2025 21:52:58 +0300 Subject: [PATCH] [AMD] Support global load in local prefetch schedule (#5380) The PR extends the `local-prefetch` instruction scheduling strategy for the AMD GPUs to handle `global_load` ops. --- test/TritonGPU/amd/amd-instruction-sched.mlir | 72 +++++++++++++++- third_party/amd/backend/compiler.py | 4 +- .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 85 +++++++++++-------- 3 files changed, 121 insertions(+), 40 deletions(-) diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 8cf3bdcafde6..f24a2eab63ae 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -2,7 +2,7 @@ // RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch arch=gfx942 num_stages=2' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 @@ -11,6 +11,7 @@ module { // INSERT_IGLP1-LABEL: @test_dot_op // INSTR_COUNT_NS1-LABEL: @test_dot_op // INSTR_COUNT_NS2-LABEL: @test_dot_op + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: @test_dot_op // LABELING_PS_1-LABEL: @test_dot_op // LABELING_PS_2-LABEL: @test_dot_op tt.func @test_dot_op(%lb : index, %ub : index, %step : index, @@ -68,8 +69,73 @@ module { // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> - // USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints] - // USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD:.+]] + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE:512]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA:8]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ:32]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ:256]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD]] + // LABELING_PS_1: scf.for // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index fff0e1af6e77..c3ccded47a82 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -220,11 +220,10 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_optimize_dot_operands(pm, True) stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" - use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" # The `local-prefetch` scheduling variant requires turning on buffer ops. if options.instruction_sched_variant == "local-prefetch": - stream_prefetch = use_buffer_ops = True + stream_prefetch = True if amd.has_matrix_core_feature(options.arch): assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. " @@ -244,6 +243,7 @@ def make_ttgir(mod, metadata, options): if use_block_pingpong and options.num_stages == 2: amd.passes.ttgpuir.add_block_pingpong(pm) + use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" if use_buffer_ops: amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 333a78b4c526..dd5b655cfcf7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -112,6 +112,33 @@ triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) { return (dotCounter == 1) ? dotOp : nullptr; } + +// The AMDGPU compiler backend can fold consecutive `ds_read/ds_write` +// instructions into wider variants as a part of its load/store optimization +// during the instruction selection pass. If it happens, then it means that +// we are overestimated these types of instructions at the current level of +// the IR. In this scenario, the inserted `sched.group.barriers` will result +// in "fooling" the scheduling solver which can mess up the final assembly. +// To avoid this, we switch off the backend load/store folding optimization +// which is going to prevent instructions folding. In this case, the +// instruction widths of `ds_read/ds_write` instructions are going to match +// their LLVM representations. This is implemented as follows. +// TODO: The current implementation disables `ds_read/ds_write` folding for +// all basic blocks in the currently processed function. We should try to +// avoid it. The compiler backend team proposed to play we the load/store +// alignment values within the currently processed basic block as an +// alternative solution. +void disableInstructionFolding(triton::amdgpu::InstructionSchedHint schedHint) { + auto funcOp = schedHint->getParentOfType(); + MLIRContext *ctx = schedHint->getContext(); + llvm::SmallVector targetFeatures; + if (auto attr = funcOp.getTargetFeatures()) { + llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures)); + } + targetFeatures.push_back(str_attr("-load-store-opt")); + funcOp.setTargetFeaturesAttr( + ::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures)); +} } // namespace mlir::triton namespace { @@ -121,6 +148,8 @@ namespace { void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, mlir::amdgpu::sched_barrier_opt_enum maskValue, int sizeValue, int groupIdValue) { + if (sizeValue < 1) + return; IntegerAttr mask = rewriter.getI32IntegerAttr(static_cast(maskValue)); IntegerAttr size = @@ -242,13 +271,6 @@ struct InstructionSchedHintsRewriter PatternRewriter &rewriter, Location loc, triton::amdgpu::InstructionSchedHint schedHint) const { - if (!(schedHint.getIsBufferLoadsAEnabled() && - schedHint.getIsBufferLoadsBEnabled())) { - LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` " - "instructions"); - return; - } - if (!machineDescr) { schedHint.emitError("unknown target architecture detected"); return; @@ -266,12 +288,14 @@ struct InstructionSchedHintsRewriter schedHint.getNumGlobalLoadsB().getValue(); if (numBufferLoadInstA == 0) { - schedHint.emitError("buffer load count for tile A must be initialized"); + schedHint.emitError( + "global/buffer load count for tile A must be initialized"); return; } if (numBufferLoadInstB == 0) { - schedHint.emitError("buffer load count for tile B must be initialized"); + schedHint.emitError( + "global/buffer load count for tile B must be initialized"); return; } @@ -296,24 +320,39 @@ struct InstructionSchedHintsRewriter const uint32_t mmaIssueCycle = this->machineDescr->getMmaIssueCycle(); const uint32_t numLdsDataPaths = this->machineDescr->getNumLdsDataPaths(); + // Compute how many ds_reads from tile A we can put between to adjacent + // MFMAs const auto dsReadAMmaRate = (mmaExecCycle - mmaIssueCycle + numLdsDataPaths * dsReadAIssueCycle - 1) / (numLdsDataPaths * dsReadAIssueCycle); + + // Compute how many ds_reads from tile B we can put between to adjacent + // MFMAs const auto dsReadBMmaRate = (mmaExecCycle - mmaIssueCycle + numLdsDataPaths * dsReadBIssueCycle - 1) / (numLdsDataPaths * dsReadBIssueCycle); + // Compute how many (MFMA [ds_read]+) clusters we can get from tile A const auto numDsreadAMma = (numDsReadInstA + dsReadAMmaRate - 1) / dsReadAMmaRate; + + // Compute how many (MFMA [ds_read]+) clusters we can get from tile B const auto numDsreadBMma = (numDsReadInstB + dsReadBMmaRate - 1) / dsReadBMmaRate; - // stage 1 + // Stage 1 + // Compute how many MFMAs we have left for stage 1 - i.e., clusters with + // ds_writes, global/buffer_loads, MFMAs const auto numMmaStage1 = numMmaInst - (numDsreadAMma + numDsreadBMma); const auto numMmaPerIssue = numMmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + // Compute how many ds_writes we have per global/buffer load resulting from + // tile A const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + + // Compute how many ds_writes we have per global/buffer load resulting from + // tile B const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; for (size_t i = 0; i < numBufferLoadInstA; ++i) { @@ -377,31 +416,7 @@ struct InstructionSchedHintsRewriter rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); } - // The AMDGPU compiler backend can fold consecutive `ds_read/ds_write` - // instructions into wider variants as a part of its load/store optimization - // during the instruction selection pass. If it happens, then it means that - // we are overestimated these types of instructions at the current level of - // the IR. In this scenario, the inserted `sched.group.barriers` will result - // in "fooling" the scheduling solver which can mess up the final assembly. - // To avoid this, we switch off the backend load/store folding optimization - // which is going to prevent instructions folding. In this case, the - // instruction widths of `ds_read/ds_write` instructions are going to match - // their LLVM representations. This is implemented as follows. - - // TODO: The current implementation disables `ds_read/ds_write` folding for - // all basic blocks in the currently processed function. We should try to - // avoid it. The compiler backend team proposed to play we the load/store - // alignment values within the currently processed basic block as an - // alternative solution. - auto funcOp = schedHint->getParentOfType(); - MLIRContext *ctx = schedHint->getContext(); - llvm::SmallVector targetFeatures; - if (auto attr = funcOp.getTargetFeatures()) { - llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures)); - } - targetFeatures.push_back(str_attr("-load-store-opt")); - funcOp.setTargetFeaturesAttr( - ::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures)); + disableInstructionFolding(schedHint); } LogicalResult