diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index cb3e3d292efa..df6029db0de2 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -212,11 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); - -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); - -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 22c8f9c8a330..8c7ab9831667 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -18,14 +18,6 @@ namespace gpu { SmallVector reorderValues(const SmallVector &values, Type inType, Type ouType); -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter); - Type getElementType(Value value); class MultipleOperandsRange @@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); + subOperands = unpackI32s(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } resultVals = maybeDeduplicate(op, resultVals); resultVals = - packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 29b8865c03ae..56a82d7cc0fb 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc, return llvmStruct; } +// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer +// instructions to pack & unpack sub-word integers. A workaround is to +// store the results of tensors with dot operand encodings in i32 to +// facilitate instructions such as `ldmatrix`. +// +// TODO: Confirm if the problem is still there. +inline bool requiresI32Conversion(Type type) { + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return false; + auto dotOpEnc = dyn_cast(tensorTy.getEncoding()); + if (!dotOpEnc) + return false; + auto parent = dyn_cast(dotOpEnc.getParent()); + if (!(parent && parent.getVersionMajor() < 3)) + return false; + return true; +} + +inline SmallVector packI32s(const SmallVector &inValues, + Type type, RewriterBase &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter) { + if (!requiresI32Conversion(type)) + return inValues; + Type eltTy = + typeConverter->convertType(cast(type).getElementType()); + + SmallVector outValues; + int vecWidth = 32 / eltTy.getIntOrFloatBitWidth(); + auto vecTy = vec_ty(eltTy, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecTy); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + +inline SmallVector unpackI32s(const SmallVector &inValues, + Type type, RewriterBase &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter) { + if (!requiresI32Conversion(type)) + return inValues; + Type eltTy = + typeConverter->convertType(cast(type).getElementType()); + + SmallVector outValues; + for (auto v : inValues) { + auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecTy); + for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + inline SmallVector unpackLLElements(Location loc, Value llvmStruct, RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 283dd9165918..c39c408d9330 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -727,6 +727,10 @@ def TT_ReduceOp: TT_Op<"reduce", llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); unsigned getNumOperands(); + + // Returns the CombineOp iff this ReduceOp's region contains only + // one CombineOp other than the return, or nullptr if not applicable. + ::mlir::Operation *getSingleCombiner(); }]; } diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index c728cfbb32cf..47e3fca79bb1 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -679,6 +679,13 @@ class LinearLayout { // (i.e. every input bit affects the output). llvm::MapVector getFreeVariableMasks() const; + // Increase an input dimension without affecting the output dimension. The + // added free variables are mapped to 0, ensuring that the new input + // dimensions correspond directly to the existing output space. The function + // errors out if `newInDimSize` is less than the current size or the new size + // is not a power of 2. + LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; + std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 276a6e7004df..665b97aeebba 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - assert(!isMfmaToDotShortcut(srcTy, dstTy)); + assert(cvtNeedsSharedMemory(srcTy, dstTy)); // FIXME This is NOT entirely correct // This should be getElemOrder, but we don't have such a method diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 4915d7b1acda..aa9f8b01eae1 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) @@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { return matrixDimsCompatible && bDimCompatible; } -bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - if (mfmaLayout == nullptr || dotOperandLayout == nullptr) - return false; - // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is - // improved. In addition, we can enable this shortcut for regular MFMA - // layout when opIdx == 1. - return mfmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && - dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && - dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && - (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); -} - // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { @@ -655,8 +639,46 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (!(srcLayout.has_value() && dstLayout.has_value())) return std::nullopt; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + auto numSrcRegs = srcLayout->getInDimSize(kRegister); + auto numDstRegs = dstLayout->getInDimSize(kRegister); + // The `invertAndCompose` function will generate a layout that is injective + // by assigning new output dimensions to free variables. For instance, + // consider a scenario where `srcLayout` has a free variable in the lane + // dimension, while `dstLayout` has two free variables in the lane + // dimension and also a larger number of registers. + // The injective form of `srcLayout` will add only a single additional row + // to the transformation matrix, whereas the injective form of `dstLayout` + // will add two additional rows. This discrepancy causes misleading results + // because the matrices end up with a different number of rows. + // + // Take `dstLayout ⋅ srcLayout^-1` as an example: + // + // - `injective(dstLayout)`: [n, m] → [n + 2, m] + // - `injective(srcLayout)`: [n, m] → [n + 1, m] + // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] + // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + + // 1] → [n + 2, n + 1] + // + // Here, the `(n + 1)`-th row added by `dstLayout` represents the free + // variable in registers, and the `(n + 2)`-th row represents the free + // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` + // represents the free variable in lanes. As a result, the `(n + 1)`-th row + // in two layouts do not correspond to the same free variable. + // + // To address this issue, we pad the free variables in `srcLayout` and + // `dstLayout` to ensure they have the same number of registers. This + // guarantees that the resulting matrices have the same number of rows, + // ensuring consistency in the composition process. + auto numRegs = std::max(numSrcRegs, numDstRegs); + auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); + auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. - LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); + LinearLayout comp = + dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { @@ -693,15 +715,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, - // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully - // subsumed by the linear-layout checks. + // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and + // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout + // checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !isMmaToDotShortcut(srcTy, dstTy) && - !isMfmaToDotShortcut(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { @@ -711,20 +732,6 @@ bool atomicNeedsSharedMemory(Value value) { return true; } -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) - return true; - // dot_op = #mma - // when #mma = MmaEncoding - auto mmaLayout = dyn_cast(srcTy.getEncoding()); - auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); - return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && - mmaLayout.getWarpsPerCTA()[1] == 1 && - dotOperandLayout.getOpIdx() == 0 && - dotOperandLayout.getParent() == mmaLayout && - !srcTy.getElementType().isF32(); -} - namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a18b2cbc308c..65ee8cc0023e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -288,62 +288,90 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return rewriter.notifyMatchFailure( op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); assert(to_vector(conversion->getInDimNames()) == to_vector(conversion->getOutDimNames())); auto dims = conversion->getInDimNames(); - if (llvm::is_contained(dims, str_attr("block"))) { + if (llvm::is_contained(dims, kBlock)) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. return rewriter.notifyMatchFailure( op, "NYI: Transfer between different CTAs"); - } else if (llvm::is_contained(dims, str_attr("warp"))) { + } else if (llvm::is_contained(dims, kWarp)) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, str_attr("lane"))) { + } else if (llvm::is_contained(dims, kLane)) { // Case 3. Transfer between values in the same warp, in which case we try // to move values using warp shuffles, though if the pattern is // complicated enough we may fall back to using shared memory // TODO(Keren): implement warp shuffle instead of using the general // approach that uses shared memory - LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); - } else if (llvm::is_contained(dims, str_attr("register"))) { + } else if (llvm::is_contained(dims, kRegister) || + dstLayout.getInDimSize(kRegister) != + srcLayout.getInDimSize(kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, *conversion, adaptor, rewriter); + return transferWithinThread( + op, dstLayout.getFreeVariableMasks()[kRegister], + dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); } else { - // The two layouts are equivalent. We should probably remove these in - // RemoveLayoutConversion. - rewriter.replaceOp(op, adaptor.getSrc()); + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. + auto dstCvt = requiresI32Conversion(dstTy); + auto srcCvt = requiresI32Conversion(srcTy); + if (dstCvt || srcCvt) { + auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter); + inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(), + getTypeConverter()); + inVals = + packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter()); + auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals, + rewriter, op.getType()); + rewriter.replaceOp(op, res); + } else { + rewriter.replaceOp(op, adaptor.getSrc()); + } return success(); } } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, - OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, + const LinearLayout &conversion, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals; - outVals.resize(conversion.getInDimSize(kRegister)); - for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { - auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); + SmallVector outVals(numRegs); + for (int i = 0; i < numRegs; i++) { + // Remove free masks from the register index + // For example, if idx = 0b00111, and masks = 0b00100, then we get + // 0b00011. It means that register 7 (0b111) has the same value as + // register 3 (0b011). + auto idx = i & (~regMasks); + auto srcIdx = conversion.hasInDim(kRegister) + ? conversion.apply({{kRegister, idx}}).begin()->second + : idx; outVals[i] = inVals[srcIdx]; } + outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); @@ -375,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (auto dotOperand = dyn_cast(layout)) { if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { - if (product(getCTAsPerCGA(nvidiaMma)) > 1) { - return false; - } if (useLegacyMMAConversion) { return false; } @@ -387,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; return largeKWidth && nvidiaMma.isAmpere(); } + return false; } if (isa(layout)) { return true; @@ -428,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); } } + inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); // Pretty sure this is the identity function ATM // It'd be better to simply call `quotient({kBlock})` and @@ -447,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - // FIXME [Dot LL] - // We know it's just for largeKWidth case in Ampere - // In this case, we need to pack the outputs into i32 - if (isa(dstTy.getEncoding())) { - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - + outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8ee166866974..470e8b32b540 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -103,51 +103,6 @@ SmallVector reorderValues(const SmallVector &values, Type inType, llvm_unreachable("unimplemented code path"); } -SmallVector unpackI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - for (auto v : inValues) { - // cast i32 to appropriate eltType vector and extract elements - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecType); - for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - -SmallVector packI32(const SmallVector &inValues, Type srcTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return inValues; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return inValues; - SmallVector outValues; - auto eltType = typeConverter->convertType(tensorTy.getElementType()); - int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); - auto vecType = vec_ty(eltType, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecType); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} - int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; @@ -500,7 +455,7 @@ struct ElementwiseInlineAsmOpConversion auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); unpackedOperands.push_back( - unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter())); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -560,10 +515,11 @@ struct ElementwiseInlineAsmOpConversion unpackedResults[i], /*inType=*/op->getOperand(0).getType(), /*ouType=*/op->getResult(i).getType()); } - auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), - rewriter, loc, getTypeConverter()); - outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, - op->getResult(i).getType())); + auto dstTy = op->getResult(i).getType(); + unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc, + getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); } rewriter.replaceOp(op, outs); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 1a0c115a9ecf..e2ed0228de8d 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -184,42 +184,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - // FIXME [Dot LL] - // Ampere case - // In this case, we need to pack the outputs into i32 - if (auto dotOp = dyn_cast(dstTy.getEncoding())) { - if (auto parent = dyn_cast(dotOp.getParent())) { - if (parent.isAmpere()) { - if (elemLlvmTy.isInteger(8)) { - auto concat = [&](Value a1, Value a2, Value a3, Value a4) { - return or_( - or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), - or_(shl(zext(i32_ty, a3), i32_val(16)), - shl(zext(i32_ty, a4), i32_val(24)))); - }; - SmallVector outVals32(outVals.size() / 4); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], - outVals[4 * i + 2], outVals[4 * i + 3]); - } - outVals = outVals32; - } else { - assert(elemLlvmTy.isBF16() && "Unexpected element type"); - auto concat = [&](Value a, Value b) { - return or_(zext(i32_ty, bitcast(a, i16_ty)), - shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); - }; - - SmallVector outVals32(outVals.size() / 2); - for (int i = 0; i < outVals32.size(); ++i) { - outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); - } - outVals = outVals32; - } - } - } - } - + outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter); Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ffea5f3c67a6..e77e2d5c8691 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -503,6 +503,22 @@ llvm::SmallVector ReduceOp::getElementTypes() { return getElementTypesImpl(this->getOperands()); } +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- ScanOp -- diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 56af4eaef8b9..6978ccfb2553 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -551,8 +551,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { } std::optional -dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -895,7 +895,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); auto order = dot.getCTAOrder(); - assert(order[0] == 1 && order[1] == 0); + assert(order[0] == rank - 1 && order[1] == rank - 2); ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); @@ -903,13 +903,11 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { - if (auto mfmaLayout = llvm::dyn_cast(getParent())) { - return dotOperandMfmaToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(getParent())) { - // FIXME [Dot LL] - // Do this unconditionally - auto largeKWidth = getKWidth() == 8; - if (mma.isAmpere() && largeKWidth) { + auto parent = getParent(); + if (auto mfmaLayout = llvm::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(parent)) { + if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) { return ampereDotToLinearLayout(shape, *this); } } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..9f3d8fff491b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -290,7 +290,7 @@ struct MMAV3UseRegOperand dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!isMmaToDotShortcut(srcTy, newTy)) + if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) return failure(); Value newOperand = diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index bf017f8c6463..4319d1f086dd 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } +LinearLayout LinearLayout::resize(StringAttr inDim, + int32_t newInDimSize) const { + BasesT bases = getBases(); + assert(bases.contains(inDim) && "inDim not in layout"); + assert(llvm::isPowerOf2_32(newInDimSize) && + "newInDimSize must be a power of 2"); + assert(newInDimSize >= getInDimSize(inDim) && + "newInDimSize must be >= old size"); + auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); + for (int i = 0; i < numFreeVariables; i++) { + bases[inDim].push_back(std::vector(getNumOutDims(), 0)); + } + return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); +} + std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a0719c974f9c..208f6b80bfe5 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -6,7 +6,7 @@ #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 2054853b30c1..df4f5ab01feb 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,7 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index ef6733845721..a2f713faaf18 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -62,3 +62,151 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16x2 + tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16x2 + tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16_dpp + tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16_dpp + tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_dpp_max + tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 274, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 273, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 322, 10, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 323, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK: llvm.amdgcn.readlane + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64xf32, #blocked3>) -> f32 + tt.return + } +} + +// ----- + +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_xor_max + tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { + // CHECK: rocdl.ds_swizzle + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 12, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 264, 15, 3, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 10, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 260, 15, 5, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 78, 15, 15, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 177, 15, 15, false : i32 + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32xf32, #blocked4>) -> f32 + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 9705d4bf43d3..6e990248e56e 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -823,6 +823,110 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice = #triton_gpu.slice<{dim = 0, parent = #mma}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg + tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 16]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -847,6 +951,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 686e5a24e8dd..5dfd0f2a5f4c 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -460,429 +460,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } } -// ----- -// This test ensures that loads will not be moved across `for` loops. - -// CHECK-LABEL: tt.func public @_attn_bwd -// CHECK: tt.load -// CHECK: tt.load -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// Moved before the independent `tt.store` ops but not before the `for` ops. -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.store -// CHECK: tt.store -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// CHECK: tt.store - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c-1_i32 = arith.constant -1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %c32_i32 = arith.constant 32 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c16_i32 = arith.constant 16 : i32 - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> - %0 = tt.get_program_id z : i32 - %1 = arith.muli %0, %arg14 : i32 - %2 = arith.extsi %1 : i32 to i64 - %3 = arith.remsi %0, %arg13 : i32 - %4 = arith.muli %arg11, %3 : i32 - %5 = arith.divsi %0, %arg13 : i32 - %6 = arith.muli %arg10, %5 : i32 - %7 = arith.addi %4, %6 : i32 - %8 = arith.extsi %7 : i32 to i64 - %9 = tt.get_program_id x : i32 - %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 - %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 - %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 - %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 - %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 - %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 - %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 - %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 - %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 - %19 = arith.muli %9, %c128_i32 : i32 - %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> - %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> - %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> - %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> - %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> - %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> - %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> - %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> - %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> - %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> - %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> - %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> - %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> - %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> - %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> - %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> - %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> - %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> - %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> - %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> - %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> - %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> - %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> - %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> - %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %90 = arith.muli %arg12, %c16_i32 : i32 - %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> - %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> - %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { - %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> - %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> - %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> - %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> - %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %250 = arith.addi %arg18, %c16_i32 : i32 - %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> - } - %94 = arith.addi %19, %c128_i32 : i32 - %95 = arith.subi %arg14, %94 : i32 - %96 = arith.divsi %95, %c32_i32 : i32 - %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> - %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> - %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> - %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> - %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> - %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> - %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> - %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> - %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %122 = arith.muli %arg12, %c32_i32 : i32 - %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> - %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> - %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { - %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> - %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> - %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> - %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> - %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> - %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %244 = arith.addi %arg18, %c32_i32 : i32 - %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> - } - %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> - %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> - %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> - %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> - %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> - %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> - %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %174 = arith.muli %arg12, %c16_i32 : i32 - %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> - %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { - %206 = arith.addi %arg20, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> - %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> - %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> - %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> - %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %238 = arith.addi %arg17, %c16_i32 : i32 - %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 - } - triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %178 = arith.divsi %19, %c32_i32 : i32 - %179 = arith.muli %178, %c32_i32 : i32 - %180 = arith.subi %19, %179 : i32 - %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> - %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %196 = arith.muli %arg12, %c32_i32 : i32 - %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> - %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { - %206 = arith.addi %arg19, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> - %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> - %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> - %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> - %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 - } - triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> - %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> - tt.return - } -} - // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir new file mode 100644 index 000000000000..5c173ffb4858 --- /dev/null +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -0,0 +1,211 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +// Check the logic of sched-2nd-load optimizations +// + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> + +// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough +// The following tile sizes should apply the optimization +// 256x256x128 +// 256x256x64 +// The following tile sizes should NOT apply the optimization +// 256x64x128 +// 256x256x32 +// + +// Should apply: tile size 256x256x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should apply: tile size 256x256x64 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x64 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x64x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x64x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x64xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x256x32 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x32 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// Category 2: single dot with two loads and tile size is large enough (128x128x128). +// We make sure the move is legal. +// Should NOT apply: the 2nd load has a user before the dot +// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.store +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> + tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> + %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<128x128xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 3: two dots in the for loop. Make sure the optimization is not applied +// should NOT apply: two dots +// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot +// CHECK: triton_gpu.local_load +// CHECK-NEXT: triton_gpu.local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: triton_gpu.local_store +// CHECK-NEXT: triton_gpu.local_store +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index a7395f86dc50..6dbb0435e20c 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,6 +30,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Traits.h" + // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index a49e442d3984..9e174d545dd9 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -19,6 +19,17 @@ enum class ISAFamily { // Deduces the corresponding ISA family for the given target gfx |arch|. ISAFamily deduceISAFamily(llvm::StringRef arch); +// Here is a partial definition of DppCtrl enums. For the complete definition, +// please check: +// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 +enum class DppCtrl : uint32_t { + QUAD_PERM_FIRST = 0, + ROW_SHL0 = 0x100, + ROW_SHR0 = 0x110, + BCAST15 = 0x142, + BCAST31 = 0x143 +}; + } // namespace mlir::triton::AMD #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b7ee4efc72d0..d3ffaed2e8fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -115,56 +115,6 @@ struct LocalLoadOpConversion } }; -struct ConvertLayoutOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = op.getSrc(); - Value dst = op.getResult(); - auto srcTy = cast(src.getType()); - auto dstTy = cast(dst.getType()); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (isa(srcLayout) && - isa(dstLayout)) { - return lowerMfmaToDotOperand(op, adaptor, rewriter); - } - return failure(); - } - -private: - LogicalResult - lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - if (isMfmaToDotShortcut(srcTy, dstTy)) { - // vecSize is an number of sequential elements stored by one thread - // - For MFMA encoding (encoding of the result tensor of dot - // operation) it is 4 - // - For MFMA operand encoding it is - // dotOperandEncoding::kWidth, - // which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack - // = 1) - // - // For cases where these two values are equal MFMA and MFMA operand - // layouts are the same. - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - Value view = - packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } - return failure(); - } -}; } // namespace namespace mlir::triton::AMD { @@ -172,7 +122,6 @@ void populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index cece47227ea0..bce126ea4d72 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -38,8 +38,10 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMfmaToDotShortcut); + auto isShortcut = + mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory)); + + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut); /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a45efd4a7971..343dc7b3f37f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -694,6 +694,32 @@ struct AtomicCASOpConversion } }; +bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) { + return isaFamily == triton::AMD::ISAFamily::CDNA1 || + isaFamily == triton::AMD::ISAFamily::CDNA2 || + isaFamily == triton::AMD::ISAFamily::CDNA3; +} + +Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { + assert(val.getType().isInteger(32)); + auto loc = val.getLoc(); + Value old = i32_val(0); + int rowMask = 0b1111; // enable all rows + int bankMask = 0b1111; // enable all banks + bool boundCtrl = false; + auto dppMovOp = rewriter.create( + loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl); + return dppMovOp.getResult(); +} + +Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane +} + +Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane +} + struct AtomicRMWOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -765,10 +791,36 @@ struct AtomicRMWOpConversion // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; + Type packF16Ty = vec_ty(valueElemTy, 2); + + // In the case of unpaired f16 elements utilize dpp instructions to + // accelerate atomics. Here is an algorithm of lowering + // tt::atomicRmwOp(%ptr, %val, %mask): + // 0. Group thread by pairs. Master thread is (tid % 2 == 0); + // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so + // all the masters recieve value from secondary threads; + // 2. Take into account parity in the %mask value, build control flow + // structures according to it; + // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value; + // 4. All the threads send result of generated operation to (tid + 1) thread + // via dppUpdateOp shl, so all secondary thread also recieve their + // result. + // + // This approach enables us to use half the active threads committing atomic + // requests to avoid generating of code providing unified access to f16 + // element and reduce contantion. + bool useDppForPackedF16 = false; // tensor if (tensorTy) { auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16(); + unsigned availableVecSize = isF16Ty ? 2 : 1; + vec = std::min(vec, availableVecSize); + // Force F16 packing in the case it's not comming in as packed, but the + // ISA can support packed atomic instructions. + useDppForPackedF16 = + supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) && + vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD; // mask numElems = tensorTy.getNumElements(); } @@ -776,20 +828,49 @@ struct AtomicRMWOpConversion auto tid = tid_val(); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + if (useDppForPackedF16) + mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); auto memOrdering = op.getSem(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; + retType = useDppForPackedF16 ? packF16Ty : retType; SmallVector resultVals(elemsPerThread); - const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; + Value operand; + if (useDppForPackedF16) { + // Move %val to left neighbour to proceed packed atomic further. + Value packedVal = null(packF16Ty); + packedVal = + insert_element(packF16Ty, packedVal, valElements[i], i32_val(0)); + // Pack to i32 type to simplify transaction + packedVal = bitcast(packedVal, i32_ty); + Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); + // Unpack results back + Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty); + operand = undef(packF16Ty); + operand = + insert_element(packF16Ty, operand, valElements[i], i32_val(0)); + operand = insert_element( + packF16Ty, operand, + extract_element(valueElemTy, unpackedDppRes, i32_val(0)), + i32_val(1)); + } else if (vec == 1) { + operand = valElements[i]; + } else { + operand = undef(vecTy); + for (size_t ii = 0; ii < vec; ++ii) + operand = + insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); + } + Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); @@ -806,25 +887,11 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = rewriter - .create( - loc, *maybeKind, rmwPtr, valElements[i], - atomicMemOrdering, StringRef("agent")) - .getResult(); - - // NV for the f16v2 case generates one packed instruction. We have to - // create two separate instructions since LLVM::AtomicRMWOp doesn't - // support this. Can be optimized out with rocdl.raw.buffer.atomic. - if (f16v2) { - Value atom2 = - rewriter - .create( - loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], - atomicMemOrdering, StringRef("agent")) - .getResult(); - auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); - atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); - } + Value atom = + rewriter + .create(loc, *maybeKind, rmwPtr, operand, + atomicMemOrdering, StringRef("agent")) + .getResult(); if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = @@ -837,10 +904,25 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToStart(endBlock); Value retVal = endBlock->getArgument(0); if (tensorTy) { - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? retVal - : extract_element(valueElemTy, retVal, i32_val(ii)); + if (useDppForPackedF16) { + // Return packed to i32 result after atomic operation back from master + // lane. + auto packedRet = bitcast(retVal, i32_ty); + Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); + // Unpack results back + Value unpackedDppRes = bitcast(dppMovRes, packF16Ty); + retVal = insert_element( + packF16Ty, retVal, + extract_element(valueElemTy, unpackedDppRes, i32_val(1)), + i32_val(1)); + resultVals[i] = + extract_element(valueElemTy, retVal, urem(tid, i32_val(2))); + } else { + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? retVal + : extract_element(valueElemTy, retVal, i32_val(ii)); + } } } else { if (!atomicNeedsSharedMemory(op.getResult())) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 3a40d73c2a7c..525361fee603 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +using mlir::triton::AMD::DppCtrl; namespace mlir::triton::AMD { namespace { @@ -103,22 +104,22 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleXor(loc, rewriter, val, i); + return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleUp(loc, rewriter, val, i); + return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::programId(RewriterBase &rewriter, Location loc, @@ -126,11 +127,184 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } +// Cast and sext values into specific-length int to meet the requirements of +// instructions like UpdateDpp or readlane if necessary. +static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, + Value &val, Type fromType, + unsigned toBits) { + unsigned originalBits = fromType.getIntOrFloatBitWidth(); + Type toType = fromType; + + if (!fromType.isIntOrIndex()) { + val = bitcast(val, int_ty(originalBits)); + toType = int_ty(originalBits); + } + + if (originalBits < toBits) { + val = sext(int_ty(toBits), val); + toType = int_ty(toBits); + } + + return toType; +} + +// Trunc the value to specific length and then cast it to given type if +// necessary. This function is typically used in conjunction with +// castToAndSExtInt. +static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, + Value val, Type valType, + unsigned fromBits) { + unsigned originalBits = valType.getIntOrFloatBitWidth(); + Value toVal = val; + + if (originalBits < fromBits) { + toVal = trunc(int_ty(originalBits), toVal); + } + + if (!valType.isIntOrIndex()) { + toVal = bitcast(toVal, valType); + } + + return toVal; +} + bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - return false; + if (numLaneToReduce != 64) + return false; + + if (auto family = getISAFamily(); + family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) { + return false; + } + + Operation *reduxOp = op.getSingleCombiner(); + if (!reduxOp) + return false; + + auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, + uint32_t dppCtrl, int rowMask, + int bankMask) -> Value { + // DPP has limited support for data types, so here we need to + // cast non-integer types or integer types shorter than 32 bits + // to int32, except for fp32. + Type actualType = valType; + if (!valType.isF32()) { + actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); + } + + Value dppResult = + rewriter + .create(loc, actualType, src, src, + rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(true)) + .getRes(); + + if (!valType.isF32()) { + src = truncAndCastFromInt(rewriter, loc, src, valType, 32); + dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); + } + + IRMapping mapping; + mapping.map(reduxOp->getOperand(0), src); + mapping.map(reduxOp->getOperand(1), dppResult); + return rewriter.clone(*reduxOp, mapping)->getResult(0); + }; + + for (int i = 0; i < acc.size(); i++) { + Value buf; + auto valType = acc[i].getType(); + + /* + Here's the implementation of full-wavefront reduction using dpp. + https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + + Each step has a v_mov_dpp instruction following the redux op. In + some cases, the lower-level compiler could merge them into single + instruction. For example, v_mov_dpp + max => v_max_dpp. + + For gfx9, we have 64 threads per warp. These 64 threads are arranged + into 4 rows, with each row being 16 threads. Each 16 threads are arranged + further into 4 banks, with each bank being 4 threads. Overall it's in a + (row, bank, thread) structure. When shuffling, we use row/bank mask to + indicate which row/bank to participate. Then modifier like row_shr and + row_bcast means exact data movement schemes. In the following + instructions, taking row 0 as an example: + + Step 1: Right shift for 8 lanes. + lane 8-15 = redux(lane 0-7, lane 8-15) + + Step 2: Right shift for 4 lanes. + lane 12-15 = redux(lane 8-11, lane 12-15) + + Step 3: Right shift for 2 lanes. + lane 14-15 = redux(lane 12-13, lane 14-15) + + Step 4: Right shift for 1 lane. + lane 15 = redux(lane 14, lane 15) + + Step 5: Broadcast lane 15 of each row to all the lanes of its next row. + lane 16-31 = redux(lane 15, lane 16-31) + + Step 6: Broadcast lane 31 to lane 32-63. + lane 32-63 = redux(lane 31, lane 32-63) + + Now the reduction result is stored in lane 63. + + Step 7: Read the reduction result from lane 63 and broadcast with + readlane. + */ + + const int allRows = 0xf; + const int allBanks = 0xf; + + const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); + + // row_shr:8 + buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:4 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:2 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:1 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, + allRows, allBanks); + + // row_bcast:15 row_mask:0xa + buf = createDppReduxOpWithBoundCtrl( + valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); + + // row_bcast:31 + buf = createDppReduxOpWithBoundCtrl(valType, buf, + static_cast(DppCtrl::BCAST31), + allRows, allBanks); + + // Similarly, we need to cast data types for readlane instruction. + Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); + + // Get reduction result from lane 63 + std::string intrinsic = "llvm.amdgcn.readlane"; + Value result = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, + ValueRange{buf, i32_val(63)}) + ->getResult(0); + + result = truncAndCastFromInt(rewriter, loc, result, valType, 16); + + acc[i] = result; + } + + return true; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 542b1ecbb7fb..0bd401f1993a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,6 +8,8 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +using mlir::triton::AMD::DppCtrl; +using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -71,8 +73,9 @@ Type castToVectorType(Type ty) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, - Value i, int strideInt, ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, + int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -84,7 +87,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, if (bits < 32) val = sext(i32_ty, val); - val = shuffleCommon(loc, rewriter, val, i, strideInt, mode, clamp); + val = + shuffleCommon(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); if (bits < 32) val = trunc(int_ty(bits), val); @@ -98,8 +102,10 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shuffleCommon(loc, rewriter, val0, i, strideInt, mode, clamp); - val1 = shuffleCommon(loc, rewriter, val1, i, strideInt, mode, clamp); + val0 = shuffleCommon(loc, rewriter, isaFamily, val0, i, strideInt, mode, + clamp); + val1 = shuffleCommon(loc, rewriter, isaFamily, val1, i, strideInt, mode, + clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); @@ -134,13 +140,83 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value stride = i32_val(32); Value lineId = xor_(threadId, stride); return bpermute(lineId); - } else { - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); + } else if (strideInt == 16) { + Value offset = i32_val(0x401F); return rewriter.create(loc, valType, val, offset); + } else { + if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { + // DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback + // to ds_swizzle for other archs. + // + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); + } + + auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, + uint32_t dppCtrl, uint32_t rowMask, + uint32_t bankMask) { + return rewriter.create( + loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); + }; + + const int allRows = 0xf; + const int allBanks = 0xf; + + switch (strideInt) { + case 1: { + // quad_perm: 1, 0, 3, 2 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {1, 0, 3, 2}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 2: { + // quad_perm: 2, 3, 0, 1 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {2, 3, 0, 1}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 4: { + // row_shr:4 bank_mask: 0xa + auto ret = createDppOpWithoutBoundCtrl( + val, val, 4 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xa) + .getRes(); + + // row_shl:4 bank_mask: 0x5 + return createDppOpWithoutBoundCtrl( + ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x5); + } + case 8: { + // row_shr:8 bank_mask: 0xc + auto ret = createDppOpWithoutBoundCtrl( + val, val, 8 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xc) + .getRes(); + + // row_shl:8 bank_mask: 0x3 + return createDppOpWithoutBoundCtrl( + ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x3); + } + default: + assert(false && + "bfly shfl with stride >= 16 should not be handled by dpp."); + } } break; case ShflKind::up: { @@ -158,22 +234,27 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, return Value(); } -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, - i32_val(0x1f)); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, - i32_val(0x0)); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleIdx(loc, rewriter, val, i32_val(i)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { - return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, + i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 123234fd4824..d150531848e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -2,12 +2,14 @@ #define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; @@ -19,10 +21,18 @@ const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS"; const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 7da8083cfb92..c3a69a5f9a2a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_triton_library(TritonAMDGPUTransforms MfmaGroup.cpp DEPENDS + TritonAMDGPUIR TritonAMDGPUTransformsIncGen TritonGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index e122f15fd901..9371c8b5f897 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -5,23 +5,30 @@ #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; namespace ttg = mlir::triton::gpu; -namespace tt = mlir::triton; - -static bool isLocalLoadOrDotLayoutConversion(Operation *op) { - if (isa(op)) - return true; - if (auto cvt = dyn_cast(op)) - return isa(cvt.getType().getEncoding()); - return false; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +// Return true if the given moduleOp contains a pure matmul problem; i.e., +// single dot in the main loop. +static bool isPureMatmulProblem(ModuleOp moduleOp) { + bool isMatmul = true; + bool foundLoop = false; + moduleOp.walk([&](scf::ForOp forOp) -> void { + int counter = 0; + forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); + isMatmul = (isMatmul && (counter == 1)); + foundLoop = true; + }); + return foundLoop && isMatmul; } // Search through block to find earliest insertion point for move op. This can @@ -61,194 +68,311 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Return the first user in the same block of the given op. If the user is in a +// nested block then return the op owning the block. Return nullptr if not +// existing. +static Operation *getFirstUseInSameBlock(Operation *op) { + SmallVector usersInSameBlock; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + usersInSameBlock.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; +} + // Check if the operation opInsideLoop is inside any scf::ForOp and // opOutsideLoop is not inside the same loop. -bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, - mlir::Operation *opOutsideLoop) { +static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { scf::ForOp parentForOp = opInsideLoop->getParentOfType(); return parentForOp && !parentForOp->isAncestor(opOutsideLoop); } -class TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { -public: - TritonAMDGPUReorderInstructionsPass() = default; - - Operation *getFirstUse(Operation *op) { - std::vector users; - for (auto user : op->getUsers()) { - if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) - users.push_back(ancestor); - } - auto minOpIt = std::min_element(users.begin(), users.end(), - [](mlir::Operation *a, mlir::Operation *b) { - return a->isBeforeInBlock(b); - }); - return minOpIt != users.end() ? *minOpIt : nullptr; - } +//===----------------------------------------------------------------------===// +// Reorder mechanisms +//===----------------------------------------------------------------------===// - void runOnOperation() override { - ModuleOp m = getOperation(); +// Sink dot layout conversions into loops to decrease register pressure when +// possible. +static void sinkDotConversion(ModuleOp moduleOp) { + DenseMap opToMove; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { + Attribute encoding = op.getType().getEncoding(); + if (!isa_and_nonnull(encoding)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove[op] = user; + }); - // Sink shared memory loads and layout conversions into loops to decrease - // register pressure when possible. - DenseMap opToMove; - m.walk([&](Operation *op) { - if (!isLocalLoadOrDotLayoutConversion(op)) - return; - if (!op->hasOneUse()) - return; - Operation *user = *op->getUsers().begin(); - if (user->getParentOfType() == - op->getParentOfType()) + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); +} + +// Adjust the placement of shared memory writes and reads to immediately follow +// the definition of their operands in case where shared memory write is in the +// loop but its operand is not. +// +// This is a heuristic driven by optimizing fused attention by hoisting Q tensor +// shared memory read/write operations outside of the loop, as Q is a loop +// invariant and can be loaded once before entering the loop. But it should be +// generally applicable. +// +// There are two possible patterns for this adjustment depending on whether the +// write to shared memory is performed using an optional `local_alloc` argument +// or a `local_store` instruction. +// +// 1) %1 = some_op ... (typically a load or an operation that scales the tensor +// after loading) +// %2 = local_alloc %1 +// %3 = local_load %2 +// +// 2) %1 = some_op ... +// %2 = local_alloc +// %3 = local_store %1, %2 +// %4 = local_load %2 +static void hoistLocalLoad(ModuleOp moduleOp) { + moduleOp.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; + + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) return; - opToMove.insert({op, user}); - }); - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); - opToMove.clear(); - - // Adjust the placement of LDS writes and reads to immediately follow the - // definition of their operands in case where LDS write is in the - // loop but it's operand is not. This is a heuristic for optimizing fused - // attention by hoisting Q tensor LDS read/write operations outside of the - // loop, as Q is a loop invariant and can be loaded once before entering the - // loop. - // There are two possible patterns for this adjustment depending on - // whether the write to LDS is performed using an optional `local_alloc` - // argument or a `local_store` instruction. - // - // clang-format off - // - // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) - // %2 = local_alloc %1 - // %3 = local_load %2 - // - // 2) %1 = some_op ... - // %2 = local_alloc - // %3 = local_store %1, %2 - // %4 = local_load %2 - // - // clang-format on - m.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) + + auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) return; - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) - return; + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } - auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { - return; - } + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); + // localStore comes before localLoad in block. + Operation *localStore = getFirstUseInSameBlock(localAlloc); + if (!isa(localStore)) + return; - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } - // localStore comes before localLoad in block. - Operation *localStore = getFirstUse(localAlloc); - if (!isa(localStore)) - return; + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); +} - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; - } +// Sink conversion after the last dealloc but before the first use in its block. +// This helps to avoid unnecessary shared memory allocation. +static void moveDownCoversion(ModuleOp moduleOp) { + SmallVector convertOps; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); + for (auto op : convertOps) { + Operation *user = getFirstUseInSameBlock(op); + for (auto it = Block::iterator(op), ie = op->getBlock()->end(); + it != ie && &*it != user; ++it) + if (isa(&*it)) + op->moveAfter(&*it); + } +} - // Sink conversion after the last dealloc but before the first use ancestor - // in its block. This helps to avoid unnecessary shared memory allocation. - m.walk([&](triton::gpu::ConvertLayoutOp op) { - auto curr = mlir::Block::iterator(op); - for (; &*curr != getFirstUse(op); curr++) - if (isa(&*curr)) - op->moveAfter(&*curr); - }); +// Move transpositions just after their definition. +static void moveUpTranspose(ModuleOp moduleOp) { + SmallVector transOps; + moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); - // Move transpositions just after their definition. - m.walk([&](triton::TransOp op) { - if (Operation *argOp = op.getSrc().getDefiningOp()) - op->moveAfter(argOp); - }); + for (auto op : transOps) + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); +} - SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); - // Move local_stores early if dependence distance greater than - // one iteration. - // Best perf on GEMM when these precede global loads. - m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - - for (auto op : llvm::reverse(moveOps)) { - // Gather use-def chain in block. - Block *block = op->getBlock(); - bool leadsToLoad = false; - SetVector backwardSet; - - BackwardSliceOptions options; - options.omitBlockArguments = true; - options.inclusive = false; - options.filter = [&](Operation *defOp) -> bool { - Block *defBlock = defOp->getBlock(); - if (!block->findAncestorOpInBlock(*defOp)) - return false; - // Check for a `load` dependent path. - leadsToLoad |= isa(defOp); - // Only move ops residing in the same block. - return defBlock == block; - }; - mlir::getBackwardSlice(op, &backwardSet, options); - backwardSet.insert(op); - - // Don't move a local_store if its source is a load from - // the same iteration. - if (isa(op) && leadsToLoad) - continue; - - auto ipoint = findEarlyInsertionPoint(block, op); - // Remove ops that already precede the insertion point. This is done - // before moves happen to avoid `Operation::isBeforeInBlock` N^2 - // complexity. - - SmallVector dfg = backwardSet.takeVector(); - if (ipoint != block->end()) { - // Move ops to insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveAfter(block, ipoint); - } else { - // Move ops to block begin. - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveBefore(block, block->begin()); - } +// Schedule global load and local store ops for better GEMM performance. +static void scheduleGlobalLoadLocalStore(ModuleOp m) { + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } + } +} + +/** + * Sched-load optimization for matmul kernels with large tile sizes + * The basic idea of sched-load optimization is to sink the 2nd tt.load + * after local_load so that global_load instructions can be interleaved with + * mfma's. This can help hide the issue latency of global_load instructions + * and improve performance on MI300X. + * + * It's assumed that the IR before this optimization has the following + * structure: + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * tileB = tt.load b_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * After this optimization, the IR is transformed to + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * For now, we don't have a perfect hueristic about when should this + * optimization be applied. Therefore, we implement a simple hueristic that + * this is applied when the tile size of A and B are large enough, i.e. + * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical + * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We + * are experimenting how to better control instruction scheduling and enable + * such optimizations. + */ +static void sinkSecondLoad(ModuleOp m) { + m.walk([&](scf::ForOp forOp) -> void { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = loadOps[0]; + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = loadOps[1]; + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); + }); +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { + void runOnOperation() override { + ModuleOp m = getOperation(); + + hoistLocalLoad(m); + + sinkDotConversion(m); + moveDownCoversion(m); + + moveUpTranspose(m); + + if (isPureMatmulProblem(m)) { + scheduleGlobalLoadLocalStore(m); + sinkSecondLoad(m); } } }; +} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 71fd3c0cd4e7..96289bbb2e47 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -631,46 +631,6 @@ struct ConvertLayoutOpConversion convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); return success(); } - - if (isMmaToDotShortcut(srcTy, dstTy)) { - // get source values - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned elems = getTotalElemsPerThread(srcTy); - Type elemTy = - this->getTypeConverter()->convertType(srcTy.getElementType()); - // for the destination type, we need to pack values together - // so they can be consumed by tensor core operations - SmallVector vecVals; - // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer - // instructions to pack & unpack sub-word integers. A workaround is to - // store the results of ldmatrix in i32 - auto elemSize = elemTy.getIntOrFloatBitWidth(); - if (auto intTy = dyn_cast(elemTy) && elemSize <= 16) { - auto fold = 32 / elemSize; - for (unsigned i = 0; i < elems; i += fold) { - Value val = i32_val(0); - for (unsigned j = 0; j < fold; j++) { - auto ext = - shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); - val = or_(i32_ty, val, ext); - } - vecVals.push_back(bitcast(val, i32_ty)); - } - } else { - unsigned vecSize = std::max(32 / elemSize, 1); - Type vecTy = vec_ty(elemTy, vecSize); - for (unsigned i = 0; i < elems; i += vecSize) { - Value packed = rewriter.create(loc, vecTy); - for (unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(bitcast(packed, i32_ty)); - } - } - Value view = - packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } return failure(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index cf0ddc248dd1..36b14e270b27 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -70,10 +70,18 @@ struct DecomposeUnsupportedConversions : public mlir::triton::impl::DecomposeUnsupportedNVIDIAConversionsBase< DecomposeUnsupportedConversions> { void runOnOperation() override { + // FIXME [Dot LL] + // Remove the decomposeTensorCoreToDotLayoutConversion class entirely after + // we have enabled the new layout conversion for all the cases. + auto nvidiaShortCutFn = [&](RankedTensorType srcTy, + RankedTensorType dstTy) { + return matchMmaV3AndDotOperandLayout(srcTy, dstTy) || + cvtReordersRegisters(srcTy, dstTy); + }; ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, - isMmaToDotShortcut); + nvidiaShortCutFn); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 75f9354104b1..d1cef15a354e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -93,20 +93,12 @@ static std::optional matchReduxKind(triton::ReduceOp op, int computeCapability) { if (computeCapability < 80) return std::nullopt; - if (op.getNumOperands() != 1 || op.getNumResults() != 1) - return std::nullopt; - Block *block = &(*op.getCombineOp().begin()); - Operation *yield = block->getTerminator(); - Operation *reduceOp = yield->getOperand(0).getDefiningOp(); - if (!reduceOp || reduceOp->getNumOperands() != 2 || - reduceOp->getNumResults() != 1) + Operation *reduceOp = op.getSingleCombiner(); + if (!reduceOp) return std::nullopt; auto intType = dyn_cast(reduceOp->getResultTypes()[0]); if (!intType || intType.getWidth() > 32) return std::nullopt; - if (reduceOp->getOperand(0) != block->getArgument(0) || - reduceOp->getOperand(1) != block->getArgument(1)) - return std::nullopt; if (isa(reduceOp)) return NVVM::ReduxKind::ADD; if (isa(reduceOp)) diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index fd65233e5c6b..d4c15bbad03f 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -41,9 +41,9 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } - DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, - ArrayRef order) { - auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, + ArrayRef warps) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, {1, 0}); return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); } @@ -301,6 +301,19 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { + EXPECT_EQ(toLinearLayout({16, 16}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { EXPECT_EQ(toLinearLayout({32, 32}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), @@ -502,7 +515,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -511,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -524,7 +537,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1})), LinearLayout( { {S("register"), @@ -534,7 +547,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1})), LinearLayout( { {S("register"), @@ -554,7 +567,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1})), LinearLayout( { {S("register"), diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index f006447002ef..897172fd6d34 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } +TEST_F(LinearLayoutTest, Resize) { + auto init = LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")}); + EXPECT_EQ(init.resize(S("in0"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in1"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); +} + } // anonymous namespace } // namespace mlir::triton