Skip to content

Commit

Permalink
[BACKEND] Add Address Sanitizer Pass (#5127)
Browse files Browse the repository at this point in the history
Add address sanitizer pass to LLVM pass pipeline. 

Known limitations for this PR:
- Currently only the AMD backend is supported
- Source code line support not implemented here,
  coming in follow up patch
  • Loading branch information
CRobeck authored Jan 3, 2025
1 parent f5e949f commit dc261bf
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 12 deletions.
17 changes: 15 additions & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ jobs:
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py --ignore=test_address_sanitizer.py
python3 -m pytest -s -n 8 language/test_subprocess.py
python3 -m pytest -s -n 8 test_debug.py --forked
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
Expand Down Expand Up @@ -429,14 +429,27 @@ jobs:
cd python/test/unit
pytest --capture=tee-sys -rfs -n 12 language runtime \
--ignore=language/test_line_info.py \
--ignore=test_debug.py
--ignore=test_debug.py \
--ignore=test_address_sanitizer.py
# TODO: uncomment
# pytest --capture=tee-sys -rfs test_debug.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
- name: Run asan tests on HIP
run: |
cd python/test/unit
ulimit -s 1024
export PATH=$(find ~/.triton/llvm -name llvm-symbolizer -printf '%h\n'):$PATH
export LD_LIBRARY_PATH=$(find /opt -name libclang_rt.asan-x86_64.so -printf '%h\n'):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /opt -type d -wholename *lib/llvm/lib/asan):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /usr -name libcaffe2_nvrtc.so -printf '%h\n'):$LD_LIBRARY_PATH
export CLANG_ASAN_LIB=$(find /opt -name libclang_rt.asan-x86_64.so)
export HIP_ASAN_LIB=$(find /opt -wholename *lib/asan/libamdhip64.so)
ASAN_OPTIONS=detect_leaks=0,alloc_dealloc_mismatch=0 \
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
- name: Run regression tests
run: |
# Reenable test_functional_regression.py once it's fixed
Expand Down
19 changes: 15 additions & 4 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ jobs:
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
fi
cd python/test/unit
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py --ignore=test_address_sanitizer.py
python3 -m pytest -s -n 8 language/test_subprocess.py
python3 -m pytest -s -n 8 test_debug.py --forked
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
Expand All @@ -309,7 +309,6 @@ jobs:
python3 -m pytest -s hopper/test_flashattention.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py

- name: Run interpreter tests
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
env:
Expand Down Expand Up @@ -416,15 +415,27 @@ jobs:
cd python/test/unit
pytest --capture=tee-sys -rfs -n 12 language runtime \
--ignore=language/test_line_info.py \
--ignore=test_debug.py
--ignore=test_debug.py \
--ignore=test_address_sanitizer.py
# TODO: uncomment
# pytest --capture=tee-sys -rfs test_debug.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py

# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py

- name: Run asan tests on HIP
run: |
cd python/test/unit
ulimit -s 1024
export PATH=$(find ~/.triton/llvm -name llvm-symbolizer -printf '%h\n'):$PATH
export LD_LIBRARY_PATH=$(find /opt -name libclang_rt.asan-x86_64.so -printf '%h\n'):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /opt -type d -wholename *lib/llvm/lib/asan):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /usr -name libcaffe2_nvrtc.so -printf '%h\n'):$LD_LIBRARY_PATH
export CLANG_ASAN_LIB=$(find /opt -name libclang_rt.asan-x86_64.so)
export HIP_ASAN_LIB=$(find /opt -wholename *lib/asan/libamdhip64.so)
ASAN_OPTIONS=detect_leaks=0,alloc_dealloc_mismatch=0 \
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
- name: Run regression tests
run: |
# Reenable test_functional_regression.py once it's fixed
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
separated values to be specified (eg
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions` or
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`).
- `TRITON_ENABLE_ASAN=1` invokes the LLVM address sanitizer for
memory leak and out of bounds access detection. Currently only supported on the AMD
backend. This must be run using the ASAN libraries documented [here](https://rocm.docs.amd.com/projects/llvm-project/en/latest/conceptual/using-gpu-sanitizer.html).

When enabling the address sanitizer it is recommended to disable various memory caching strategies
both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the
memory fault where it originates. This can be done through the HSA_DISABLE_FRAGMENT_ALLOCATOR, AMD_PYTORCH_NO_CUDA_MEMORY_CACHING,
and PYTORCH_NO_HIP_MEMORY_CACHING environment variables.

- `USE_IR_LOC={ttir,ttgir}` reparses the IR such that the location information
will be the line number of the IR file with that particular extension,
instead of line number of the python file. This can provide a direct mapping
Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_HIP_STREAM_PREFETCH",
"TRITON_HIP_USE_BLOCK_PINGPONG",
"TRITON_LLVM_DEBUG_ONLY",
"TRITON_ENABLE_ASAN",
"USE_IR_LOC",
"NVPTX_ENABLE_DUMP",
// clang-format on
Expand Down
17 changes: 16 additions & 1 deletion python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h"
#include <csignal>
#include <memory>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -217,7 +219,14 @@ void init_triton_llvm(py::module &&m) {
.def("set_calling_conv", &llvm::Function::setCallingConv)
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
std::string &val) { fn->addFnAttr(name, val); })

