Skip to content

Commit

Permalink
Rolled back to the version with unordered map but without contextSetE…
Browse files Browse the repository at this point in the history
…xtendedDeleter

It seems the static thread_local unordered map needs to stay because of
all the thread shenanigans. But we're removing the use of detail
namespace in sycl since it's not necessary for correctness.
  • Loading branch information
konradkusiak97 committed Nov 8, 2024
1 parent 1961e14 commit 46a2661
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 57 deletions.
19 changes: 2 additions & 17 deletions src/blas/backends/cublas/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
**************************************************************************/
#ifndef CUBLAS_HANDLE_HPP
#define CUBLAS_HANDLE_HPP
#include <atomic>
#include <unordered_map>

namespace oneapi {
Expand All @@ -28,26 +27,12 @@ namespace cublas {

template <typename T>
struct cublas_handle {
using handle_container_t = std::unordered_map<T, std::atomic<cublasHandle_t>*>;
using handle_container_t = std::unordered_map<T, cublasHandle_t>;
handle_container_t cublas_handle_mapper_{};
~cublas_handle() noexcept(false) {
for (auto& handle_pair : cublas_handle_mapper_) {
cublasStatus_t err;
if (handle_pair.second != nullptr) {
auto handle = handle_pair.second->exchange(nullptr);
if (handle != nullptr) {
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already
// destroyed by the ContextCallback and we're free to delete the
// atomic object.
delete handle_pair.second;
}

handle_pair.second = nullptr;
}
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second);
}
cublas_handle_mapper_.clear();
}
Expand Down
32 changes: 16 additions & 16 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,32 @@ namespace cublas {
thread_local cublas_handle<CUdevice> CublasScopedContextHandler::handle_helper =
cublas_handle<CUdevice>{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih)
: ih(ih) {}
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) {
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUstream streamId = get_stream(queue);
cublasStatus_t err;

if (handle_helper.cublas_handle_mapper_.count(device) > 0) {
cublasHandle_t handle = handle_helper.cublas_handle_mapper_[device];
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
auto it = handle_helper.cublas_handle_mapper_.find(device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
cublasHandle_t nativeHandle = it->second;
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
}
return nativeHandle;
}

cublasHandle_t handle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &handle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
cublasHandle_t nativeHandle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);

auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(device, handle));
auto insert_iter =
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));

return handle;
return nativeHandle;
}

CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
Expand Down
4 changes: 1 addition & 3 deletions src/blas/backends/cublas/cublas_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
#include <CL/sycl.hpp>
#endif

#include <atomic>
#include <memory>
#include <thread>
#include <unordered_map>
#include "cublas_helper.hpp"
#include "cublas_handle.hpp"

Expand Down Expand Up @@ -69,7 +67,7 @@ class CublasScopedContextHandler {
sycl::context get_context(const sycl::queue& queue);

public:
CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih);
CublasScopedContextHandler(sycl::interop_handle& ih);

/**
* @brief get_handle: creates the handle by implicitly impose the advice
Expand Down
28 changes: 9 additions & 19 deletions src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,21 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue)
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(current_device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
if (it->second == nullptr) {
handle_helper.cublas_handle_mapper_.erase(it);
}
else {
auto handle = it->second->load();
if (handle != nullptr) {
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
else {
handle_helper.cublas_handle_mapper_.erase(it);
}
cublasHandle_t handle = it->second;
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
cublasHandle_t handle;

CUBLAS_ERROR_FUNC(cublasCreate, err, &handle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);

auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(current_device, new std::atomic<cublasHandle_t>(handle)));
auto insert_iter =
handle_helper.cublas_handle_mapper_.insert(std::make_pair(current_device, handle));
return handle;
}

Expand All @@ -71,4 +61,4 @@ CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
} // namespace cublas
} // namespace blas
} // namespace mkl
} // namespace oneapi
} // namespace oneapi
1 change: 0 additions & 1 deletion src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#endif
#include <memory>
#include <thread>
#include <unordered_map>
#include "cublas_helper.hpp"
#include "cublas_handle.hpp"
namespace oneapi {
Expand Down
2 changes: 1 addition & 1 deletion src/blas/backends/cublas/cublas_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) {
#else
cgh.host_task([f, queue](sycl::interop_handle ih) {
#endif
auto sc = CublasScopedContextHandler(queue, ih);
auto sc = CublasScopedContextHandler(ih);
f(sc);
});
}
Expand Down

0 comments on commit 46a2661

Please sign in to comment.