Skip to content

Commit

Permalink
[rocsolver] Use enqueue_native_command ext when avail (#582)
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <[email protected]>
  • Loading branch information
JackAKirk authored Oct 8, 2024
1 parent 7adfbcc commit 09d4ab3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 52 deletions.
4 changes: 2 additions & 2 deletions src/lapack/backends/rocsolver/rocsolver_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<rocmDataType **>(a_dev);
auto *info_ = reinterpret_cast<rocblas_int *>(info);
ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
(int)n[i], a_ + offset, (int)lda[i], info_ + offset,
(int)group_sizes[i]);
offset += group_sizes[i];
Expand Down Expand Up @@ -627,7 +627,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<rocmDataType **>(a_dev);
auto **b_ = reinterpret_cast<rocmDataType **>(b_dev);
ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
rocsolver_native_named_func(func_name, func, err, handle, get_rocblas_fill_mode(uplo[i]),
(int)n[i], (int)nrhs[i], a_ + offset, (int)lda[i],
b_ + offset, (int)ldb[i], (int)group_sizes[i]);
offset += group_sizes[i];
Expand Down
11 changes: 11 additions & 0 deletions src/lapack/backends/rocsolver/rocsolver_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ class hip_error : virtual public std::runtime_error {
hipError_t hip_err; \
HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId);

template <class Func, class... Types>
inline void rocsolver_native_named_func(const char *func_name, Func func,
rocsolver_status err,
rocsolver_handle handle, Types... args){
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, args...)
#else
ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...)
#endif
};

inline rocblas_eform get_rocsolver_itype(std::int64_t itype) {
switch (itype) {
case 1: return rocblas_eform_ax;
Expand Down
Loading

0 comments on commit 09d4ab3

Please sign in to comment.