From e4b0518b0f3bf186dda8b7c9336ce1e4917ab3f7 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 7 Jan 2025 20:24:22 -0500 Subject: [PATCH] Attempt bump --- deps/ReactantExtra/API.cpp | 2 + deps/ReactantExtra/BUILD | 4 +- deps/ReactantExtra/WORKSPACE | 69 +++++++++++++++-------------- deps/ReactantExtra/make-bindings.jl | 2 +- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 81945aa2b..5b0c8eac3 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -91,6 +91,7 @@ #include "xla/python/pjrt_ifrt/pjrt_tuple.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "jax/jaxlib/mosaic/dialect/tpu/tpu_dialect.h" using namespace mlir; using namespace llvm; @@ -548,6 +549,7 @@ extern "C" void RegisterDialects(MlirContext cctx) { context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 4bdc49035..a86f5a3c9 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -685,10 +685,10 @@ gentbl_cc_library( ) gentbl_cc_library( - name = "MosaicTPUJLIncGen", + name = "TPUJLIncGen", tbl_outs = [( ["--generator=jl-op-defs", "--disable-module-wrap=0"], - "MosaicTPU.jl" + "TPU.jl" ) ], td_file = "@jax//jaxlib/mosaic:dialect/tpu/tpu.td", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a8693ca26..37f713e7e 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "bf8ad1cc5f42c0929ca6bf8de1b04507803721fd" +ENZYMEXLA_COMMIT = "d59e5fdcc457bad7c0ce36215e5b427bf0646284" ENZYMEXLA_SHA256 = "" http_archive( @@ -94,39 +94,40 @@ LLVM_TARGETS = select({ "//conditions:default": ["AMDGPU", "NVPTX"], }) + ["AArch64", "X86", "ARM"] -LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908" -LLVM_SHA256 = "" -http_archive( - name = "llvm-raw", - build_file_content = "# empty", - sha256 = LLVM_SHA256, - strip_prefix = "llvm-project-" + LLVM_COMMIT, - urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], -) - - -load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -maybe( - http_archive, - name = "llvm_zlib", - build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", - sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", - strip_prefix = "zlib-ng-2.0.7", - urls = [ - "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", - ], -) - -maybe( - http_archive, - name = "llvm_zstd", - build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", - sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", - strip_prefix = "zstd-1.5.2", - urls = [ - "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" - ], -) +# Uncomment these lines to use a custom LLVM commit +# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908" +# LLVM_SHA256 = "" +# http_archive( +# name = "llvm-raw", +# build_file_content = "# empty", +# sha256 = LLVM_SHA256, +# strip_prefix = "llvm-project-" + LLVM_COMMIT, +# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], +# ) +# +# +# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +# maybe( +# http_archive, +# name = "llvm_zlib", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", +# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", +# strip_prefix = "zlib-ng-2.0.7", +# urls = [ +# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", +# ], +# ) +# +# maybe( +# http_archive, +# name = "llvm_zstd", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", +# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", +# strip_prefix = "zstd-1.5.2", +# urls = [ +# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" +# ], +# ) http_archive( name = "jax", diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index 26174517e..db13cffc9 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -27,7 +27,7 @@ for file in [ "Nvvm.jl", "Gpu.jl", "Affine.jl", - "MosaicTPU.jl", + "TPU.jl", "Triton.jl" ] build_file(joinpath(src_dir, "mlir", "Dialects", file))