Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 8, 2025
1 parent e4b0518 commit 165b2fd
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#include "xla/python/pjrt_ifrt/pjrt_tuple.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "jax/jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"

using namespace mlir;
using namespace llvm;
Expand Down Expand Up @@ -1150,9 +1150,10 @@ extern "C" const ifrt::Sharding *ifrt_array_sharding(ifrt::Array *array) {
return &(array->sharding());
}

extern "C" PjRtLayout *ifrt_array_layout(ifrt::Array *array) {
return MyValueOrThrow(array->layout()).release();
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" PjRtLayout *ifrt_array_layout(ifrt::Array *array) {
// return MyValueOrThrow(array->layout()).release();
// }

// TODO xla::ifrt::Array::DisassembleIntoSingleDeviceArrays
// TODO xla::ifrt::Array::FullyReplicatedShard
Expand Down Expand Up @@ -1400,15 +1401,16 @@ ifrt_executable_output_shardings(ifrt::Executable *executable) {
return std::make_tuple(shardings.value().size(), shardings.value().data());
}

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_executable_parameter_layouts(ifrt::Executable *executable) {
auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
for (int i = 0; i < layouts.size(); i++) {
layouts_ptr[i] = layouts[i].release();
}
return std::make_tuple(layouts.size(), layouts_ptr);
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" std::tuple<size_t, xla::PjRtLayout **>
// ifrt_executable_parameter_layouts(ifrt::Executable *executable) {
// auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
// for (int i = 0; i < layouts.size(); i++) {
// layouts_ptr[i] = layouts[i].release();
// }
// return std::make_tuple(layouts.size(), layouts_ptr);
// }

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_executable_output_layouts(ifrt::Executable *executable) {
Expand Down Expand Up @@ -1511,15 +1513,16 @@ ifrt_loadedexecutable_output_shardings(ifrt::LoadedExecutable *executable) {
return std::make_tuple(shardings.value().size(), shardings.value().data());
}

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable *executable) {
auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
for (int i = 0; i < layouts.size(); i++) {
layouts_ptr[i] = layouts[i].release();
}
return std::make_tuple(layouts.size(), layouts_ptr);
}
// @mofeng this is now a shared ptr, will let you fix
// extern "C" std::tuple<size_t, xla::PjRtLayout **>
// ifrt_loadedexecutable_parameter_layouts(ifrt::LoadedExecutable *executable) {
// auto layouts = MyValueOrThrow(executable->GetParameterLayouts());
// auto layouts_ptr = new xla::PjRtLayout *[layouts.size()];
// for (int i = 0; i < layouts.size(); i++) {
// layouts_ptr[i] = layouts[i].release();
// }
// return std::make_tuple(layouts.size(), layouts_ptr);
// }

extern "C" std::tuple<size_t, xla::PjRtLayout **>
ifrt_loadedexecutable_output_layouts(ifrt::LoadedExecutable *executable) {
Expand Down

0 comments on commit 165b2fd

Please sign in to comment.