Skip to content

Commit

Permalink
[FRONTEND] add assertion that PRNG seed must be an int
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Jan 3, 2025
1 parent f410f91 commit 268d7df
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
17 changes: 17 additions & 0 deletions python/test/unit/language/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,23 @@ def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01


def test_seed_is_int(device):

@triton.jit
def kernel(X, seed):
offset = tl.arange(0, 1)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand)

x = torch.empty(1, dtype=torch.float32, device=device)
with pytest.raises(triton.compiler.errors.CompilationError):
seed0 = torch.zeros(1, dtype=torch.int32, device="cuda")
kernel[(1, )](x, seed0)
with pytest.raises(triton.compiler.errors.CompilationError):
seed1 = 2.3
kernel[(1, )](x, seed1)


# test normal PRNG


Expand Down
3 changes: 2 additions & 1 deletion python/triton/language/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
@jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
seed = tl.to_tensor(seed)
tl.static_assert(seed.dtype.is_int())
seed = seed.to(tl.uint64)
c0 = tl.to_tensor(c0)
c1 = tl.to_tensor(c1)
c2 = tl.to_tensor(c2)
c3 = tl.to_tensor(c3)
seed = seed.to(tl.uint64)
if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
int_dtype = tl.uint32
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
Expand Down

0 comments on commit 268d7df

Please sign in to comment.