Skip to content

Commit

Permalink
NVLS support for msccl++ executor (#375)
Browse files Browse the repository at this point in the history
- Support mote datatype for multicast operation
- Add new OP MULTI_LOAD_REDUCE_STORE to support NVLS
- Modify allocSharedPhysicalCuda, which return std::shared_ptr<T>
instead of std::shared_ptr<PhysicalCudaMemory>
- Add Python support for allocSharedPhysicalCuda

Test passed for `allreduce_nvls.json`
  • Loading branch information
Binyang2014 authored Nov 20, 2024
1 parent 3e51e9b commit 28a57b0
Show file tree
Hide file tree
Showing 26 changed files with 2,116 additions and 212 deletions.
2 changes: 1 addition & 1 deletion apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/packet_device.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>

#include "common.hpp"
#include "gpu_data_types.hpp"

template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
Expand Down
2 changes: 1 addition & 1 deletion docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ baseImageTable=(
["cuda12.2"]="nvidia/cuda:12.2.2-devel-ubuntu20.04"
["cuda12.3"]="nvidia/cuda:12.3.2-devel-ubuntu20.04"
["cuda12.4"]="nvidia/cuda:12.4.1-devel-ubuntu22.04"
["rocm6.2"]="rocm/rocm-terminal:6.2"
["rocm6.2"]="rocm/rocm-terminal:6.2.1"
)

declare -A extraLdPathTable
Expand Down
4 changes: 4 additions & 0 deletions docs/getting-started/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
```
lsmod | grep nvidia_peermem
```
* For GPU with nvls support, the IMEX channels should be set up (refer [cuMemCreate](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g899d69a862bba36449789c64b430dc7c)). You can set up the channels manually via:
```
sudo nvidia-modprobe -s -i <start:number of minors>
```

## Build with Docker Images

Expand Down
9 changes: 7 additions & 2 deletions include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using CUdeviceptr = hipDeviceptr_t;
using CUmemGenericAllocationHandle = hipMemGenericAllocationHandle_t;
using CUmemAllocationProp = hipMemAllocationProp;
using CUmemAccessDesc = hipMemAccessDesc;
using CUmemAllocationHandleType = hipMemAllocationHandleType;

constexpr auto cudaSuccess = hipSuccess;
constexpr auto cudaStreamNonBlocking = hipStreamNonBlocking;
Expand Down Expand Up @@ -86,6 +87,9 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri
#define cuMemSetAccess(...) hipMemSetAccess(__VA_ARGS__)
#define cuMemMap(...) hipMemMap(__VA_ARGS__)
#define cuMemUnmap(...) hipMemUnmap(__VA_ARGS__)
#define cuMemRetainAllocationHandle(...) hipMemRetainAllocationHandle(__VA_ARGS__)
#define cuMemExportToShareableHandle(...) hipMemExportToShareableHandle(__VA_ARGS__)
#define cuMemImportFromShareableHandle(...) hipMemImportFromShareableHandle(__VA_ARGS__)

#else

Expand All @@ -97,9 +101,10 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri
// NVLS
#if !defined(__HIP_PLATFORM_AMD__)
#include <linux/version.h>
#define USE_NVLS ((CUDART_VERSION >= 12010) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0)))
// We need CU_MEM_HANDLE_TYPE_FABRIC (instroduced in cuda12.3) to support sharing handles across GPUs via sockets
#define CUDA_NVLS_SUPPORTED ((CUDART_VERSION >= 12030) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0)))
#else // !defined(__HIP_PLATFORM_AMD__)
#define USE_NVLS 0
#define CUDA_NVLS_SUPPORTED 0
#endif // !defined(__HIP_PLATFORM_AMD__)

// GPU sync threads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using __bfloat162 = __hip_bfloat162;
#else

#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#if (CUDART_VERSION >= 11000)
#include <cuda_bf16.h>
#endif
Expand Down
155 changes: 100 additions & 55 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "errors.hpp"
#include "gpu.hpp"
#include "utils.hpp"

/// Throw @ref mscclpp::CudaError if @p cmd does not return cudaSuccess.
/// @param cmd The command to execute.
Expand All @@ -34,6 +35,19 @@

namespace mscclpp {

/// set memory access permission to read-write
/// @param base Base memory pointer.
/// @param size Size of the memory.
inline void setReadWriteMemoryAccess(void* base, size_t size) {
CUmemAccessDesc accessDesc = {};
int deviceId;
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = deviceId;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, size, &accessDesc, 1));
}

