-
This is a follow-up question from #1315. While I managed to make a program with Below, I describe my small experiment with sample outputs. I put my at the end, and also include the full reproducing program. Hope to receive some help here. Thank you in advance. I am trying a very simple task: copy a matrix of layout I created the CTA like this: code blockusing ClusterShape = Shape<_2, _1, _1>;
auto tma = make_tma_copy(
SM90_TMA_LOAD_MULTICAST{}, gX, smem_layout, size(ClusterShape{})); Then, inside the kernel, I have this block of code, which allows me to vary the code blockcute::cluster_arrive_relaxed();
cute::cluster_wait();
if (warp_idx == 0 && lane_predicate) {
constexpr int k_tma_transaction_bytes = size(sX) * sizeof(T);
tma_mbar[0] = 0;
cute::initialize_barrier(tma_mbar[0], 1;
cute::set_barrier_transaction_bytes(tma_mbar[0], k_tma_transaction_bytes);
uint16_t mcast_mask = 0b01; // I could use 0b00, 0b10, or 0b11
cute::copy(tma.with(tma_mbar[0], mcast_mask), tXgX, tXsX);
}
cute::cluster_arrive();
cute::cluster_wait(); Here's what I have seen:
So, two questions that stand out to me are:
Finally, here's the full (minimal, self-contained) program./**********
Self-contained example to study SM90_TMA_MULTICAST
Usage:
$ nvcc main.cu \
--expt-relaxed-constexpr \
--generate-code=arch=compute_90a,code=sm_90a \
-lcuda \
-w \
-Xcompiler=-Wconversion \
-Xcompiler=-fno-strict-aliasing \
-Xcompiler=-Wfatal-errors \
-Xcompiler=-Wno-abi \
-Xcompiler=-Wfatal-errors \
-std=c++17 \
-arch=sm_90 \
-I/usr/local/cuda/include \
-I"${CUTLASS_PATH}/include" \
-I"${CUTLASS_PATH}/tools/util/include"
$ ./a.out
**********/
#include <cstdio>
#include "thrust/device_vector.h"
#include "thrust/host_vector.h"
#include "cutlass/cutlass.h"
#include "cutlass/cluster_launch.hpp"
#include "cute/tensor.hpp"
#include "cute/arch/cluster_sm90.hpp"
template <
class T,
class TensorX,
class GmemLayout,
class SmemLayout,
class Tma
>
__global__ static void
tma_kernel(
TensorX tX,
GmemLayout gmem_layout,
SmemLayout smem_layout,
CUTE_GRID_CONSTANT Tma const tma
) {
using namespace cute;
__shared__ T smem[cosize_v<SmemLayout>];
__shared__ uint64_t tma_mbar[1];
auto sX = make_tensor(make_smem_ptr(smem), smem_layout);
auto mX = tma.get_tma_tensor(shape(gmem_layout));
auto gX = local_tile(mX, shape(sX), make_coord(blockIdx.x, blockIdx.y)); // (CTA_TILE_M,CTA_TILE_N)
auto block_rank_in_cluster = cute::block_rank_in_cluster();
auto cta_tma = tma.get_slice(block_rank_in_cluster);
auto tXgX = cta_tma.partition_S(gX); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
auto tXsX = cta_tma.partition_D(sX); // (TMA,TMA_M,TMA_N)
auto warp_idx = cutlass::canonical_warp_idx_sync();
auto lane_predicate = cute::elect_one_sync();
cute::cluster_arrive_relaxed();
cute::cluster_wait();
if (warp_idx == 0 && lane_predicate) {
constexpr int k_tma_transaction_bytes = size(sX) * sizeof(T);
tma_mbar[0] = 0;
cute::initialize_barrier(tma_mbar[0], 1 /*numThreads*/);
cute::set_barrier_transaction_bytes(tma_mbar[0], k_tma_transaction_bytes);
uint16_t mcast_mask = 0b01;
cute::copy(tma.with(tma_mbar[0], mcast_mask), tXgX, tXsX);
}
cute::cluster_arrive();
cute::cluster_wait();
if (thread(0, 0)) {
printf("----------\n");
print("tma: "); print(tma); print("\n");
printf("----------\n");
print("smem_layout: "); print(smem_layout); print("\n");
printf("----------\n");
print("tX: "); print_tensor(tX); print("\n");
print("mX: "); print_tensor(mX); print("\n");
print("gX: "); print_tensor(gX); print("\n");
print("sX: "); print_tensor(sX); print("\n");
}
}
int main() {
using namespace cute;
using T = float;
constexpr int m = 16;
constexpr int n = 8;
// create data
thrust::host_vector<T> cpu_data(m * n);
for (int i = 0; i < m*n; ++i) {
cpu_data[i] = static_cast<T>(i / 10.f);
}
thrust::device_vector<T> gpu_data = cpu_data;
cudaDeviceSynchronize();
// cluster shape
using ClusterShape = Shape<_2, _1, _1>;
// create tensors
auto gmem_layout = Layout<Shape< Int<m>, Int<n>>>{};
auto smem_layout = Layout<Shape<Int<m/2>, Int<n>>>{};
auto pX = reinterpret_cast<const T*>(gpu_data.data().get());
auto gX = make_tensor(make_gmem_ptr(pX), gmem_layout);
// create the TMA object
auto tma = make_tma_copy(
SM90_TMA_LOAD_MULTICAST{}, gX, smem_layout, size(ClusterShape{}));
// launch the kernel
dim3 grid_dims{2, 1, 1};
dim3 block_dims{1, 1, 1};
dim3 cluster_dims{size<0>(ClusterShape{}),
size<1>(ClusterShape{}),
size<2>(ClusterShape{})};
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims};
void const* kernel_ptr = (void const*)tma_kernel<
T,
decltype(gX),
decltype(gmem_layout),
decltype(smem_layout),
decltype(tma)
>;
cutlass::launch_kernel_on_cluster(
launch_params, kernel_ptr, gX, gmem_layout, smem_layout, tma);
auto result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST("Kernel launch FAILED.\n");
cudaError_t error = cudaGetLastError();
std::cout << error << std::endl;
}
return 0;
} |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Your concept of the parameters appears correct, but the Multicast TMAs would not be used to copy a 16x8 gmem tensor to two 8x8 smem tensors. In your example case, each copy appears to be completely independent. Instead, the Multicast TMAs are used to copy a a single 8x8 gmem tensor to two 8x8 smem tensors in a broadcasted fashion, where the broadcast is performed across all participating CTAs in the This is useful in GEMMs, for example, because the A tiles can be broadcasted across each "row" of CTAs and the B tiles can be broadcasted across each "column" of CTAs. |
Beta Was this translation helpful? Give feedback.
Your concept of the parameters appears correct, but the Multicast TMAs would not be used to copy a 16x8 gmem tensor to two 8x8 smem tensors. In your example case, each copy appears to be completely independent.
Instead, the Multicast TMAs are used to copy a a single 8x8 gmem tensor to two 8x8 smem tensors in a broadcasted fashion, where the broadcast is performed across all participating CTAs in the
mcast_mask
.This is useful in GEMMs, for example, because the A tiles can be broadcasted across each "row" of CTAs and the B tiles can be broadcasted across each "column" of CTAs.