diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 9061e24fbebd..e29c7eaecc7f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1143,8 +1143,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, [[nodiscard]] bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + ArrayRef shmemStrides, ArrayRef shmemOffsets, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback); inline DenseMap getSwizzledSharedPtrs( @@ -1318,7 +1318,8 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, void storeDistributedToShared( MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + ArrayRef dstOffsets, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 4cea14f0957f..ab259880c5dc 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -30,9 +30,10 @@ void lowerDistributedToShared( auto smemBase = smemObj.getBase(); auto dstStrides = smemObj.getStrides(); + auto dstOffsets = smemObj.getOffsets(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo, llvmOpCount); + dstOffsets, loc, rewriter, targetInfo, llvmOpCount); } struct GlobalScratchAllocOpConversion @@ -138,18 +139,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // FIXME [Dot LL] // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedDotOpLayout(RankedTensorType type) { - auto layout = type.getEncoding(); - auto bitwidth = type.getElementType().getIntOrFloatBitWidth(); - if (auto dot = dyn_cast(layout)) { + static bool isSupportedDotOpLayout(SharedEncodingAttr srcLayout, + RankedTensorType dstTy) { + auto dstLayout = dstTy.getEncoding(); + auto bitwidth = dstTy.getElementType().getIntOrFloatBitWidth(); + auto rank = dstTy.getRank(); + if (auto dot = dyn_cast(dstLayout)) { + auto vecWidth = 32 / bitwidth; auto kWidth = dot.getKWidth(); - // Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy: - // - kWidth == 8 - // - kWidth == 4, bitwidth = 32 + auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2; if (auto mma = dyn_cast(dot.getParent())) { - bool legacyLoweringIsBuggy = - kWidth >= 8 || (kWidth == 4 && bitwidth == 32); - return legacyLoweringIsBuggy && mma.isAmpere(); + auto needTrans = kOrder != srcLayout.getOrder()[0]; + auto canUseLdmatrix = + (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth); + return !canUseLdmatrix && mma.isAmpere(); } if (isa(dot.getParent())) return true; @@ -162,12 +165,11 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); + auto srcLayout = cast(srcTy.getEncoding()); Attribute dstLayout = dstTy.getEncoding(); - if (isa(srcLayout) && - (isa( + if ((isa( dstLayout) || - isSupportedDotOpLayout(dstTy))) { + isSupportedDotOpLayout(srcLayout, dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -207,7 +209,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); auto dstLayout = dstTy.getEncoding(); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && + assert((dstShape.size() <= 2 || + isSupportedDotOpLayout(srcSharedLayout, dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 57ba4f9ddf85..404d285988ec 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -161,8 +161,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + ArrayRef shmemStrides, ArrayRef shmemOffsets, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback) { MLIRContext *ctx = rewriter.getContext(); @@ -182,6 +182,7 @@ bool emitTransferBetweenRegistersAndShared( return false; } auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); + auto cSwizzleOffset = shmemOffsets[sharedOrder[0]]; // sharedLayout's in-dims are currently (offset, block). Reshape to // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional @@ -278,8 +279,8 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, const TargetInfoBase &target) { SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( - dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(), - smemObj.getStrides(), loc, rewriter, target, + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.base, + smemObj.getStrides(), smemObj.getOffsets(), loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { auto vecVal = load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * @@ -298,12 +299,14 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, + ArrayRef dstOffsets, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { + dstStrides, dstOffsets, loc, rewriter, target, + [&](VectorType vecTy, Value vecAddr) { ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); srcVals = srcVals.drop_front(vecTy.getNumElements()); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 43c87af487a1..b4ffe1dd063e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -932,17 +932,17 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, auto warpsPerCTAMma = mma.getWarpsPerCTA(); std::vector> warps; if (isA) { - for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) { warps.push_back({0, 0}); } - for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) { warps.push_back({0, i}); } } else { - for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) { warps.push_back({0, i}); } - for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) { warps.push_back({0, 0}); } } @@ -950,6 +950,9 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, for (auto &w : warps) { w.push_back(0); } + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, 0, i}); + } } ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a439b89270a9..9189c125630e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -961,8 +961,9 @@ struct AsyncCopyGlobalToLocalOpConversion VectorType vecTy; SmallVector shmemAddrs; bool ok = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc, - rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { + srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, + smemObj.offsets, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { vecTy = vecTy_; shmemAddrs.push_back(shmemAddr); }); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d662537ed72d..ffbd7c79648e 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -47,6 +47,12 @@ class LinearLayoutConversionsTest : public ::testing::Test { return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } + DotOperandEncodingAttr dotMMAv2_3d(int idx, int kWidth, unsigned warps) { + auto mmaLayout = + mma(2, 0, {1, 16, 8}, {warps, 1, 1}, {1, 1, 1}, {1, 1, 1}, {2, 1, 0}); + return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + } + AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, unsigned nDim, bool isTransposed) { SmallVector cpg(warps.size(), 1u); @@ -535,6 +541,20 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_3d) { + EXPECT_EQ( + toLinearLayout({32, 16, 16}, dotMMAv2_3d(0, 2, 16)), + LinearLayout( + { + {S("register"), {{0, 0, 1}, {0, 8, 0}, {0, 0, 8}, {16, 0, 0}}}, + {S("lane"), + {{0, 0, 2}, {0, 0, 4}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{1, 0, 0}, {2, 0, 0}, {4, 0, 0}, {8, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})),