/// A RAII guard that will cudaThreadExchangeStreamCaptureMode to cudaStreamCaptureModeRelaxed on construction and
/// restore the previous mode on destruction. This is helpful when we want to avoid CUDA graph capture.
struct AvoidCudaGraphCaptureGuard {
Expand All @@ -53,15 +67,6 @@ struct CudaStreamWithFlags {
template <class T>
struct CudaDeleter;

template <class T>
struct PhysicalCudaMemory {
CUmemGenericAllocationHandle memHandle_;
T* devicePtr_;
size_t size_;
PhysicalCudaMemory(CUmemGenericAllocationHandle memHandle, T* devicePtr, size_t size)
: memHandle_(memHandle), devicePtr_(devicePtr), size_(size) {}
};

namespace detail {

/// A wrapper of cudaMalloc that sets the allocated memory to zero.
Expand All @@ -79,46 +84,38 @@ T* cudaCalloc(size_t nelem) {
return ptr;
}

#if (CUDA_NVLS_SUPPORTED)
template <class T>
PhysicalCudaMemory<T>* cudaPhysicalCalloc(size_t nelem, size_t gran) {
T* cudaPhysicalCalloc(size_t nelems, size_t gran) {
AvoidCudaGraphCaptureGuard cgcGuard;

int deviceId = -1;
CUdevice currentDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
MSCCLPP_CUTHROW(cuDeviceGet(&currentDevice, deviceId));

CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = deviceId;
#if defined(__HIP_PLATFORM_AMD__)
// TODO: revisit when HIP fixes this typo in the field name
prop.requestedHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#else
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#endif
prop.requestedHandleTypes =
(CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC);
prop.location.id = currentDevice;

CUmemGenericAllocationHandle memHandle;
size_t bufferSize = sizeof(T) * nelem;
// allocate physical memory
MSCCLPP_CUTHROW(cuMemCreate(&memHandle, bufferSize, &prop, 0 /*flags*/));

CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = deviceId;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUmemGenericAllocationHandle memHandle;
size_t nbytes = (nelems * sizeof(T) + gran - 1) / gran * gran;
MSCCLPP_CUTHROW(cuMemCreate(&memHandle, nbytes, &prop, 0 /*flags*/));

T* devicePtr = nullptr;
// Map the device pointer
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, bufferSize, gran, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, bufferSize, 0, memHandle, 0));
MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)devicePtr, bufferSize, &accessDesc, 1));
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, nbytes, gran, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, nbytes, 0, memHandle, 0));
setReadWriteMemoryAccess(devicePtr, nbytes);
CudaStreamWithFlags stream(cudaStreamNonBlocking);
MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, bufferSize, stream));

MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, nbytes, stream));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));

return new PhysicalCudaMemory<T>(memHandle, devicePtr, bufferSize);
return devicePtr;
}
#endif

template <class T>
T* cudaExtCalloc(size_t nelem) {
Expand Down Expand Up @@ -206,11 +203,15 @@ struct CudaDeleter {
template <class T>
struct CudaPhysicalDeleter {
static_assert(!std::is_array_v<T>, "T must not be an array");
void operator()(PhysicalCudaMemory<T>* ptr) {
void operator()(T* ptr) {
AvoidCudaGraphCaptureGuard cgcGuard;
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr->devicePtr_, ptr->size_));
MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr->devicePtr_, ptr->size_));
MSCCLPP_CUTHROW(cuMemRelease(ptr->memHandle_));
CUmemGenericAllocationHandle handle;
size_t size = 0;
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr));
MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size));
MSCCLPP_CUTHROW(cuMemRelease(handle));
MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, size));
}
};

Expand All @@ -234,16 +235,46 @@ std::shared_ptr<T> allocSharedCuda(size_t count = 1) {
return detail::safeAlloc<T, detail::cudaCalloc<T>, CudaDeleter<T>, std::shared_ptr<T>>(count);
}

/// Allocated physical memory on the device and returns a memory handle along with a memory handle for it.
/// The deallocation only happens PhysicalCudaMemory goes out of scope.
#if (CUDA_NVLS_SUPPORTED)
static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag) {
size_t gran = 0;
int numDevices = 0;
MSCCLPP_CUDATHROW(cudaGetDeviceCount(&numDevices));

CUmulticastObjectProp prop = {};
prop.size = size;
// This is a dummy value, it might affect the granularity in the future
prop.numDevices = numDevices;
prop.handleTypes = (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC);
prop.flags = 0;
MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, granFlag));
return gran;
}
#endif

