Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DOCS] Correct autotune/heuristics docstrings (#5487)
Running the example given in the [autotune docstring](https://triton-lang.org/main/python-api/generated/triton.autotune.html) gives the error ```python import triton import torch @triton.autotune(configs=[ triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), ], key=['x_size'] # the two above configs will be evaluated anytime # the value of x_size changes ) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] if __name__ == '__main__': x = torch.ones(8, device="cuda") kernel[lambda _: (1,)](x, x.numel()) ``` ``` Traceback (most recent call last): File "...", line 18, in <module> kernel[lambda _: (1,)](x, x.size) File ".../triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File ".../triton/runtime/autotuner.py", line 156, in run timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} File ".../triton/runtime/autotuner.py", line 156, in <dictcomp> timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} File ".../triton/runtime/autotuner.py", line 133, in _bench return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) File ".../triton/testing.py", line 106, in do_bench fn() File ".../triton/runtime/autotuner.py", line 114, in kernel_call self.fn.run( File ".../triton/runtime/jit.py", line 618, in run bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) TypeError: dynamic_func() missing 1 required positional argument: 'META' ``` It seems it cannot parse the kwargs `**META`, so the keyword arguments must be manually specified. Also, `BLOCK_SIZE` should probably be marked as `tl.constexpr`. ```python @triton.autotune( configs=[ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8), ], key=["x_size"], # the two above configs will be evaluated anytime # the value of x_size changes ) @triton.jit def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): ... ``` Similarly, for the [heuristics](https://triton-lang.org/main/python-api/generated/triton.heuristics.html) example, first, the same `**META` issue applies, second, `args` is no longer a list of positional argument values but a dictionary from argument name to value, and third, `2 ** int(math.ceil(math.log2(args[1])))` is awkward and `triton.next_power_of_2(args['x_size'])` should be preferred. ```python import torch import triton @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) @triton.jit def kernel(x_ptr, x_size, **META): BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size if __name__ == "__main__": x = torch.ones(8, device="cuda") kernel[lambda _: (1,)](x, x.numel()) ``` ``` Traceback (most recent call last): File "...", line 15, in <module> kernel[lambda _: (1,)](x, x.numel()) File ".../triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File ".../triton/runtime/autotuner.py", line 337, in run kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) File "...", line 7, in <lambda> @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) KeyError: 1 ``` Applying the suggested changes results in ```python # smallest power-of-two >= x_size @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])}) @triton.jit def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): ... ```
- Loading branch information