Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move lapack_info_check inside of onemkl_cusolver_host_task #238

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
105 changes: 25 additions & 80 deletions src/lapack/backends/cusolver/cusolver_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,25 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st
// Create new buffer with 32-bit ints then copy over results
std::uint64_t ipiv_size = stride_ipiv * batch_size;
sycl::buffer<int> ipiv32(sycl::range<1>{ ipiv_size });
sycl::buffer<int> devInfo{ batch_size };

queue.submit([&](sycl::handler &cgh) {
auto a_acc = a.template get_access<sycl::access::mode::read_write>(cgh);
auto ipiv32_acc = ipiv32.template get_access<sycl::access::mode::write>(cgh);
auto devInfo_acc = devInfo.template get_access<sycl::access::mode::write>(cgh);
auto scratch_acc = scratchpad.template get_access<sycl::access::mode::write>(cgh);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = sc.get_mem<cuDataType *>(a_acc);
auto ipiv_ = sc.get_mem<int *>(ipiv32_acc);
auto devInfo_ = sc.get_mem<int *>(devInfo_acc);
auto scratch_ = sc.get_mem<cuDataType *>(scratch_acc);
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (std::int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i,
lda, scratch_, ipiv_ + stride_ipiv * i, devInfo_ + i);
lda, scratch_, ipiv_ + stride_ipiv * i, dev_info_d + i);
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

Expand All @@ -215,7 +214,6 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st
[=](sycl::id<1> index) { ipiv_acc[index] = ipiv32_acc[index]; });
});

lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
}

#define GETRF_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -459,10 +457,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m, n, lda, stride_a, stride_tau, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
Expand Down Expand Up @@ -513,10 +508,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m[i], n[i], lda[i], group_sizes[i]);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
Expand Down Expand Up @@ -571,26 +563,22 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
// Allocate memory with 32-bit ints then copy over results
std::uint64_t ipiv_size = stride_ipiv * batch_size;
int *ipiv32 = (int *)malloc_device(sizeof(int) * ipiv_size, queue);
int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
auto devInfo_ = reinterpret_cast<int *>(devInfo);
auto scratchpad_ = reinterpret_cast<cuDataType *>(scratchpad);
auto ipiv_ = reinterpret_cast<int *>(ipiv32);
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i,
lda, scratchpad_, ipiv_ + stride_ipiv * i, devInfo_ + i);
lda, scratchpad_, ipiv32 + stride_ipiv * i, dev_info_d + i);
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

Expand All @@ -607,10 +595,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ipiv32, queue); });
});

// lapack_info_check calls queue.wait()
lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
sycl::free(devInfo, queue);

return done_casting;
}

Expand Down Expand Up @@ -656,29 +640,27 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
for (int64_t group_id = 0; group_id < group_count; ++group_id)
for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id)
ipiv32[global_id] = (int *)malloc_device(sizeof(int) * n[group_id], queue);
int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
auto scratch_ = reinterpret_cast<cuDataType *>(scratchpad);
int64_t global_id = 0;
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
for (int64_t local_id = 0; local_id < group_sizes[group_id];
++local_id, ++global_id) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id],
n[group_id], a_[global_id], lda[group_id], scratch_,
ipiv32[global_id], devInfo + global_id);
ipiv32[global_id], dev_info_d + global_id);
}
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

Expand Down Expand Up @@ -712,10 +694,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
});
});

// lapack_info_check calls queue.wait()
lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
sycl::free(devInfo, queue);

return done_freeing;
}

Expand Down Expand Up @@ -814,22 +792,18 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu
});

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
cgh.depends_on(done_casting);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
auto ipiv_ = reinterpret_cast<int *>(ipiv32);
auto b_ = reinterpret_cast<cuDataType *>(b);
cusolverStatus_t err;

// Does not use scratch so call cuSolver asynchronously and sync at end
for (int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(trans), n,
nrhs, a_ + stride_a * i, lda, ipiv_ + stride_ipiv * i,
nrhs, a_ + stride_a * i, lda, ipiv32 + stride_ipiv * i,
b_ + stride_b * i, ldb, nullptr);
}
CUSOLVER_SYNC(err, handle)
Expand Down Expand Up @@ -902,13 +876,8 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu
}

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
for (int64_t i = 0; i < batch_size; i++) {
cgh.depends_on(casting_dependencies[i]);
}
depends_on_events(cgh, dependencies);
depends_on_events(cgh, casting_dependencies);

onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
Expand Down Expand Up @@ -967,10 +936,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
Expand Down Expand Up @@ -1020,10 +986,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
Expand Down Expand Up @@ -1074,10 +1037,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(n, lda, stride_a, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
CUdeviceptr a_dev;
Expand Down Expand Up @@ -1135,10 +1095,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu
}

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
int64_t offset = 0;
Expand Down Expand Up @@ -1199,10 +1156,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu
throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1");

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
CUresult cuda_result;
Expand Down Expand Up @@ -1283,10 +1237,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu
queue.submit([&](sycl::handler &h) { h.memcpy(b_dev, b, batch_size * sizeof(T *)); });

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
cgh.depends_on(done_cpy_a);
cgh.depends_on(done_cpy_b);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
Expand Down Expand Up @@ -1340,10 +1291,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
Expand Down Expand Up @@ -1393,10 +1341,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu
overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
for (int64_t i = 0; i < num_events; i++) {
cgh.depends_on(dependencies[i]);
}
depends_on_events(cgh, dependencies);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType **>(a);
Expand Down
56 changes: 36 additions & 20 deletions src/lapack/backends/cusolver/cusolver_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,30 +280,46 @@ struct CudaEquivalentType<std::complex<double>> {

/* devinfo */

inline void get_cusolver_devinfo(sycl::queue &queue, sycl::buffer<int> &devInfo,
std::vector<int> &dev_info_) {
sycl::host_accessor<int, 1, sycl::access::mode::read> dev_info_acc{ devInfo };
for (unsigned int i = 0; i < dev_info_.size(); ++i)
dev_info_[i] = dev_info_acc[i];
// Accepts a int*, copies the memory from device to host,
// checks value does not indicate an error, frees the device memory
inline void lapack_info_check_and_free(int *dev_info_d, const char *func_name,
const char *cufunc_name, int num_elements = 1) {
int *dev_info_h = (int *)malloc(sizeof(int) * num_elements);
cuMemcpyDtoH(dev_info_h, reinterpret_cast<CUdeviceptr>(dev_info_d), sizeof(int) * num_elements);
for (uint32_t i = 0; i < num_elements; ++i) {
if (dev_info_h[i] > 0)
throw oneapi::mkl::lapack::computation_error(
func_name,
std::string(cufunc_name) + " failed with info = " + std::to_string(dev_info_h[i]),
dev_info_h[i]);
}
cuMemFree(reinterpret_cast<CUdeviceptr>(dev_info_d));
}

inline void get_cusolver_devinfo(sycl::queue &queue, const int *devInfo,
std::vector<int> &dev_info_) {
queue.wait();
queue.memcpy(dev_info_.data(), devInfo, sizeof(int));
// Allocates and returns a CUDA device pointer for cuSolver dev_info
inline int *create_dev_info(int num_elements = 1) {
CUdeviceptr dev_info_d;
cuMemAlloc(&dev_info_d, sizeof(int) * num_elements);
return reinterpret_cast<int *>(dev_info_d);
}

template <typename DEVINFO_T>
inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name,
const char *cufunc_name, int dev_info_size = 1) {
std::vector<int> dev_info_(dev_info_size);
get_cusolver_devinfo(queue, devinfo, dev_info_);
for (const auto &val : dev_info_) {
if (val > 0)
throw oneapi::mkl::lapack::computation_error(
func_name, std::string(cufunc_name) + " failed with info = " + std::to_string(val),
val);
}
// Helper function for waiting on a vector of sycl events
inline void depends_on_events(sycl::handler &cgh,
const std::vector<sycl::event> &dependencies = {}) {
for (auto &e : dependencies)
cgh.depends_on(e);
}

// Asynchronously frees sycl USM `ptr` after waiting on events `dependencies`
template <typename T>
inline sycl::event free_async(sycl::queue &queue, T *ptr,
const std::vector<sycl::event> &dependencies = {}) {
sycl::event done = queue.submit([&](sycl::handler &cgh) {
depends_on_events(cgh, dependencies);

cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ptr, queue); });
});
return done;
}

/* batched helpers */
Expand Down
Loading