/// Allocates physical memory on the device and returns a std::shared_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::shared_ptr to the memory handle and a device pointer for that memory.
/// @return A std::shared_ptr to the allocated memory.
template <class T>
std::shared_ptr<PhysicalCudaMemory<T>> allocSharedPhysicalCuda(size_t count, size_t gran) {
return detail::safeAlloc<PhysicalCudaMemory<T>, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>,
std::shared_ptr<PhysicalCudaMemory<T>>>(count, gran);
std::shared_ptr<T> allocSharedPhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) {
#if (CUDA_NVLS_SUPPORTED)
if (!isNvlsSupported()) {
throw Error("Only support GPU with NVLS support", ErrorCode::InvalidUsage);
}
if (count == 0) {
return nullptr;
}

if (gran == 0) {
gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED);
}
size_t nelems = ((count * sizeof(T) + gran - 1) / gran * gran) / sizeof(T);
return detail::safeAlloc<T, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>, std::shared_ptr<T>>(nelems, gran);
#else
throw Error("Only support GPU with Fabric support", ErrorCode::InvalidUsage);
#endif
}

/// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out.
Expand All @@ -269,18 +300,6 @@ UniqueCudaPtr<T> allocUniqueCuda(size_t count = 1) {
return detail::safeAlloc<T, detail::cudaCalloc<T>, CudaDeleter<T>, UniqueCudaPtr<T>>(count);
}

/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it.
/// The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::unique_ptr to the memory handle and a device pointer for that memory.
template <class T>
std::unique_ptr<PhysicalCudaMemory<T>> allocUniquePhysicalCuda(size_t count, size_t gran) {
return detail::safeAlloc<PhysicalCudaMemory<T>, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>,
std::unique_ptr<CudaPhysicalDeleter<T>, CudaDeleter<CudaPhysicalDeleter<T>>>>(count, gran);
}

/// Allocates memory on the device and returns a std::unique_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
Expand Down Expand Up @@ -349,6 +368,32 @@ UniqueCudaHostPtr<T> makeUniqueCudaHost(size_t count) {
return ptr;
}

/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it.
/// The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::unique_ptr to the allocated memory.
template <class T>
std::unique_ptr<T> allocUniquePhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) {
#if (CUDA_NVLS_SUPPORTED)
if (!isNvlsSupported()) {
throw Error("Only suupport GPU with NVLS support", ErrorCode::InvalidUsage);
}
if (count == 0) {
return nullptr;
}

if (gran == 0) {
gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED);
}
return detail::safeAlloc<T, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>,
std::unique_ptr<CudaPhysicalDeleter<T>, CudaDeleter<CudaPhysicalDeleter<T>>>>(count, gran);
#else
throw Error("Only support GPU with Fabric support", ErrorCode::InvalidUsage);
#endif
}

/// Asynchronous cudaMemcpy without capture into a CUDA graph.
/// @tparam T Type of each element in the allocated memory.
/// @param dst Destination pointer.
Expand Down
20 changes: 10 additions & 10 deletions include/mscclpp/nvls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,26 @@ class NvlsConnection {

struct DeviceMulticastPointer {
private:
std::shared_ptr<PhysicalCudaMemory<char>> deviceMem_;
void* devicePtr_;
std::shared_ptr<char> mcPtr_;
size_t bufferSize_;

public:
using DeviceHandle = DeviceMulticastPointerDeviceHandle;
DeviceMulticastPointer(std::shared_ptr<PhysicalCudaMemory<char>> deviceMem, std::shared_ptr<char> mcPtr,
size_t bufferSize)
: deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceMulticastPointer(void* devicePtr, std::shared_ptr<char> mcPtr, size_t bufferSize)
: devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceHandle deviceHandle();
char* getDevicePtr();
void* getDevicePtr();

friend class NvlsConnection;
};

std::shared_ptr<DeviceMulticastPointer> allocateAndBindCuda(size_t size);

/// The \p handle to the allocation (its lifetime is managed by the caller)
/// and the \p size of the allocation.
std::shared_ptr<char> bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size);
/// @brief bind the allocated memory via @ref mscclpp::allocSharedPhysicalCuda to the multicast handle. The behavior
/// is undefined if the devicePtr is not allocated by @ref mscclpp::allocSharedPhysicalCuda.
/// @param devicePtr
/// @param size
/// @return DeviceMulticastPointer with devicePtr, mcPtr and bufferSize
DeviceMulticastPointer bindAllocatedMemory(CUdeviceptr devicePtr, size_t size);

size_t getMultiCastMinGranularity();

Expand Down
Loading

0 comments on commit 28a57b0

Please sign in to comment.