diff --git a/src/blas/backends/cublas/cublas_handle.hpp b/src/blas/backends/cublas/cublas_handle.hpp index 83a76c927..ce455925f 100644 --- a/src/blas/backends/cublas/cublas_handle.hpp +++ b/src/blas/backends/cublas/cublas_handle.hpp @@ -18,7 +18,6 @@ **************************************************************************/ #ifndef CUBLAS_HANDLE_HPP #define CUBLAS_HANDLE_HPP -#include #include namespace oneapi { @@ -28,26 +27,12 @@ namespace cublas { template struct cublas_handle { - using handle_container_t = std::unordered_map*>; + using handle_container_t = std::unordered_map; handle_container_t cublas_handle_mapper_{}; ~cublas_handle() noexcept(false) { for (auto& handle_pair : cublas_handle_mapper_) { cublasStatus_t err; - if (handle_pair.second != nullptr) { - auto handle = handle_pair.second->exchange(nullptr); - if (handle != nullptr) { - CUBLAS_ERROR_FUNC(cublasDestroy, err, handle); - handle = nullptr; - } - else { - // if the handle is nullptr it means the handle was already - // destroyed by the ContextCallback and we're free to delete the - // atomic object. - delete handle_pair.second; - } - - handle_pair.second = nullptr; - } + CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second); } cublas_handle_mapper_.clear(); } diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 164c3c3aa..142c36217 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -33,32 +33,32 @@ namespace cublas { thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; -CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) - : ih(ih) {} +CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { CUdevice device = ih.get_native_device(); CUstream streamId = get_stream(queue); cublasStatus_t err; - if (handle_helper.cublas_handle_mapper_.count(device) > 0) { - cublasHandle_t handle = handle_helper.cublas_handle_mapper_[device]; - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; + auto it = handle_helper.cublas_handle_mapper_.find(device); + if (it != handle_helper.cublas_handle_mapper_.end()) { + cublasHandle_t nativeHandle = it->second; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); + } + return nativeHandle; } - cublasHandle_t handle; - CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); + cublasHandle_t nativeHandle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle); + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); - auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(device, handle)); + auto insert_iter = + handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle)); - return handle; + return nativeHandle; } CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) { diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index 803a98f32..28ca1f71a 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -24,10 +24,8 @@ #include #endif -#include #include #include -#include #include "cublas_helper.hpp" #include "cublas_handle.hpp" @@ -69,7 +67,7 @@ class CublasScopedContextHandler { sycl::context get_context(const sycl::queue& queue); public: - CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih); + CublasScopedContextHandler(sycl::interop_handle& ih); /** * @brief get_handle: creates the handle by implicitly impose the advice diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp index 03c282aed..908600d27 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp @@ -36,31 +36,21 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) cublasStatus_t err; auto it = handle_helper.cublas_handle_mapper_.find(current_device); if (it != handle_helper.cublas_handle_mapper_.end()) { - if (it->second == nullptr) { - handle_helper.cublas_handle_mapper_.erase(it); - } - else { - auto handle = it->second->load(); - if (handle != nullptr) { - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; - } - else { - handle_helper.cublas_handle_mapper_.erase(it); - } + cublasHandle_t handle = it->second; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); } + return handle; } cublasHandle_t handle; CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(current_device, new std::atomic(handle))); + auto insert_iter = + handle_helper.cublas_handle_mapper_.insert(std::make_pair(current_device, handle)); return handle; } @@ -71,4 +61,4 @@ CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) { } // namespace cublas } // namespace blas } // namespace mkl -} // namespace oneapi \ No newline at end of file +} // namespace oneapi diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp index 9e1eb89e5..7d218e355 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp @@ -25,7 +25,6 @@ #endif #include #include -#include #include "cublas_helper.hpp" #include "cublas_handle.hpp" namespace oneapi { diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index f4b530ddd..ae95e6eb1 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -60,7 +60,7 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) { #else cgh.host_task([f, queue](sycl::interop_handle ih) { #endif - auto sc = CublasScopedContextHandler(queue, ih); + auto sc = CublasScopedContextHandler(ih); f(sc); }); }