Skip to content

Commit

Permalink
[DOCS] Correct autotune/heuristics docstrings (#5487)
Browse files Browse the repository at this point in the history
Running the example given in the [autotune
docstring](https://triton-lang.org/main/python-api/generated/triton.autotune.html)
gives the error

```python
import triton
import torch


@triton.autotune(configs=[
    triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
    triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
    ],
    key=['x_size'] # the two above configs will be evaluated anytime
                    # the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
    BLOCK_SIZE = META['BLOCK_SIZE']

if __name__ == '__main__':
    x = torch.ones(8, device="cuda")
    kernel[lambda _: (1,)](x, x.numel())
```

```
Traceback (most recent call last):
  File "...", line 18, in <module>
    kernel[lambda _: (1,)](x, x.size)
  File ".../triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File ".../triton/runtime/autotuner.py", line 156, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File ".../triton/runtime/autotuner.py", line 156, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File ".../triton/runtime/autotuner.py", line 133, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File ".../triton/testing.py", line 106, in do_bench
    fn()
  File ".../triton/runtime/autotuner.py", line 114, in kernel_call
    self.fn.run(
  File ".../triton/runtime/jit.py", line 618, in run
    bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
TypeError: dynamic_func() missing 1 required positional argument: 'META'
```

It seems it cannot parse the kwargs `**META`, so the keyword arguments
must be manually specified.

Also, `BLOCK_SIZE` should probably be marked as `tl.constexpr`.

```python
@triton.autotune(
    configs=[
        triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
        triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
    ],
    key=["x_size"],  # the two above configs will be evaluated anytime
    # the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
  ...
```

Similarly, for the
[heuristics](https://triton-lang.org/main/python-api/generated/triton.heuristics.html)
example, first, the same `**META` issue applies, second, `args` is no
longer a list of positional argument values but a dictionary from
argument name to value, and third, `2 **
int(math.ceil(math.log2(args[1])))` is awkward and
`triton.next_power_of_2(args['x_size'])` should be preferred.

```python
import torch
import triton


@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
    BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size


if __name__ == "__main__":
    x = torch.ones(8, device="cuda")
    kernel[lambda _: (1,)](x, x.numel())
```

```
Traceback (most recent call last):
  File "...", line 15, in <module>
    kernel[lambda _: (1,)](x, x.numel())
  File ".../triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File ".../triton/runtime/autotuner.py", line 337, in run
    kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
  File "...", line 7, in <lambda>
    @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
KeyError: 1
```

Applying the suggested changes results in

```python
# smallest power-of-two >= x_size
@triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})
@triton.jit
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
    ...
```
  • Loading branch information
stephen-huan authored Dec 23, 2024
1 parent f27f6a7 commit 513a047
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
...
:note: When all the configurations are evaluated, the kernel will run multiple times.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
Expand Down Expand Up @@ -382,18 +382,18 @@ def run(self, *args, **kwargs):
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
# smallest power-of-two >= x_size
@triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
...
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
:type values: dict[str, Callable[[list[Any]], Any]]
:type values: dict[str, Callable[[dict[str, Any]], Any]]
"""

def decorator(fn):
Expand Down

0 comments on commit 513a047

Please sign in to comment.