Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BLAS] Simplify CublasScopedContextHandler #609

Merged
30 changes: 12 additions & 18 deletions src/blas/backends/cublas/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,30 @@
**************************************************************************/
#ifndef CUBLAS_HANDLE_HPP
#define CUBLAS_HANDLE_HPP
#include <atomic>
#include <unordered_map>
#include "cublas_helper.hpp"

namespace oneapi {
namespace mkl {
namespace blas {
namespace cublas {

template <typename T>
struct cublas_handle {
using handle_container_t = std::unordered_map<T, std::atomic<cublasHandle_t>*>;
using handle_container_t = std::unordered_map<CUdevice, cublasHandle_t>;
handle_container_t cublas_handle_mapper_{};
~cublas_handle() noexcept(false) {
CUresult err;
CUcontext original;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original);
for (auto& handle_pair : cublas_handle_mapper_) {
cublasStatus_t err;
if (handle_pair.second != nullptr) {
auto handle = handle_pair.second->exchange(nullptr);
if (handle != nullptr) {
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already
// destroyed by the ContextCallback and we're free to delete the
// atomic object.
delete handle_pair.second;
}

handle_pair.second = nullptr;
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, handle_pair.first);
if (original != desired) {
// Sets the desired context as the active one for the thread in order to destroy its corresponding cublasHandle_t.
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
}
cublasStatus_t err;
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second);
}
cublas_handle_mapper_.clear();
}
Expand Down
113 changes: 17 additions & 96 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
*
**************************************************************************/
#include "cublas_scope_handle.hpp"
#if __has_include(<sycl/detail/common.hpp>)
#include <sycl/detail/common.hpp>
#else
#include <CL/sycl/detail/common.hpp>
#endif

namespace oneapi {
namespace mkl {
Expand All @@ -35,108 +30,34 @@ namespace cublas {
* takes place if no other element in the container has a key equivalent to
* the one being emplaced (keys in a map container are unique).
*/
#ifdef ONEMKL_PI_INTERFACE_REMOVED
thread_local cublas_handle<ur_context_handle_t> CublasScopedContextHandler::handle_helper =
cublas_handle<ur_context_handle_t>{};
#else
thread_local cublas_handle<pi_context> CublasScopedContextHandler::handle_helper =
cublas_handle<pi_context>{};
#endif
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih)
: ih(ih),
needToRecover_(false) {
placedContext_ = new sycl::context(queue.get_context());
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult err;
CUcontext desired;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_);
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice);
if (original_ != desired) {
// Sets the desired context as the active one for the thread
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired);
// No context is installed and the suggested context is primary
// This is the most common case. We can activate the context in the
// thread and leave it there until all the PI context referring to the
// same underlying CUDA primary context are destroyed. This emulates
// the behaviour of the CUDA runtime api, and avoids costly context
// switches. No action is required on this side of the if.
needToRecover_ = !(original_ == nullptr);
}
}

CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) {
if (needToRecover_) {
CUresult err;
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_);
}
delete placedContext_;
}

void ContextCallback(void* userData) {
auto* ptr = static_cast<std::atomic<cublasHandle_t>*>(userData);
if (!ptr) {
return;
}
auto handle = ptr->exchange(nullptr);
if (handle != nullptr) {
cublasStatus_t err1;
CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle);
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already destroyed by
// the cublas_handle destructor and we're free to delete the atomic
// object.
delete ptr;
}
}
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) {
auto cudaDevice = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUresult cuErr;
CUcontext desired;
CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice);
#ifdef ONEMKL_PI_INTERFACE_REMOVED
auto piPlacedContext_ = reinterpret_cast<ur_context_handle_t>(desired);
#else
auto piPlacedContext_ = reinterpret_cast<pi_context>(desired);
#endif
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUstream streamId = get_stream(queue);
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_);

auto it = handle_helper.cublas_handle_mapper_.find(device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
if (it->second == nullptr) {
handle_helper.cublas_handle_mapper_.erase(it);
}
else {
auto handle = it->second->load();
if (handle != nullptr) {
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
else {
handle_helper.cublas_handle_mapper_.erase(it);
}
cublasHandle_t nativeHandle = it->second;
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
}
return nativeHandle;
}

cublasHandle_t handle;

CUBLAS_ERROR_FUNC(cublasCreate, err, &handle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);

auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(piPlacedContext_, new std::atomic<cublasHandle_t>(handle)));
cublasHandle_t nativeHandle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);

sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback,
insert_iter.first->second);
auto insert_iter =
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));

return handle;
return nativeHandle;
}

CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
Expand Down
35 changes: 2 additions & 33 deletions src/blas/backends/cublas/cublas_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,9 @@
#else
#include <CL/sycl.hpp>
#endif
#if __has_include(<sycl/context.hpp>)
#if __SYCL_COMPILER_VERSION <= 20220930
#include <sycl/backend/cuda.hpp>
#endif
#include <sycl/context.hpp>
#else
#include <CL/sycl/backend/cuda.hpp>
#include <CL/sycl/context.hpp>
#endif

