-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Backend] Implement layout conversion within warps with shuffle idx #5419
base: main
Are you sure you want to change the base?
Conversation
Warp shuffles are affine transformations on the lane ID. This PR implements an approach that recognizes we can represent affine transformations by augmenting the linear layout with a dummy dimension and then decompose the mapping from src to dst layout into a series of affine transformations that represent warp shuffles. It turns out for all distributed layouts, we can implement this transformation using a single index shuffle (mask + permute).
FWIW, It might be a bit too much for a first PR, but something that'd be a nice north start would be to discover the neat algorithm in |
I went a bit overboard with the implementation instead of keeping it simple (couldn't resist)... but the solution is much more general now. I'm not sure the PR is in a landable state however, given the complexity. I also really need to set this down to push on other things but I would love comments on the approach when you guys have time.
The approach I found doesn't quite do that yet, because it doesn't account for bytepacking the i32s for warp shuffles and there isn't TargetInfo interface to the byte permute instructions (the ROCDL dialect doesn't expose the AMD one, but we can probably do something similar as the With a bit more determination, it should be possible to subsume those special patterns though! |
Will take a look soon. Thanks! |
So cool! Will look into it tomorrow! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished the first round. I get the high level idea, which is similar to the gather function. But I do need a second round review to dig into math details. Right now, seems like there are a few symbols that I misunderstood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Jokeren thanks for the review! And yeah it did end up looking similar to gather in the way the selects are handled
For the most part, the implementation seems correct, but there are a few conversions of 4d+ tensors that are broken. I'm digging into those but it's not clear how the rank affects the code here |
I also noticed the time spent in Python tests has been increased? Can you please investigate a bit? |
I've verified that there aren't any performance regressions. Interestingly, I cannot seem to find any microbenchmarks where warp shuffles perform any different than shared memory conversions (I'm measuring differences under 0.5% for many different kernels...). I'm playing around with kernels that look like @triton.jit
def kernel(a_ptr, b_ptr):
idx = (tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None])
x = tl.load(a_ptr + idx)
for i in range(10000000):
x += tl.trans(x.reshape((2, 2, 2, 2, 2, 2, 2, 2)), (1, 3, 5, 7, 0, 2, 4, 6)).reshape(16, 16)
tl.store(b_ptr + idx, x) Which AFAICT just runs layout conversion in a loop. The only observable difference in the SASS between the smem and warp shuffle versions (besides the implementation of the conversion) is that the SMEM version tends to get unrolled by a larger factor (4 or 8) probably since it tends to use fewer instructions. Maybe I'm measuring the wrong things! At the very least, this reduces the amount of shared memory that has to be allocated for kernels without changing the performance 🙃 . |
Thanks for the update. Will take a pass tomorrow or Sunday! |
@Jokeren I double-checked my benchmark script and realized I made an incredibly silly mistake. Now I am actually measuring differences between the two implementations. In some cases, conversion using warp shuffles is up to 2x faster, and in some cases, it is 2x slower! I think this has to do with the number of selects that get emitted and the smem iterations / vector length. I'm trying to narrow down a basic heuristic that can be used to make a good choice. |
This reverts commit b67c916.
@Jokeren Here's a hopefully relatively detailed description of the layout conversions where speedups were observed. First, here's the benchmark script: import triton
import pathlib
import torch
import triton.language as tl
@triton.jit
def kernel1(a_ptr, b_ptr):
idx = (tl.arange(0, 16)[None, :] + tl.arange(0, 16)[:, None])
x = tl.load(a_ptr + idx)
for i in range(1000):
x += x
tl.store(b_ptr + idx, x)
TSIZE=32
dtype='i64'
torch_dtype=torch.int64
ttgir = f"""
#blocked = #ttg.blocked<{{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0]}}>
#test = #ttg.blocked<{{sizePerThread = [16, 2], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}}>
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<{dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
%c1_i32 = arith.constant 1 : i32
%c1000000_i32 = arith.constant 100000 : i32
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_range {{end = {TSIZE} : i32, start = 0 : i32}} : tensor<{TSIZE}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{TSIZE}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{TSIZE}xi32, #blocked>
%2 = tt.make_range {{end = {TSIZE} : i32, start = 0 : i32}} : tensor<{TSIZE}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>>
%3 = tt.expand_dims %2 {{axis = 1 : i32}} : tensor<{TSIZE}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{TSIZE}x1xi32, #blocked>
%4 = tt.broadcast %1 : tensor<1x{TSIZE}xi32, #blocked> -> tensor<{TSIZE}x{TSIZE}xi32, #blocked>
%5 = tt.broadcast %3 : tensor<{TSIZE}x1xi32, #blocked> -> tensor<{TSIZE}x{TSIZE}xi32, #blocked>
%6 = arith.addi %4, %5 : tensor<{TSIZE}x{TSIZE}xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<{dtype}> -> tensor<{TSIZE}x{TSIZE}x!tt.ptr<{dtype}>, #blocked>
%8 = tt.addptr %7, %6 : tensor<{TSIZE}x{TSIZE}x!tt.ptr<{dtype}>, #blocked>, tensor<{TSIZE}x{TSIZE}xi32, #blocked>
%9 = tt.load %8 : tensor<{TSIZE}x{TSIZE}x!tt.ptr<{dtype}>, #blocked>
%10 = scf.for %arg2 = %c0_i32 to %c1000000_i32 step %c1_i32 iter_args(%arg3 = %9) -> (tensor<{TSIZE}x{TSIZE}x{dtype}, #blocked>) : i32 {{
%x = ttg.convert_layout %arg3 : tensor<{TSIZE}x{TSIZE}x{dtype}, #blocked> -> tensor<{TSIZE}x{TSIZE}x{dtype}, #test>
%res = arith.addi %x, %x : tensor<{TSIZE}x{TSIZE}x{dtype}, #test>
%13 = ttg.convert_layout %res : tensor<{TSIZE}x{TSIZE}x{dtype}, #test> -> tensor<{TSIZE}x{TSIZE}x{dtype}, #blocked>
scf.yield %13 : tensor<{TSIZE}x{TSIZE}x{dtype}, #blocked>
}}
%11 = tt.splat %arg1 : !tt.ptr<{dtype}> -> tensor<{TSIZE}x{TSIZE}x!tt.ptr<{dtype}>, #blocked>
%12 = tt.addptr %11, %6 : tensor<{TSIZE}x{TSIZE}x!tt.ptr<{dtype}>, #blocked>, tensor<{TSIZE}x{TSIZE}xi32, #blocked>
tt.store %12, %10 : tensor<{TSIZE}x{TSIZE}x!tt.ptr<{dtype}>, #blocked>
tt.return
}}
}}
"""
temp_file = pathlib.Path("test.ttgir")
temp_file.write_text(ttgir)
kernel = triton.compile(str(temp_file))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # Argument names to use as an x-axis for the plot.
x_vals=[1], # Different possible values for `x_name`.
x_log=True, # x axis is logarithmic.
line_arg='provider', # Argument name whose value corresponds to a different line in the plot.
line_vals=['triton'], # Possible values for `line_arg`.
line_names=['Triton'], # Label name for the lines.
styles=[('blue', '-')], # Line styles.
ylabel='GB/s', # Label name for the y-axis.
plot_name='cvt',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(size, provider):
a = torch.arange(TSIZE * TSIZE, device="cuda", dtype=torch_dtype).reshape(TSIZE, TSIZE)
b = torch.empty_like(a)
#kernel[(1, )](a, b, num_warps=1)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(lambda: kernel[(1, 1, 1)](a, b), quantiles=quantiles)
return ms, min_ms, max_ms
a = torch.arange(TSIZE * TSIZE, device="cuda", dtype=torch_dtype).reshape(TSIZE, TSIZE)
b = torch.empty_like(a)
print(kernel.asm["llir"])
benchmark.run(print_data=True)
The closer I also tested against vectorized shmem when I also tested layouts where For |
@Mogball thanks for the results. IIUC, packing i8/i16 was not implemented for the evaluation? It's good enough to demonstrate the effectiveness already, but please let me know if packing is on your roadmap. |
Packing wasn't implemented. I do plan to tackle it later though. It will be important for some internal use cases! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At last I found time to read through this! I got the main idea, but I have not reviewed the whole code, as I think we can do better. I'll leave it up to you to decide whether you wan to re-work this PR, or whether you want to reimplement the generic algo. If it's the former, I am happy to do a detailed review of the whole thing.
Let's consider the running example from the comments to make things clearer
src = {register = [[1,0], [2,0]], lane = [[0,1], [0,2]]}
dst = {register = [[0,1], [2,0]], lane = [[1,0], [0,2]]}
So, we have two linear functions, S, T: H -> M (Source and Target from Hardware into a Matrix). In the example, M = F_2^2 x F_2^2 and H is "the same as M but with labels" (i.e. isomorphic to M). Let's assume there is no thread broadcasting in S or T for now for simplicity.
First, we want to find a subspace of M so that the elements it defines belong to different threads in S and T. A way to do this is:
- Find a basis of the intersection of I = S[lane] \cap T[lane]. In our example it's {(0, 2)}
- Consider C = {x ^ y | (x, y) \in zip(S[lane]\I, T[lane]\I)}. In our example it's {(1, 1)}.
A basis of a subspace with the property above is given by B = C \cup I (check that indeed the elements {(0, 0), (0, 2), (1, 1), (1, 3)} all belong different threads).
Then we complete this basis into a basis of the whole space. In our example, we choose R = {(0, 1), (2, 0)} (we could have chosen (1, 0) as the first element of the basis, it doesn't matter).
R gives us the iterations of our algorithm. In the first iteration we'll interchange the elements of B, in the next one we'll interchange the elements of B ^ (0, 1), then B ^ (2, 0) and finally B ^ (2, 1). In each iteration we'll want to get the target values for the elements of these sets, so we'll need to compute the relevant sources that own them.
Maths tells us that this split of M will have all the threads in every iteration, so it leads to feasible warp shuffles. Even more, we see that in this example S^{-1}(x) = T^{-1}(x) for x \in B and for x \in B ^ (2, 0), which means that we don't need to emit those shuffles because the elements are already in the right position (this can be computed looking at the kernel of src^{-1} o dst - id where src and dst are the matrices with the zeros removed (probably, I haven't fully checked)).
@@ -58,6 +58,12 @@ struct ScratchConfig { | |||
} | |||
}; | |||
|
|||
// For a layout conversion between `srcTy` and `dstTy`, return the vector length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nb. this will be removed once proper vectorisation is implemented in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on that proper vectorization means?
// | ||
// Warp shuffles allow indexing into another lane, but does not allowing | ||
// selecting the register. Suppose we decompose `C` into `C = P1 ∘ W ∘ P2`, | ||
// where `W` is a warp shuffle and `P1` and `P2` are (lane-dependent) register |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you define a warp shuffle tho. It's not clear (see the next point).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I defined a single warp shuffle as
W = [ I R' ]
[ 0 L ]
Which is just one index shuffle.
// Note for each destination register, two lanes want two different registers | ||
// in the same source lane (T0:0 -> T0:0, T1:0 -> T0:1). This is impossible to | ||
// represent with a warp shuffle, because the source lane (e.g. T0) can only | ||
// supply one of its registers as the shuffle value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not impossible tho. In fact, in that example you can do it in 2 shuffles! Sure, T0 wants its register 0, but it already has it. In fact thread t
can put out register (t&1)^1
and read from t ^ 1
(and similarly for reg 2 and 3) and we've shared the data in two shuffles!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but the algorithm I implemented actually finds a register permutation so that it can be done in just 1 warp shuffle. This algorithm breaks down the conversion transformation into register permute -> one index shuffle -> register permute
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I misunderstood. Mario's algorithm does the whole thing in 2 shuffles total. I was referring to 1 shuffle "per tensor element in each thread". We chatted quickly about this offline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. IMO it's fine to merge this PR first and improve it later.
@lezcano Thanks for the detailed explanation! One of the sketchiest parts of the current implementation is the part where I manually scan the bases in the The main difference between the two algorithms is the one I implemented always generates 1 warp shuffle since it has a rigid structure, whereas your algorithm may generate more due to breaking down the transformation into iterations. The algorithm I implemented only fails when there is lane broadcasting, because it makes some of the sublayouts noninvertible. I put off dealing with this in the future. Otherwise, every layout conversion within a warp can be codegened with a single index shuffle. I actually think that there are cases where generating more warp shuffles is better than trying to permute registers to generate 1, because the register permutation can have a significant runtime cost. I will try to think about how to find a decent composition of the two algorithms. |
FYI based on this table I think a shuffle may cost the same as 4 moves, assuming mov and fma have similar throughput. |
@lezcano After thinking about this more, I am not certain that either algorithm is more general, since they both (as written) handle the same cases, but generate different code! In brief, the approach I implemented always tries to decompose the conversion into However, in order to handle broadcasting in the conversion, I am pretty sure we will need to codegen at least 2 warp shuffles instead of 1. If the number of shuffles required to handle broadcasting is variable (instead of just one index plus one butterfly shuffle), then that is where I think your approach might be better suited. But broadcasting isn't needed to support the current use cases that matter. The other factor is that sometimes it is better to emit multiple warp shuffles rather than try to emit a single shuffle with register permutations. But for smaller data types, register permutations become cheaper because we can use byte permute instructions, so the benefits aren't clear at the moment. I think a possible next step would be to implement your algorithm separately and try to understand the cases in which it generates better code, and maybe figure out how to unify the approaches. |
That's good to know. It would certainly help build a cost model for an algorithm that has to pick between generating register permutations and warp shuffles. |
Update: I chatted with Mario and I did misunderstand a few of the details. In particular, for the example
Since some of the elements are already in the right spot (identity from dst->src), it's possible to shuffle just a subset of the data. In particular,
One can see that only two shuffles are needed to put the data in the right places. |
Warp shuffles are affine transformations on the lane ID. This PR implements an approach that recognizes we can represent affine transformations by augmenting the linear layout with a dummy dimension and then decompose the mapping from src to dst layout into a series of affine transformations that represent warp shuffles.
It turns out for all distributed layouts, we can implement this transformation using a single index shuffle (mask + permute).