Skip to content

Commit

Permalink
Merge branch 'main' into num_compute_units
Browse files Browse the repository at this point in the history
  • Loading branch information
dyniols authored Jan 7, 2025
2 parents e82835b + 3472b5b commit e6d65d7
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 29 deletions.
63 changes: 43 additions & 20 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,41 +949,53 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,

auto Platform = CommandBuffer->Context->getPlatform();
auto ZeDevice = CommandBuffer->Device->ZeDevice;
ze_command_list_handle_t ZeCommandList =
CommandBuffer->ZeComputeCommandListTranslated;
if (Platform->ZeMutableCmdListExt.LoaderExtension) {
ZeCommandList = CommandBuffer->ZeComputeCommandList;
}

if (NumKernelAlternatives > 0) {
ZeMutableCommandDesc.flags |=
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;

std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
NumKernelAlternatives + 1, nullptr);
std::vector<ze_kernel_handle_t> KernelHandles(NumKernelAlternatives + 1,
nullptr);

ze_kernel_handle_t ZeMainKernel{};
UR_CALL(getZeKernel(ZeDevice, Kernel, &ZeMainKernel));

// Translate main kernel first
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, ZeMainKernel,
(void **)&TranslatedKernelHandles[0]));
if (Platform->ZeMutableCmdListExt.LoaderExtension) {
KernelHandles[0] = ZeMainKernel;
} else {
// If the L0 loader is not aware of the MCL extension, the main kernel
// handle needs to be translated.
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, ZeMainKernel, (void **)&KernelHandles[0]));
}

for (size_t i = 0; i < NumKernelAlternatives; i++) {
ze_kernel_handle_t ZeAltKernel{};
UR_CALL(getZeKernel(ZeDevice, KernelAlternatives[i], &ZeAltKernel));

ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, ZeAltKernel,
(void **)&TranslatedKernelHandles[i + 1]));
if (Platform->ZeMutableCmdListExt.LoaderExtension) {
KernelHandles[i + 1] = ZeAltKernel;
} else {
// If the L0 loader is not aware of the MCL extension, the kernel
// alternatives need to be translated.
ZE2UR_CALL(zelLoaderTranslateHandle, (ZEL_HANDLE_KERNEL, ZeAltKernel,
(void **)&KernelHandles[i + 1]));
}
}

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListGetNextCommandIdWithKernelsExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, NumKernelAlternatives + 1,
TranslatedKernelHandles.data(), &CommandId));
(ZeCommandList, &ZeMutableCommandDesc, NumKernelAlternatives + 1,
KernelHandles.data(), &CommandId));

} else {
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
(CommandBuffer->ZeComputeCommandListTranslated,
&ZeMutableCommandDesc, &CommandId));
(ZeCommandList, &ZeMutableCommandDesc, &CommandId));
}
DEBUG_LOG(CommandId);

Expand Down Expand Up @@ -1863,17 +1875,22 @@ ur_result_t updateKernelCommand(
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;

if (NewKernel && Command->Kernel != NewKernel) {
ze_kernel_handle_t KernelHandle{};
ze_kernel_handle_t ZeNewKernel{};
UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel));

ze_kernel_handle_t ZeKernelTranslated = nullptr;
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated));
ze_command_list_handle_t ZeCommandList =
CommandBuffer->ZeComputeCommandList;
KernelHandle = ZeNewKernel;
if (!Platform->ZeMutableCmdListExt.LoaderExtension) {
ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated;
ZE2UR_CALL(zelLoaderTranslateHandle,
(ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle));
}

