From 37817d7773e419e89a955cdee17296d685df79b0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 2 Jan 2025 17:54:44 -0800 Subject: [PATCH] [FRONTEND] enable construction of named tuples inside triton functions (#5519) --- python/test/unit/language/test_tuple.py | 22 +++++++++++++--------- python/triton/compiler/code_generator.py | 9 +++++++++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py index 8f2729976315..e83938098bfe 100644 --- a/python/test/unit/language/test_tuple.py +++ b/python/test/unit/language/test_tuple.py @@ -114,19 +114,23 @@ class Tensor(NamedTuple): @triton.jit -def _namedtuple_kernel(closure, X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): +def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - # load x - mask_x = (offs_m[:, None] < X.shape[0]) & (offs_n[None, :] < X.shape[1]) + mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1]) + return mask + + +@triton.jit +def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride) Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] - x = tl.load(Xs, mask=mask_x, other=0) - # compute y - y = closure.fn(x, *closure.captured) - # store y - mask_y = (offs_m[:, None] < Y.shape[0]) & (offs_n[None, :] < Y.shape[1]) Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] - tl.store(Ys, y, mask=mask_y) + x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) + y = closure.fn(x, *closure.captured) + tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N)) def test_namedtuple(device="cuda"): diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 2e06103b50a8..e5cc0132f5a8 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -315,6 +315,9 @@ def _is_constexpr_global(self, name): return False + def _is_namedtuple(self, val): + return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") + def _define_name_lookup(self): def local_lookup(name: str, absent): @@ -333,6 +336,7 @@ def global_lookup(name: str, absent): getattr(val, "__triton_builtin__", False), # getattr(val, "__module__", "").startswith("triton.language"), # isinstance(val, language.dtype), # + self._is_namedtuple(val), self._is_constexpr_global(name), # # Allow accesses to globals while visiting an ast.arg # because you should be able to do @@ -535,6 +539,11 @@ def assignTarget(self, target, value): def visit_Assign(self, node): # construct values to assign def _sanitize_value(value): + if self._is_namedtuple(type(value)): + vals = [_sanitize_value(v) for v in value] + types = [v.type for v in vals] + fields = type(value)._fields + return language.tuple(vals, language.tuple_type(types, fields)) if isinstance(value, language.tuple): return language.tuple([_sanitize_value(v) for v in value.values]) native_nontensor_types = (language.dtype, language.tuple)