// After Plugin Interface removal in DPC++ ur.hpp is the new include
#if __has_include(<sycl/detail/ur.hpp>)
#include <sycl/detail/ur.hpp>
#ifndef ONEMKL_PI_INTERFACE_REMOVED
#define ONEMKL_PI_INTERFACE_REMOVED
#endif
#elif __has_include(<sycl/detail/pi.hpp>)
#include <sycl/detail/pi.hpp>
#else
#include <CL/sycl/detail/pi.hpp>
#endif
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved

#include <atomic>
#include <memory>
#include <thread>
#include <unordered_map>
#include "cublas_helper.hpp"
#include "cublas_handle.hpp"

Expand Down Expand Up @@ -84,22 +61,14 @@ the handle must be destroyed when the context goes out of scope. This will bind
**/

class CublasScopedContextHandler {
CUcontext original_;
sycl::context* placedContext_;
bool needToRecover_;
sycl::interop_handle& ih;
#ifdef ONEMKL_PI_INTERFACE_REMOVED
static thread_local cublas_handle<ur_context_handle_t> handle_helper;
#else
static thread_local cublas_handle<pi_context> handle_helper;
#endif
static thread_local cublas_handle handle_helper;
CUstream get_stream(const sycl::queue& queue);
sycl::context get_context(const sycl::queue& queue);

public:
CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih);
CublasScopedContextHandler(sycl::interop_handle& ih);

~CublasScopedContextHandler() noexcept(false);
/**
* @brief get_handle: creates the handle by implicitly impose the advice
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
Rbiessy marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
32 changes: 11 additions & 21 deletions src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,33 @@ namespace mkl {
namespace blas {
namespace cublas {

thread_local cublas_handle<int> CublasScopedContextHandler::handle_helper = cublas_handle<int>{};
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih)
: interop_h(ih) {}

cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) {
sycl::device device = queue.get_device();
int current_device = interop_h.get_native_device<sycl::backend::cuda>();
CUdevice current_device = interop_h.get_native_device<sycl::backend::cuda>();
CUstream streamId = get_stream(queue);
cublasStatus_t err;
auto it = handle_helper.cublas_handle_mapper_.find(current_device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
if (it->second == nullptr) {
handle_helper.cublas_handle_mapper_.erase(it);
}
else {
auto handle = it->second->load();
if (handle != nullptr) {
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
else {
handle_helper.cublas_handle_mapper_.erase(it);
}
cublasHandle_t handle = it->second;
cudaStream_t currentStreamId;
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);
}
return handle;
}
cublasHandle_t handle;

CUBLAS_ERROR_FUNC(cublasCreate, err, &handle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId);

auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(current_device, new std::atomic<cublasHandle_t>(handle)));
auto insert_iter =
handle_helper.cublas_handle_mapper_.insert(std::make_pair(current_device, handle));
return handle;
}

Expand All @@ -71,4 +61,4 @@ CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) {
} // namespace cublas
} // namespace blas
} // namespace mkl
} // namespace oneapi
} // namespace oneapi
3 changes: 1 addition & 2 deletions src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#endif
#include <memory>
#include <thread>
#include <unordered_map>
#include "cublas_helper.hpp"
#include "cublas_handle.hpp"
namespace oneapi {
Expand Down Expand Up @@ -60,7 +59,7 @@ the handle must be destroyed when the context goes out of scope. This will bind

class CublasScopedContextHandler {
sycl::interop_handle interop_h;
static thread_local cublas_handle<int> handle_helper;
static thread_local cublas_handle handle_helper;
sycl::context get_context(const sycl::queue& queue);
CUstream get_stream(const sycl::queue& queue);

Expand Down
14 changes: 1 addition & 13 deletions src/blas/backends/cublas/cublas_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,6 @@
#else
#include "cublas_scope_handle_hipsycl.hpp"

// After Plugin Interface removal in DPC++ ur.hpp is the new include
#if __has_include(<sycl/detail/ur.hpp>)
#include <sycl/detail/ur.hpp>
#ifndef ONEMKL_PI_INTERFACE_REMOVED
#define ONEMKL_PI_INTERFACE_REMOVED
#endif
#elif __has_include(<sycl/detail/pi.hpp>)
#include <sycl/detail/pi.hpp>
#else
#include <CL/sycl/detail/pi.hpp>
#endif

namespace sycl {
using interop_handler = sycl::interop_handle;
}
Expand All @@ -72,7 +60,7 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) {
#else
cgh.host_task([f, queue](sycl::interop_handle ih) {
#endif
auto sc = CublasScopedContextHandler(queue, ih);
auto sc = CublasScopedContextHandler(ih);
f(sc);
});
}
Expand Down