Skip to content

Commit

Permalink
Reduce runtime dependency on torch
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Jan 3, 2025
1 parent 781ae0b commit c3ccf21
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
56 changes: 45 additions & 11 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
import math
import os
import statistics
import subprocess
import sys
from contextlib import contextmanager
Expand All @@ -17,16 +19,38 @@ def nvsmi(attrs):
return ret


def _quantile(a, q):
n = len(a)
a = sorted(a)

def get_quantile(q):
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return (1 - t) * a[lower] + t * a[upper]

return [get_quantile(q) for q in q]


def _summarize_statistics(times, quantiles, return_mode):
import torch
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
ret = _quantile(times, quantiles)
if len(ret) == 1:
ret = ret[0]
return ret
if return_mode == "all":
return times.tolist()
return getattr(torch, return_mode)(times).item()
return times
elif return_mode == "min":
return min(times)
elif return_mode == "max":
return max(times)
elif return_mode == "mean":
return statistics.mean(times)
elif return_mode == "median":
return statistics.median(times)


def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
Expand All @@ -39,7 +63,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean".
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
:type return_mode: str
"""
import torch
Expand Down Expand Up @@ -89,7 +113,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
end_event.record()
torch.cuda.synchronize()
ret += [start_event.elapsed_time(end_event) / n_repeat]
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
return _summarize_statistics(ret, quantiles, return_mode)


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
Expand All @@ -107,10 +131,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float], optional
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
:type return_mode: str
"""
assert return_mode in ["min", "max", "mean", "median", "all"]
import torch

di = runtime.driver.active.get_device_interface()

Expand All @@ -124,7 +148,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
end_event = di.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
if hasattr(cache, "zero_"):
cache.zero_()
elif isinstance(cache, bytearray):
cache.__init__(len(cache))
else:
cache[:] = 0
fn()
end_event.record()
di.synchronize()
Expand All @@ -147,14 +176,19 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
if hasattr(cache, "zero_"):
cache.zero_()
elif isinstance(cache, bytearray):
cache.__init__(len(cache))
else:
cache[:] = 0
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
di.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
7 changes: 5 additions & 2 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,11 @@ def get_device_interface(self):

@staticmethod
def is_active():
import torch
return torch.version.hip is not None
try:
import torch
return torch.version.hip is not None
except ImportError:
return False

def get_current_target(self):
device = self.get_current_device()
Expand Down
7 changes: 5 additions & 2 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,11 @@ def get_device_interface(self):

@staticmethod
def is_active():
import torch
return torch.cuda.is_available() and (torch.version.hip is None)
try:
import torch
return torch.cuda.is_available() and (torch.version.hip is None)
except ImportError:
return False

def get_benchmarker(self):
from triton.testing import do_bench
Expand Down

0 comments on commit c3ccf21

Please sign in to comment.