-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] make_tiled_copy_B generates incompatible layouts #1953
Comments
It does seem more reasonable, but that's not how the The solution is therefore some shared memory layout engineering. With the TiledCopy you show, we want each thread to access 128 consecutive bits, as you mention. So if I just follow the auto smem_layout_B_atom = Layout<Shape <Shape <_8, _4>,Shape <_2,_16>>,
Stride<Stride<_2,_256>,Stride<_1,_16>>>{}); // 32N x 32K
auto smem_layout_B = tile_to_shape(smem_layout_B_atom, Shape<_128,_32>{}); // 128N x 32K but I'm sure you can do better by also considering the stores from global memory and bank-accesses. |
Thank you so much for your reply! I didn't realize that I can actually manipulate the SmemAtomLayout. Before, I was simply naively doing Shape<_32, _32>, Stride<_1, _32>. In my case, since B is actually transposed, i.e., row major, I use the following SmemAtomLayout: auto smem_layout_B_atom = Layout<Shape<Shape<_8, _4>, Shape<_2, _16>>,
Stride<Stride<_2, _16>, Stride<_1, _64>>> For swizzle, I wouldn't want to swizzle the consecutive 16 indicies, so MBase = 4. I am not sure what would need to be modified or considered for the stores from global memory. I use a pretty regularized, continuous, non-swizzled GmemCopy, which uses SM80_CP_ASYNC_CACHEGLOBAL. It doesn't seem that GmemCopy can affect SmemCopy. Is there anything that I miss here? Thanks! |
Hello, ccecka!
I modified my MMA's permutation such that the output of the MMA can be directly used as operand A of a subsequent MMA. Basically the output register indices for thread i is exactly the input indices for thread i in the next MMA.
This works like a charm. However, I didn't expect this change to affect smem_tiled_copy_B.
But now, smem_tiled_copy_B becomes rather wierd: I think SmemCopy would just need 4 destination registers, and my MMA does provide 4 registers there for each thread. Considering U16 vs Int8, I can understand that the source is partitioned into 8x2. However, this time the source is even more partitioned. Could you help me understand the reason behind this, and what could I do to perform a correct SmemCopy? It seems that I can manipulate the Smem layout a bit harder? Thank you so much! |
I’m still a bit unclear on changing SmemLayout. Here's my current understanding: previously, the data was stored in Smem in row-major order. However, we now want the data to follow a different pattern in Smem, where two values belonging to consecutive columns are stored first, and then we proceed in a row-major fashion. What I’m not sure about is how Gmem to Smem copy will handle this new storage pattern. Currently, I’m using SM80_CP_ASYNC_CACHEGLOBAL for the GmemCopy. This operation uses one source address and one destination address per transfer, and each copy moves 128 bits. It doesn't seem that this instruction can adapt to the new SmemLayout. Thank you so much! |
Looks to me like you won't be able to use 128b GMEM->SMEM copy as you only have 16 contiguous bits in each. This is what I meant by engineering the SMEM Layout to consider the loads from GMEM and the stores into RMEM. You can make the SmemLayout anything you like and it shouldn't affect correctness -- it will only affect where each logical value is stored and, therefore, the access patterns including the number of bits possible in a vectorization and the banks accessed by each thread. Sacrifices can be made in the GMEM->SMEM copy, but LDSM is in general more strict about its granularity and requirements. EDIT: For reference, your new auto smem_layout_B_atom = Layout<Shape <Shape <_2, _2, _4>, _2>,
Stride<Stride<_2, _16, _4>, _1>>{}; // 16x2 atom to account for contiguous T0 and T16 |
Thank you so much for your reply! I guess I can also choose to use DefaultCopy from Smem to Register. What I'm currently struggling with is that ldmatrix does not seem to be compatible with changing the permutation of TiledMMA: ldmatrix demands that the first row must go to 4x registers belonging to thread 0-3, the second row must go to 4x registers belonging to thread 4-7, the third row must go to thread 8-11, and the forth row 12-16, etc. In my example, changing the tiledMMA permutation on the N dimension consequently changes the TiledCopy Src layout for B.
Here is the generated tiledMMA with and without PermutationsMNK: We use TiledMMA ad SmemCopyAtomTransposed to generate SmemTiledCopyB:
Here is the TiledCopyB with and without permutation: Without modifying the permutation, everything works perfectly. On the source side, ldmatrix processes 8 rows, each assigned to a single thread. On the destination side, all 32 threads receive their corresponding portion of values in the 8x8 matrix. With the permutation on N dimension, however, the destination can not form a 8x8 mapping that maps to all 32 threads: In this case, make_tiled_copy_B does not produce an error. Instead, it maps a single ldmatrix instruction to both a strided input and a strided output. However, it seems that ldmatrix cannot inherently handle a strided input address for a single input row. I'm trying to understand the meaning of the make_tiled_copy_B output and what options I have. Does this imply that I need to choose between the following: It seems that I need to make trade-offs to optimize certain aspects at the expense of others and select the most performance-efficient solution. For example, if Gmem to Smem copy happens only once, while the data in Smem is reused multiple times when copying to registers, then (2) would be more advantageous than (1). Thank you so much! |
This is absolutely not true. The
This is correct :-D But whether you permute or not in (3) does not affect your ability to use LDSM. |
What is your question?
Hello!
I am writing an int8 GEMM layer using cute.
I use
MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>
as my atom MMA, and define my tiled MMA as:For element B, my original layout is transposed, so I use
Then I define tiled copy and use the tiled copy to partition my tensor in shared memory.
Here I plot the MMA and smem_tiled_copy_B using print_latex.
mma_int8.pdf
tiled_copy_B.pdf
Good news is that the destination of smem_tiled_copy_B matches the MMA layout of B.
Bad news is that the source of smem_tiled_copy_B is arranged like ((2, 8), 2):((64, 1),16) instead of something like (16, 2):(1, 16).
I am not sure why this configuration generates the (2, 8) partition. SmemCopyAtomTransposed is constructed using SM75_U16x8_LDSM_T and int8_t, which internally should uses the ldmatrix instruction that takes in one 128-bit input each time. So it seems more reasonable for make_tiled_copy_B to have 16 continuous int8_t values in the inner dimension.
This generates errors when calling cute::copy(), as SM75_U16x8_LDSM_T for int8 is incompatible with the src layout:
Could you help me take a look at this issue? Thank you so much!
The text was updated successfully, but these errors were encountered: