Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND][DRAFT] Use linear layout for loading mmav2 dot operand tensors from shared memory #5154

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1143,8 +1143,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
ArrayRef<Value> shmemStrides, ArrayRef<Value> shmemOffsets, Location loc,
RewriterBase &rewriter, const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
Expand Down Expand Up @@ -1318,7 +1318,8 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
ArrayRef<Value> dstOffsets, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
Expand Down
35 changes: 19 additions & 16 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ void lowerDistributedToShared(

auto smemBase = smemObj.getBase();
auto dstStrides = smemObj.getStrides();
auto dstOffsets = smemObj.getOffsets();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo, llvmOpCount);
dstOffsets, loc, rewriter, targetInfo, llvmOpCount);
}

struct GlobalScratchAllocOpConversion
Expand Down Expand Up @@ -138,18 +139,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
static bool isSupportedDotOpLayout(SharedEncodingAttr srcLayout,
RankedTensorType dstTy) {
auto dstLayout = dstTy.getEncoding();
auto bitwidth = dstTy.getElementType().getIntOrFloatBitWidth();
auto rank = dstTy.getRank();
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
auto vecWidth = 32 / bitwidth;
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
return legacyLoweringIsBuggy && mma.isAmpere();
auto needTrans = kOrder != srcLayout.getOrder()[0];
auto canUseLdmatrix =
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
return !canUseLdmatrix && mma.isAmpere();
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
Expand All @@ -162,12 +165,11 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
if ((isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
isSupportedDotOpLayout(srcLayout, dstTy))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -207,7 +209,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
assert((dstShape.size() <= 2 ||
isSupportedDotOpLayout(srcSharedLayout, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
15 changes: 9 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
ArrayRef<Value> shmemStrides, ArrayRef<Value> shmemOffsets, Location loc,
RewriterBase &rewriter, const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

Expand All @@ -182,6 +182,7 @@ bool emitTransferBetweenRegistersAndShared(
return false;
}
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
auto cSwizzleOffset = shmemOffsets[sharedOrder[0]];

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
Expand Down Expand Up @@ -278,8 +279,8 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
const TargetInfoBase &target) {
SmallVector<Value> ret;
bool success = emitTransferBetweenRegistersAndShared(
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(),
smemObj.getStrides(), loc, rewriter, target,
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.base,
smemObj.getStrides(), smemObj.getOffsets(), loc, rewriter, target,
[&](VectorType vecTy, Value vecAddr) {
auto vecVal = load(vecTy, vecAddr);
vecVal.setAlignment(vecTy.getNumElements() *
Expand All @@ -298,12 +299,14 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
ArrayRef<Value> dstOffsets, Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
dstStrides, dstOffsets, loc, rewriter, target,
[&](VectorType vecTy, Value vecAddr) {
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
srcVals = srcVals.drop_front(vecTy.getNumElements());

Expand Down
11 changes: 7 additions & 4 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,24 +932,27 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
auto warpsPerCTAMma = mma.getWarpsPerCTA();
std::vector<std::vector<int32_t>> warps;
if (isA) {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) {
warps.push_back({0, 0});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) {
warps.push_back({0, i});
}
} else {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
for (int i = 1; i < warpsPerCTAMma[rank - 1]; i *= 2) {
warps.push_back({0, i});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
for (int i = 1; i < warpsPerCTAMma[rank - 2]; i *= 2) {
warps.push_back({0, 0});
}
}
if (rank == 3) {
for (auto &w : warps) {
w.push_back(0);
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, 0, i});
}
}

ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,8 +961,9 @@ struct AsyncCopyGlobalToLocalOpConversion
VectorType vecTy;
SmallVector<Value> shmemAddrs;
bool ok = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc,
rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) {
srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides,
smemObj.offsets, loc, rewriter, targetInfo,
[&](VectorType vecTy_, Value shmemAddr) {
vecTy = vecTy_;
shmemAddrs.push_back(shmemAddr);
});
Expand Down
20 changes: 20 additions & 0 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class LinearLayoutConversionsTest : public ::testing::Test {
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
}

DotOperandEncodingAttr dotMMAv2_3d(int idx, int kWidth, unsigned warps) {
auto mmaLayout =
mma(2, 0, {1, 16, 8}, {warps, 1, 1}, {1, 1, 1}, {1, 1, 1}, {2, 1, 0});
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
}

AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
unsigned nDim, bool isTransposed) {
SmallVector<unsigned> cpg(warps.size(), 1u);
Expand Down Expand Up @@ -535,6 +541,20 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_3d) {
EXPECT_EQ(
toLinearLayout({32, 16, 16}, dotMMAv2_3d(0, 2, 16)),
LinearLayout(
{
{S("register"), {{0, 0, 1}, {0, 8, 0}, {0, 0, 8}, {16, 0, 0}}},
{S("lane"),
{{0, 0, 2}, {0, 0, 4}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}},
{S("warp"), {{1, 0, 0}, {2, 0, 0}, {4, 0, 0}, {8, 0, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1"), S("dim2")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
EXPECT_EQ(
toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})),
Expand Down
Loading