Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Jan 1, 2025
1 parent 71d6673 commit 62ffe87
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):

# -------------------- Launcher ----------------------------
def ty_to_cpp(ty):
if ty[0] == '*' or ty == "none":
if ty[0] == '*':
return "hipDeviceptr_t"
return {
"i1": "int32_t",
Expand All @@ -186,10 +186,12 @@ def ty_to_cpp(ty):
}[ty]


def make_launcher(constants, signature, ids, warp_size):
def make_launcher(constants, signature, warp_size):

def _extracted_type(ty):
if ty[0] == '*' or ty == "none":
if ty == "constexpr":
return "PyObject*"
if ty[0] == '*':
return "PyObject*"
if ty[0] == '[':
if ty == "[]":
Expand Down Expand Up @@ -223,7 +225,6 @@ def format_of(ty):
"uint64_t": "K",
}[ty]

signature = {k: v for k, v in signature.items() if v != 'constexpr'}
args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
format = "iiiKKOOOO" + args_format
signature = ','.join(signature.values()).replace('[', '').replace(']', '')
Expand All @@ -232,13 +233,13 @@ def format_of(ty):
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")

libhip_path = _get_path_to_hip_runtime_dylib()

# generate glue code
params = list(range(len(signature)))
params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"]
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
params.append("&global_scratch")
src = f"""
#define __HIP_PLATFORM_AMD__
Expand Down Expand Up @@ -470,11 +471,10 @@ def format_of(ty):
class HIPLauncher(object):

def __init__(self, src, metadata):
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
constants = src.constants if hasattr(src, "constants") else dict()
constants = {idx: value for idx, value in constants.items()}
signature = {idx: value for idx, value in src.signature.items()}
src = make_launcher(constants, signature, ids, metadata.warp_size)
src = make_launcher(constants, signature, metadata.warp_size)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch

Expand Down

0 comments on commit 62ffe87

Please sign in to comment.