diff --git a/oneflow/core/framework/random_generator.cpp b/oneflow/core/framework/random_generator.cpp index d1c9ad27fff..90572fb680f 100644 --- a/oneflow/core/framework/random_generator.cpp +++ b/oneflow/core/framework/random_generator.cpp @@ -70,7 +70,7 @@ Maybe DefaultCUDAGenerator(int device_index) { static std::vector init_flags(device_count); static std::vector> default_cuda_generator(device_count); - if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); } + if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); } CHECK_OR_RETURN(device_index >= 0 && device_index < device_count) << "Invalid device index " << device_index; std::call_once(init_flags[device_index], [&]() { @@ -91,7 +91,7 @@ Maybe MakeCPUGenerator() { #ifdef WITH_CUDA Maybe MakeCUDAGenerator(int device_index) { - if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); } + if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); } CHECK_OR_RETURN(device_index >= 0 && device_index < detail::GetCudaDeviceCount()) << "Invalid device index " << device_index; return std::make_shared( diff --git a/oneflow/core/framework/random_generator_impl.cpp b/oneflow/core/framework/random_generator_impl.cpp index acd4933fd32..39255462c14 100644 --- a/oneflow/core/framework/random_generator_impl.cpp +++ b/oneflow/core/framework/random_generator_impl.cpp @@ -145,10 +145,9 @@ int GetThreadNum(const cudaDeviceProp& prop) { } } -Maybe CUDASynchronize(int device_index) { +Maybe CUDASynchronize() { // Synchronize cuda device to avoid state been modified in random kernels. JUST(CPUSynchronize()); - OF_CUDA_CHECK(cudaSetDevice(device_index)); OF_CUDA_CHECK(cudaDeviceSynchronize()); return Maybe::Ok(); } @@ -161,25 +160,29 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(uint64_t seed, int device_index) OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_index)); max_block_num_ = prop.multiProcessorCount; max_thread_num_ = GetThreadNum(prop); - OF_CUDA_CHECK(cudaSetDevice(device_index)); + + CudaCurrentDeviceGuard dev_guard(device_index); OF_CUDA_CHECK( cudaMalloc(&curand_states_, max_block_num_ * max_thread_num_ * sizeof(curandState))); detail::InitCurandStates(seed, max_block_num_, max_thread_num_, curand_states_); } CUDAGeneratorImpl::~CUDAGeneratorImpl() { - CHECK_JUST(CUDASynchronize(this->device_index())); + CudaCurrentDeviceGuard dev_guard(this->device_index()); + CHECK_JUST(CUDASynchronize()); OF_CUDA_CHECK(cudaFree(curand_states_)); } void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { - CHECK_JUST(CUDASynchronize(this->device_index())); + CudaCurrentDeviceGuard dev_guard(this->device_index()); + CHECK_JUST(CUDASynchronize()); seed_ = seed; detail::InitCurandStates(seed_, max_block_num_, max_thread_num_, curand_states_); } Maybe CUDAGeneratorImpl::GetState() const { - JUST(CUDASynchronize(this->device_index())); + CudaCurrentDeviceGuard dev_guard(this->device_index()); + JUST(CUDASynchronize()); int64_t state_size = max_block_num_ * max_thread_num_ * sizeof(curandState); int64_t total_size = state_size + sizeof(int64_t); const auto& device = JUST(Device::New("cpu")); @@ -207,7 +210,8 @@ Maybe CUDAGeneratorImpl::SetState(const std::shared_ptr& tensor_st << total_size << ", but got " << tensor_state->shape()->elem_cnt(); } - JUST(CUDASynchronize(this->device_index())); + CudaCurrentDeviceGuard dev_guard(this->device_index()); + JUST(CUDASynchronize()); const auto& callback = std::make_shared>([&](uint64_t of_blob_ptr) { auto* of_blob = reinterpret_cast(of_blob_ptr); const int8_t* data = of_blob->blob().dptr(); @@ -398,16 +402,27 @@ Maybe MakeGeneratorImpl(uint64_t seed, int d } #ifdef WITH_CUDA + +int GetCudaDeviceIndex() { + int cuda_device_index = 0; + if (CHECK_JUST(GlobalMultiClientEnv())) { + cuda_device_index = GlobalProcessCtx::LocalRank(); + } else { + OF_CUDA_CHECK(cudaGetDevice(&cuda_device_index)); + } + return cuda_device_index; +} + int GetCudaDeviceCount() { - /* static */ int cuda_device_count; - OF_CUDA_CHECK(cudaSetDevice(GlobalProcessCtx::LocalRank())); + /* static */ int cuda_device_count = 0; + CudaCurrentDeviceGuard dev_guard(detail::GetCudaDeviceIndex()); OF_CUDA_CHECK(cudaGetDeviceCount(&cuda_device_count)); return cuda_device_count; } template<> DeviceKey MakeDeviceKey(int device_index) { - if (device_index == -1) { device_index = GlobalProcessCtx::LocalRank(); } + if (device_index == -1) { device_index = detail::GetCudaDeviceIndex(); } DeviceKey device_key; device_key.device_type = DeviceType::kGPU; device_key.device_index = device_index; diff --git a/oneflow/core/framework/random_generator_impl.h b/oneflow/core/framework/random_generator_impl.h index 18f8dc72dca..90c85050c6f 100644 --- a/oneflow/core/framework/random_generator_impl.h +++ b/oneflow/core/framework/random_generator_impl.h @@ -137,6 +137,7 @@ class CUDAGeneratorImpl : public DeviceGeneratorImpl { namespace detail { +int GetCudaDeviceIndex(); int GetCudaDeviceCount(); void InitCurandStates(uint64_t seed, int32_t block_num, int32_t thread_num, curandState* states);