Skip to content

Commit

Permalink
Merge branch 'main' into pp-med
Browse files Browse the repository at this point in the history
  • Loading branch information
jungpark-mlir authored Jan 3, 2025
2 parents beb3edc + f410f91 commit fb361c1
Show file tree
Hide file tree
Showing 33 changed files with 470 additions and 177 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,18 @@ jobs:
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
- name: Run asan tests on HIP
run: |
cd third_party/amd/python/test/
ulimit -s 1024
export PATH=$(find ~/.triton/llvm -name llvm-symbolizer -printf '%h\n'):$PATH
export LD_LIBRARY_PATH=$(find /opt -name libclang_rt.asan-x86_64.so -printf '%h\n'):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /opt -type d -wholename *lib/llvm/lib/asan):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /usr -name libcaffe2_nvrtc.so -printf '%h\n'):$LD_LIBRARY_PATH
export CLANG_ASAN_LIB=$(find /opt -name libclang_rt.asan-x86_64.so)
export HIP_ASAN_LIB=$(find /opt -wholename *lib/asan/libamdhip64.so)
ASAN_OPTIONS=detect_leaks=0,alloc_dealloc_mismatch=0 \
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
- name: Run regression tests
run: |
# Reenable test_functional_regression.py once it's fixed
Expand Down
14 changes: 12 additions & 2 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ jobs:
python3 -m pytest -s hopper/test_flashattention.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py

- name: Run interpreter tests
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
env:
Expand Down Expand Up @@ -424,7 +423,18 @@ jobs:

# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py

- name: Run asan tests on HIP
run: |
cd third_party/amd/python/test/
ulimit -s 1024
export PATH=$(find ~/.triton/llvm -name llvm-symbolizer -printf '%h\n'):$PATH
export LD_LIBRARY_PATH=$(find /opt -name libclang_rt.asan-x86_64.so -printf '%h\n'):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /opt -type d -wholename *lib/llvm/lib/asan):$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$(find /usr -name libcaffe2_nvrtc.so -printf '%h\n'):$LD_LIBRARY_PATH
export CLANG_ASAN_LIB=$(find /opt -name libclang_rt.asan-x86_64.so)
export HIP_ASAN_LIB=$(find /opt -wholename *lib/asan/libamdhip64.so)
ASAN_OPTIONS=detect_leaks=0,alloc_dealloc_mismatch=0 \
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
- name: Run regression tests
run: |
# Reenable test_functional_regression.py once it's fixed
Expand Down
47 changes: 47 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This is not the build system, just a helper to run common development commands.
# Make sure to first initialize the build system with:
# make dev-install

PYTHON := python
BUILD_DIR := $(shell cd python; $(PYTHON) -c 'from build_helpers import get_cmake_dir; print(get_cmake_dir())')
TRITON_OPT := $(BUILD_DIR)/bin/triton-opt

.PHONY: all
all:
ninja -C $(BUILD_DIR)

.PHONY: triton-opt
triton-opt:
ninja -C $(BUILD_DIR) triton-opt

.PHONY: test-lit
test-lit:
ninja -C $(BUILD_DIR) check-triton-lit-tests

.PHONY: test-cpp
test-cpp:
ninja -C $(BUILD_DIR) check-triton-unit-tests

.PHONY: test-python
test-python: all
$(PYTHON) -m pytest python/test/unit

.PHONY: test
test: test-lit test-cpp test-python

.PHONY: dev-install
dev-install:
# build-time dependencies
$(PYTHON) -m pip install ninja cmake wheel pybind11
# test dependencies
$(PYTHON) -m pip install scipy numpy torch pytest lit pandas matplotlib
$(PYTHON) -m pip install -e python --no-build-isolation -v

.PHONY: golden-samples
golden-samples: triton-opt
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
42 changes: 15 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,36 +130,15 @@ There currently isn't a turnkey way to run all the Triton tests, but you can
follow the following recipe.