.def("add_fn_asan_attr",
[](llvm::Function *fn) {
fn->addFnAttr(llvm::Attribute::SanitizeAddress);
})
.def("add_fn_target_feature",
[](llvm::Function *fn, std::string &val) {
fn->addFnAttr("target-features", val);
})
// Sets the nvvm.maxreg property on the given function.
.def("set_nvvm_maxnreg",
[](llvm::Function *fn, int maxnreg) {
Expand Down Expand Up @@ -377,6 +386,12 @@ void init_triton_llvm(py::module &&m) {
fpm.addPass(BreakStructPhiNodesPass());
fpm.addPass(InstCombinePass());
});
bool enableAddressSanitizer =
mlir::triton::tools::getBoolEnv("TRITON_ENABLE_ASAN");
if (enableAddressSanitizer) {
AddressSanitizerOptions Opts;
mpm.addPass(AddressSanitizerPass(Opts));
}
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
mpm.run(*mod, mam);
},
Expand Down
33 changes: 33 additions & 0 deletions python/test/unit/address_sanitizer_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import triton
import triton.language as tl

size = 4096
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )


@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
#Set access to go out of bounds for ASAN test
offsets = block_start + tl.arange(0, BLOCK_SIZE) + 1
x = tl.load(x_ptr + offsets)
y = tl.load(y_ptr + offsets)
output = x + y
tl.store(output_ptr + offsets, output)


pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
amdgcn = pgm.asm['amdgcn']
print(amdgcn)
34 changes: 34 additions & 0 deletions python/test/unit/test_address_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import subprocess

import triton


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def test_address_sanitizer():
if not is_hip():
return #not supported on NV backend

# It is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch
# This will give the address sanitizer the best chance at finding the memory fault where it originates,
# otherwise it could be masked by writing past the end of a cached block within a larger allocation.
os.environ["HSA_DISABLE_FRAGMENT_ALLOCATOR"] = "1"
os.environ["AMD_PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
os.environ["PYTORCH_NO_HIP_MEMORY_CACHING"] = "1"
os.environ["TRITON_ENABLE_ASAN"] = "1"

# HSA_XNACK here is required to set the xnack+ setting for the GPU at runtime.
# If it is not set and the default xnack setting of the system is xnack-
# a runtime error something like "No kernel image found" will occur. The system
# xnack setting can be found through rocminfo. xnack+ is required for ASAN.
# More information about xnack in general can be found here:
# https://llvm.org/docs/AMDGPUUsage.html#target-features
# https://rocm.docs.amd.com/en/docs-6.1.0/conceptual/gpu-memory.html
os.environ["HSA_XNACK"] = "1"

out = subprocess.Popen(["python", "address_sanitizer_helper.py"], stderr=subprocess.PIPE, stdout=subprocess.PIPE)
assert "Begin function __asan_report" in out.stdout.read().decode()
assert "heap-buffer-overflow" in out.stderr.read().decode()
8 changes: 7 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,13 @@ def compile(src, target=None, options=None):
# This is needed to safely finalize threads pool inside context: if current process forks before
# python GC deletes context object, thread pool in child process will be invalid, which could
# lead to child crash or hang.
context.disable_multithreading()
#
# However disabling multithreading causes the code to hang if the ASAN pass is enabled
# this is likely due to the llvm-symbolizer forking a process
# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
# multithreading in the MLIR context
if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
context.disable_multithreading()
# return handle to compiled kernel
return CompiledKernel(src, metadata_group, hash)

Expand Down
23 changes: 20 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def make_llir(src, metadata, options):
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
amd.attach_target_triple(llvm_mod)
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, '')
target_features = ''
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features = '+xnack'
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)

# Set various control constants on the LLVM module so that device
# libraries can resolve references to them.
Expand All @@ -330,13 +333,24 @@ def make_llir(src, metadata, options):
fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}")
denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee"
fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode)
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
fns[0].add_fn_target_feature("+xnack")
fns[0].add_fn_asan_attr()

# Hint the compiler that we'd like the firmware to set the kernel arguments
# to user SGPRs so that the kernel does not need to s_load its arguments
# from memory.
amd.set_all_fn_arg_inreg(fns[0])

if options.extern_libs:
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
default_libdir = Path(__file__).parent / 'lib'
paths = [
str(default_libdir / 'asanrtl.bc'),
str(default_libdir / "ocml.bc"),
str(default_libdir / "ockl.bc")
]
llvm.link_extern_libs(llvm_mod, paths)
elif options.extern_libs:
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
llvm.link_extern_libs(llvm_mod, paths)

Expand Down Expand Up @@ -368,7 +382,10 @@ def make_amdgcn(src, metadata, options):

@staticmethod
def make_hsaco(src, metadata, options):
hsaco = amd.assemble_amdgcn(src, options.arch, '')
target_features = ''
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features = '+xnack'
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)

rocm_path = HIPBackend.path_to_rocm_lld()
with tempfile.NamedTemporaryFile() as tmp_out:
Expand Down
Binary file added third_party/amd/backend/lib/asanrtl.bc
Binary file not shown.
4 changes: 3 additions & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ def make_llir(src, metadata, options, capability):
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()

if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
raise ASANError(
"Address Sanitizer Error: Address sanitizer is currently only supporteedd on the AMD backend")
llvm_mod = llvm.to_module(mod, context)
proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
features = get_features(options)
Expand Down

0 comments on commit dc261bf

Please sign in to comment.