ZE2UR_CALL(Platform->ZeMutableCmdListExt
.zexCommandListUpdateMutableCommandKernelsExp,
(CommandBuffer->ZeComputeCommandListTranslated, 1,
&Command->CommandId, &ZeKernelTranslated));
(ZeCommandList, 1, &Command->CommandId, &KernelHandle));
// Set current kernel to be the new kernel
Command->Kernel = NewKernel;
}
Expand Down Expand Up @@ -2079,9 +2096,15 @@ ur_result_t updateKernelCommand(
MutableCommandDesc.pNext = NextDesc;
MutableCommandDesc.flags = 0;

ze_command_list_handle_t ZeCommandList =
CommandBuffer->ZeComputeCommandListTranslated;
if (Platform->ZeMutableCmdListExt.LoaderExtension) {
ZeCommandList = CommandBuffer->ZeComputeCommandList;
}

ZE2UR_CALL(
Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
(CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc));
(ZeCommandList, &MutableCommandDesc));

return UR_RESULT_SUCCESS;
}
Expand Down
6 changes: 1 addition & 5 deletions source/adapters/level_zero/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,6 @@ ur_result_t urEnqueueEventsWait(
std::unique_lock<ur_shared_mutex> Lock(Queue->Mutex);
resetCommandLists(Queue);
}
if (OutEvent && (*OutEvent)->Completed) {
UR_CALL(CleanupCompletedEvent((*OutEvent), false, false));
UR_CALL(urEventReleaseInternal((*OutEvent)));
}

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -795,7 +791,7 @@ urEventWait(uint32_t NumEvents, ///< [in] number of events in the event list
//
ur_event_handle_t_ *Event = ur_cast<ur_event_handle_t_ *>(e);
if (!Event->hasExternalRefs())
die("urEventsWait must not be called for an internal event");
die("urEventWait must not be called for an internal event");

ze_event_handle_t ZeHostVisibleEvent;
if (auto Res = Event->getOrCreateHostVisibleEvent(ZeHostVisibleEvent))
Expand Down
1 change: 1 addition & 0 deletions source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ ur_result_t ur_platform_handle_t_::initialize() {
ZeMutableCmdListExt.Supported |=
ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp !=
nullptr;
ZeMutableCmdListExt.LoaderExtension = true;
} else {
ZeMutableCmdListExt.Supported |=
(ZE_CALL_NOCHECK(
Expand Down
6 changes: 6 additions & 0 deletions source/adapters/level_zero/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ struct ur_platform_handle_t_ : public _ur_platform {
// associated with particular Level Zero driver, store this extension here.
struct ZeMutableCmdListExtension {
bool Supported = false;
// If LoaderExtension is true, the L0 loader is aware of the MCL extension.
// If it is false, the extension has to be loaded directly from the driver
// using zeDriverGetExtensionFunctionAddress. If it is loaded directly from
// the driver, any handles passed to it must be translated using
// zelLoaderTranslateHandle.
bool LoaderExtension = false;
ze_result_t (*zexCommandListGetNextCommandIdExp)(
ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *,
uint64_t *) = nullptr;
Expand Down
7 changes: 3 additions & 4 deletions source/common/logger/ur_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,15 @@ inline Logger create_logger(std::string logger_name, bool skip_prefix,
logger::Level default_log_level) {
std::transform(logger_name.begin(), logger_name.end(), logger_name.begin(),
::toupper);
std::stringstream env_var_name;
const auto default_flush_level = logger::Level::ERR;
const std::string default_output = "stderr";
auto level = default_log_level;
auto flush_level = default_flush_level;
std::unique_ptr<logger::Sink> sink;

env_var_name << "UR_LOG_" << logger_name;
auto env_var_name = "UR_LOG_" + logger_name;
try {
auto map = getenv_to_map(env_var_name.str().c_str());
auto map = getenv_to_map(env_var_name.c_str());
if (!map.has_value()) {
return Logger(
default_log_level,
Expand Down Expand Up @@ -173,7 +172,7 @@ inline Logger create_logger(std::string logger_name, bool skip_prefix,
skip_linebreak);
} catch (const std::invalid_argument &e) {
std::cerr << "Error when creating a logger instance from the '"
<< env_var_name.str() << "' environment variable:\n"
<< env_var_name << "' environment variable:\n"
<< e.what() << std::endl;
return Logger(default_log_level,
std::make_unique<logger::StderrSink>(
Expand Down

0 comments on commit e6d65d7

Please sign in to comment.