```shell
# One-time setup. Note we have to reinstall local Triton because torch
# One-time setup. Note this will reinstall local Triton because torch
# overwrites it with the public version.
$ pip install scipy numpy torch pytest lit pandas matplotlib && pip install -e python
$ make dev-install

# Run Python tests using your local GPU.
$ python3 -m pytest python/test/unit
# To run all tests (requires a GPU)
$ make test

# Move to builddir. Fill in <...> with the full path, e.g.
# `cmake.linux-x86_64-cpython-3.11`.
$ cd python/build/cmake<...>

# Run C++ unit tests.
$ ctest -j32

# Run lit tests.
$ lit test
```

You may find it helpful to make a symlink to the builddir and tell your local
git to ignore it.

```shell
$ ln -s python/build/cmake<...> build
$ echo build >> .git/info/exclude
```

Then you can e.g. rebuild and run lit with the following command.

```shell
$ ninja -C build && ( cd build ; lit test )
# Or, to run tests without a gpu
$ make test-cpp test-lit
```

# Tips for hacking
Expand Down Expand Up @@ -193,6 +172,15 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
separated values to be specified (eg
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions` or
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`).
- `TRITON_ENABLE_ASAN=1` invokes the LLVM address sanitizer for
memory leak and out of bounds access detection. Currently only supported on the AMD
backend. This must be run using the ASAN libraries documented [here](https://rocm.docs.amd.com/projects/llvm-project/en/latest/conceptual/using-gpu-sanitizer.html).

When enabling the address sanitizer it is recommended to disable various memory caching strategies
both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the
memory fault where it originates. This can be done through the HSA_DISABLE_FRAGMENT_ALLOCATOR, AMD_PYTORCH_NO_CUDA_MEMORY_CACHING,
and PYTORCH_NO_HIP_MEMORY_CACHING environment variables.

- `USE_IR_LOC={ttir,ttgir}` reparses the IR such that the location information
will be the line number of the IR file with that particular extension,
instead of line number of the python file. This can provide a direct mapping
Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_HIP_STREAM_PREFETCH",
"TRITON_HIP_USE_BLOCK_PINGPONG",
"TRITON_LLVM_DEBUG_ONLY",
"TRITON_ENABLE_ASAN",
"USE_IR_LOC",
"NVPTX_ENABLE_DUMP",
// clang-format on
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
math::FloorOp, math::ExpM1Op, math::FmaOp, math::LogOp,
math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
math::RsqrtOp, math::SqrtOp, math::RsqrtOp, math::TanhOp>(op))
math::SqrtOp, math::RsqrtOp, math::TanhOp>(op))
return true;
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
triton::FpToFpOp, triton::AddPtrOp, triton::PreciseSqrtOp,
Expand Down
17 changes: 17 additions & 0 deletions python/build_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import sysconfig
import sys
from pathlib import Path


def get_base_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


def get_cmake_dir():
plat_name = sysconfig.get_platform()
python_version = sysconfig.get_python_version()
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name
cmake_dir.mkdir(parents=True, exist_ok=True)
return cmake_dir
15 changes: 2 additions & 13 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

import pybind11

from build_helpers import get_base_dir, get_cmake_dir


@dataclass
class Backend:
Expand Down Expand Up @@ -343,19 +345,6 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
# ---- cmake extension ----


def get_base_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


def get_cmake_dir():
plat_name = sysconfig.get_platform()
python_version = sysconfig.get_python_version()
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name
cmake_dir.mkdir(parents=True, exist_ok=True)
return cmake_dir


class CMakeClean(clean):

def initialize_options(self):
Expand Down
17 changes: 16 additions & 1 deletion python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h"
#include <csignal>
#include <memory>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -217,7 +219,14 @@ void init_triton_llvm(py::module &&m) {
.def("set_calling_conv", &llvm::Function::setCallingConv)
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
std::string &val) { fn->addFnAttr(name, val); })

