From 7afcbd1e889becd96d578b2da32ad715b15a8068 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 10:30:59 -0500 Subject: [PATCH 1/7] Update --- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 4cea14f0957f..3079cd03f3db 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -138,18 +138,21 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // FIXME [Dot LL] // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedDotOpLayout(RankedTensorType type) { - auto layout = type.getEncoding(); - auto bitwidth = type.getElementType().getIntOrFloatBitWidth(); - if (auto dot = dyn_cast(layout)) { + static bool isSupportedDotOpLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + auto bitwidth = dstTy.getElementType().getIntOrFloatBitWidth(); + auto rank = dstTy.getRank(); + if (auto dot = dyn_cast(dstLayout)) { + auto vecWidth = 32 / bitwidth; auto kWidth = dot.getKWidth(); - // Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy: - // - kWidth == 8 - // - kWidth == 4, bitwidth = 32 + auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2; if (auto mma = dyn_cast(dot.getParent())) { - bool legacyLoweringIsBuggy = - kWidth >= 8 || (kWidth == 4 && bitwidth == 32); - return legacyLoweringIsBuggy && mma.isAmpere(); + auto needTrans = kOrder != srcLayout.getOrder()[0]; + auto canUseLdmatrix = + (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth); + return !canUseLdmatrix && mma.isAmpere(); } if (isa(dot.getParent())) return true; @@ -164,10 +167,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { RankedTensorType dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (isa(srcLayout) && - (isa( + assert(isa(srcLayout) && "Unexpected src layout"); + if ((isa( dstLayout) || - isSupportedDotOpLayout(dstTy))) { + isSupportedDotOpLayout(srcTy, dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } From 6e58e85da6671b17335e977f398cf38c840c545e Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 10:49:06 -0500 Subject: [PATCH 2/7] Update --- lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 3079cd03f3db..264be07c8374 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -138,9 +138,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // FIXME [Dot LL] // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedDotOpLayout(RankedTensorType srcTy, + static bool isSupportedDotOpLayout(SharedEncodingAttr srcLayout, RankedTensorType dstTy) { - auto srcLayout = cast(srcTy.getEncoding()); auto dstLayout = dstTy.getEncoding(); auto bitwidth = dstTy.getElementType().getIntOrFloatBitWidth(); auto rank = dstTy.getRank(); @@ -165,12 +164,11 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); + auto srcLayout = cast(srcTy.getEncoding()); Attribute dstLayout = dstTy.getEncoding(); - assert(isa(srcLayout) && "Unexpected src layout"); if ((isa( dstLayout) || - isSupportedDotOpLayout(srcTy, dstTy))) { + isSupportedDotOpLayout(srcLayout, dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -210,7 +208,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); auto dstLayout = dstTy.getEncoding(); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && + assert((dstShape.size() <= 2 || + isSupportedDotOpLayout(srcSharedLayout, dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( From 3194817fcd73aa4297a2e17f7756b8c317676d36 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 16:53:05 -0500 Subject: [PATCH 3/7] Update --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 43c87af487a1..aa05d0c758f2 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -932,17 +932,17 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, auto warpsPerCTAMma = mma.getWarpsPerCTA(); std::vector> warps; if (isA) { - for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) { warps.push_back({0, 0}); } - for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) { warps.push_back({0, i}); } } else { - for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) { warps.push_back({0, i}); } - for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) { warps.push_back({0, 0}); } } @@ -950,6 +950,9 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, for (auto &w : warps) { w.push_back(0); } + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, 0, i}); + } } ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); From 768104ee9975604c0a05a02308ac6021fc1cd545 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 17:19:07 -0500 Subject: [PATCH 4/7] Update --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 4 ++-- .../TritonGPU/LinearLayoutConversionsTest.cpp | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index aa05d0c758f2..b4ffe1dd063e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -939,10 +939,10 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, warps.push_back({0, i}); } } else { - for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) { warps.push_back({0, i}); } - for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) { warps.push_back({0, 0}); } } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d662537ed72d..ffbd7c79648e 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -47,6 +47,12 @@ class LinearLayoutConversionsTest : public ::testing::Test { return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } + DotOperandEncodingAttr dotMMAv2_3d(int idx, int kWidth, unsigned warps) { + auto mmaLayout = + mma(2, 0, {1, 16, 8}, {warps, 1, 1}, {1, 1, 1}, {1, 1, 1}, {2, 1, 0}); + return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + } + AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, unsigned nDim, bool isTransposed) { SmallVector cpg(warps.size(), 1u); @@ -535,6 +541,20 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_3d) { + EXPECT_EQ( + toLinearLayout({32, 16, 16}, dotMMAv2_3d(0, 2, 16)), + LinearLayout( + { + {S("register"), {{0, 0, 1}, {0, 8, 0}, {0, 0, 8}, {16, 0, 0}}}, + {S("lane"), + {{0, 0, 2}, {0, 0, 4}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{1, 0, 0}, {2, 0, 0}, {4, 0, 0}, {8, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), From f1d861c67f9b0a8ebbfa046930209871bd53de7f Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 22:16:32 -0500 Subject: [PATCH 5/7] Update --- include/triton/Conversion/TritonGPUToLLVM/Utility.h | 7 ++++--- lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp | 3 ++- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 10 ++++++---- .../lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +++-- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 9061e24fbebd..e29c7eaecc7f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1143,8 +1143,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, [[nodiscard]] bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + ArrayRef shmemStrides, ArrayRef shmemOffsets, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback); inline DenseMap getSwizzledSharedPtrs( @@ -1318,7 +1318,8 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, void storeDistributedToShared( MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + ArrayRef dstOffsets, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 264be07c8374..ab259880c5dc 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -30,9 +30,10 @@ void lowerDistributedToShared( auto smemBase = smemObj.getBase(); auto dstStrides = smemObj.getStrides(); + auto dstOffsets = smemObj.getOffsets(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo, llvmOpCount); + dstOffsets, loc, rewriter, targetInfo, llvmOpCount); } struct GlobalScratchAllocOpConversion diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 57ba4f9ddf85..c063b854088c 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -161,8 +161,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + ArrayRef shmemStrides, ArrayRef shmemOffsets, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback) { MLIRContext *ctx = rewriter.getContext(); @@ -182,6 +182,7 @@ bool emitTransferBetweenRegistersAndShared( return false; } auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); + auto cSwizzleOffset = shmemOffsets[sharedOrder[0]]; // sharedLayout's in-dims are currently (offset, block). Reshape to // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional @@ -258,6 +259,7 @@ bool emitTransferBetweenRegistersAndShared( {kLane, laneId}, {kWarp, warpId}, {kBlock, zero}})))); + multiDimShmemOffset[0] = add(multiDimShmemOffset[0], cSwizzleOffset); // Reorder strides according to `order`. This way they match the // multi-dimensional offsets in regToSharedLayout. @@ -278,8 +280,8 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, const TargetInfoBase &target) { SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( - dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(), - smemObj.getStrides(), loc, rewriter, target, + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.base, + smemObj.getStrides(), smemObj.getOffsets(), loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { auto vecVal = load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a439b89270a9..9189c125630e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -961,8 +961,9 @@ struct AsyncCopyGlobalToLocalOpConversion VectorType vecTy; SmallVector shmemAddrs; bool ok = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc, - rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { + srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, + smemObj.offsets, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { vecTy = vecTy_; shmemAddrs.push_back(shmemAddr); }); From ba0ff0b438cc944628dbec56d5dbca3a4b713ed5 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 22:19:15 -0500 Subject: [PATCH 6/7] Update --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index c063b854088c..5ecfe9130e61 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -300,12 +300,14 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, + ArrayRef dstOffsets, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { + dstStrides, dstOffsets, loc, rewriter, target, + [&](VectorType vecTy, Value vecAddr) { ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); srcVals = srcVals.drop_front(vecTy.getNumElements()); From 34c21fc9ebac9750f9e6d379f824c10e50e1ca58 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 14 Nov 2024 22:54:55 -0500 Subject: [PATCH 7/7] Update --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 5ecfe9130e61..404d285988ec 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -259,7 +259,6 @@ bool emitTransferBetweenRegistersAndShared( {kLane, laneId}, {kWarp, warpId}, {kBlock, zero}})))); - multiDimShmemOffset[0] = add(multiDimShmemOffset[0], cSwizzleOffset); // Reorder strides according to `order`. This way they match the // multi-dimensional offsets in regToSharedLayout.