Replies: 14 comments 43 replies
-
WeightOnly is very useful in Decoder-only architecture!!!!, but I wonder why use this interleaved layout, I have not seen in cutlass existed layout. Is there exist optimizations? I have tried to use the best config in cublasLt int8 TN layout, it seems it's not suitable for interleaved layout... |
Beta Was this translation helpful? Give feedback.
-
I think reordering is used to improve the performance of type conversion https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L54 |
Beta Was this translation helpful? Give feedback.
-
Is this kernel supposed to work on CUDA 12.1? I'm getting strange results where all outputs are zero, even though I confirmed that the kernel is running and inputs are sane. On the container |
Beta Was this translation helpful? Give feedback.
-
Just found that the FT kernel always uses fp32 accumulation https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h#L110 Is there a danger in enabling fp16 accum, at least optionally? |
Beta Was this translation helpful? Give feedback.
-
I've extracted the FT kernel into https://github.com/tlc-pack/cutlass_fpA_intB_gemm to make it easier to integrate into third-party projects. And I've already made the first improvement over the original implementation: Add support for residual block fusion tlc-pack/cutlass_fpA_intB_gemm#1. Things are not documented at all and there is no test either, but nonetheless I hope it would be useful for others as well. |
Beta Was this translation helpful? Give feedback.
-
It's being integrated into TVM by apache/tvm#15111 |
Beta Was this translation helpful? Give feedback.
-
Hi, All, I am working on making changes to upstream mixed-input support into upstream NVIDIA/CUTLASS. Please review some drawings below on how I am planning to choreograph the mainloop with mixed input datatype. It is slightly different from approach discussed here. Notably, we would like to maintain canonical layout in the global memory for TN mixed input (F16 * S8) GEMM and use Consider a warp-level test that I am currently fleshing out using the figure in mind. Let me know if you see an issue in the above approach? TEST(SM80_warp_gemm_tensor_op_mixed_input_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = int8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run(cutlass::Distribution::Identity, cutlass::Distribution::Sequential);
} cc: @IonThruster , @hwu36 , @thakkarV , @kerrmudgeon |
Beta Was this translation helpful? Give feedback.
-
My understanding is that the I2F conversion is done after loading from smem to rmem, right before calling Did you guys ever try doing I2F conversion right after loading from gmem, before writing to smem? That might make the code simpler, but idk how much it affects performance. Even though that would increase the amount of smem read/write, maybe this kind of gemm is bottlenecked by either gmem bandwidth or fp16 mma anyway, and not smem read/write? |
Beta Was this translation helpful? Give feedback.
-
Any specific plans to actually have any of this upstream in CUTLASS soon? I was not aware of anything above, and as mentioned here, in the meantime I was working on utilizing it in PyTorch, starting from CUTLASS extensions from the FasterTransformer project (mine is f16xs8 only at the moment, but at least CUTLASS extensions from FasterTransformer are updated for CUTLASS 3.x in my version). Of course, it would be much better to actually have this functionality in CUTLASS. |
Beta Was this translation helpful? Give feedback.
-
Hi @rhenry-nv, I'm wondering if it is possible to run the FT int4/8 GEMM kernel on multiple GPUs. The way https://github.com/mlc-ai/mlc-llm does multi-gpu for non-FT paths is to shard the quantized weight along row or column dimension, do GEMM on each device, and do NCCL AllReduce to gather the results. This scheme doesn't seem to work for the FT kernel due to the need for weight preprocessing, involving elements permute and transpose. Any thought? |
Beta Was this translation helpful? Give feedback.
-
oh sry,I made wrong. the correct order is split weight first, then do preprocess on each device's weight.
…---Original---
From: ***@***.***>
Date: Wed, Oct 18, 2023 21:36 PM
To: ***@***.***>;
Cc: ***@***.******@***.***>;
Subject: Re: [NVIDIA/cutlass] F16 x S8/S4 GEMM (Discussion #911)
The result is incorrect if I do that.
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Will cutlass support group quantization for S4, in F16 x S4? |
Beta Was this translation helpful? Give feedback.
-
@rhenry-nv Hi rawn henry,I found in cutlass official's mixed input gemm using ColMajor weight layout, and use FragementShuffler to get the mma layout's weight. Is there exist performance gap between ColMajor and InterleavedLayout? I think InterleavedLayout may be get better performance, but it is hard to expand to other device and dtype. Using Manish's mixed input gemm, it is easier to expand like fp8 matmul int4. |
Beta Was this translation helpful? Give feedback.
-
Hi what is the name of the GTC talk? this link is too old and is lost. Could you provide the name of this talk? Thank you so much ! |
Beta Was this translation helpful? Give feedback.
-
FastTransformer has kernels written in CUTLASS to support
fp16 x int8/int4
GEMM.source code:
https://github.com/NVIDIA/FasterTransformer/tree/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions
instantiation:
https://github.com/NVIDIA/FasterTransformer/tree/main/src/fastertransformer/kernels/cutlass_kernels/fpA_intB_gemm
paper:
https://arxiv.org/abs/2211.10017
GTC'23 talk:
https://register.nvidia.com/flow/nvidia/gtcspring2023/attendeeportal/page/sessioncatalog/session/1666226207768001N4Fe
Beta Was this translation helpful? Give feedback.
All reactions