.def("add_fn_asan_attr",
[](llvm::Function *fn) {
fn->addFnAttr(llvm::Attribute::SanitizeAddress);
})
.def("add_fn_target_feature",
[](llvm::Function *fn, std::string &val) {
fn->addFnAttr("target-features", val);
})
// Sets the nvvm.maxreg property on the given function.
.def("set_nvvm_maxnreg",
[](llvm::Function *fn, int maxnreg) {
Expand Down Expand Up @@ -377,6 +386,12 @@ void init_triton_llvm(py::module &&m) {
fpm.addPass(BreakStructPhiNodesPass());
fpm.addPass(InstCombinePass());
});
bool enableAddressSanitizer =
mlir::triton::tools::getBoolEnv("TRITON_ENABLE_ASAN");
if (enableAddressSanitizer) {
AddressSanitizerOptions Opts;
mpm.addPass(AddressSanitizerPass(Opts));
}
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
mpm.run(*mod, mam);
},
Expand Down
63 changes: 60 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5625,6 +5625,10 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
assert "stmatrix" in kernel.asm["ptx"]


def filter_layout_pairs(layout_pairs):
return [pair for pair in layout_pairs if is_layout_applicable(pair[0]) and is_layout_applicable(pair[1])]


mma_pairs = [
[
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
Expand Down Expand Up @@ -5666,15 +5670,68 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]),
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]),
],
[
WmmaLayout(1, [4, 4]),
WmmaLayout(1, [16, 1]),
],
[
WmmaLayout(1, [16, 1]),
WmmaLayout(1, [4, 4]),
],
[
WmmaLayout(2, [4, 4]),
WmmaLayout(2, [16, 1]),
],
[
WmmaLayout(2, [16, 1]),
WmmaLayout(2, [4, 4]),
],
[
MfmaLayout([2, 0], [2, 2], [32, 32], False),
MfmaLayout([2, 0], [4, 1], [32, 32], False),
],
[
MfmaLayout([2, 0], [4, 1], [32, 32], False),
MfmaLayout([2, 0], [2, 2], [32, 32], False),
],
[
MfmaLayout([2, 0], [2, 2], [32, 32], False),
MfmaLayout([2, 0], [4, 1], [32, 32], True),
],
[
MfmaLayout([2, 0], [4, 1], [32, 32], False),
MfmaLayout([2, 0], [2, 2], [32, 32], True),
],
[
MfmaLayout([2, 0], [4, 4], [16, 16], False),
MfmaLayout([2, 0], [16, 1], [16, 16], False),
],
[
MfmaLayout([2, 0], [16, 1], [16, 16], False),
MfmaLayout([2, 0], [4, 4], [16, 16], False),
],
[
MfmaLayout([2, 0], [4, 4], [16, 16], False),
MfmaLayout([2, 0], [16, 1], [16, 16], True),
],
[
MfmaLayout([2, 0], [16, 1], [16, 16], False),
MfmaLayout([2, 0], [4, 4], [16, 16], True),
],
]


@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
@pytest.mark.parametrize("M, N", [[16, 16], [64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("mma_pair", filter_layouts(mma_pairs))
@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs))
def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
if is_hip():
if isinstance(mma_pair[1], MfmaLayout) and (mma_pair[1].instr_shape[1] > M or mma_pair[1].instr_shape[1] > N):
pytest.skip("HIP do not fully support skinny tensor store")

src_layout, _ = mma_pair
num_warps = np.prod(src_layout.warps_per_cta)
warp_size = THREADS_PER_WARP

def do_test(src_layout, dst_layout):
layouts = f"""
Expand All @@ -5683,7 +5740,7 @@ def do_test(src_layout, dst_layout):
"""

ir = layouts + f"""
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32}} {{
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {warp_size} : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>>
Expand Down
Loading

0 comments on commit fb361c1

Please sign in to comment.