diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 933ce6024872..88a090e64fec 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -2,6 +2,8 @@ #define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ #include "mlir/Dialect/SCF/IR/SCF.h" +#include +#include #include namespace mlir { @@ -35,6 +37,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, // Return the minClusterId and maxClusterId for the given ForOp. std::pair getMinMaxCluster(scf::ForOp &forOp); std::pair getStageCluster(Operation *op); +std::optional> maybeGetStageCluster(Operation *op); void setStageCluster(Operation *op, int stage, int cluster); } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h new file mode 100644 index 000000000000..8f36b7732f8f --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -0,0 +1,99 @@ +#pragma once +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::nvidia_gpu { + +constexpr inline int TMA_SIZE_BYTES = 128; +constexpr inline int TMA_ALIGN = 128; + +template +mlir::LogicalResult createTMADesc(mlir::Value tmaPtr, + mlir::triton::MakeTensorDescOp op, + BuilderT &builder) { + using namespace mlir; + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto mkI32Constant = [&](int32_t val) { + return builder.template create( + loc, builder.getI32Type(), builder.getI32IntegerAttr(val)); + }; + + auto elemType = op.getBase().getType().getPointeeType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + + int32_t contig_dim_size = op.getTensorShape().back(); + int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; + if (contig_dim_size_in_bytes > 128) { + contig_dim_size = 128 / elemSize; + } + llvm::SmallVector boxDim; + boxDim.push_back(mkI32Constant(contig_dim_size)); + for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { + boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); + } + + int32_t swizzle_mode; + if (contig_dim_size_in_bytes >= 128) { + swizzle_mode = 3; + } else if (contig_dim_size_in_bytes == 64) { + swizzle_mode = 2; + } else if (contig_dim_size_in_bytes == 32) { + swizzle_mode = 1; + } else { + op->emitError() + << "contiguous box dimension must be at least 32 bytes but got " + << contig_dim_size_in_bytes; + return failure(); + } + + Value elemSizeVal = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize)); + Value globalStride = builder.template create( + loc, op.getStrides()[0], elemSizeVal); + // TODO: Workaround for ptxas bug, remove when we update ptxas + Value four = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(4)); + globalStride = + builder.template create(loc, globalStride, four); + + int elemTypeEnum; + switch (elemSize) { + case 1: { + elemTypeEnum = 0; + break; + } + case 2: { + elemTypeEnum = 1; + break; + } + case 4: { + elemTypeEnum = 2; + break; + } + default: { + op->emitError() + << "Tensor descriptor element type must have size 1, 2, or 4 but got " + << elemSize; + return failure(); + } + } + + auto one = mkI32Constant(1); + builder.template create( + loc, + /*desc_ptr=*/tmaPtr, + /*global_address=*/op.getBase(), + /*box_dim=*/boxDim, + /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, + /*global_stride=*/ValueRange{globalStride}, + /*element_strides=*/ValueRange{one, one}, + /*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum), + /*interleave_layout*/ builder.getI32IntegerAttr(0), + /*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode), + /*fill_mode=*/builder.getI32IntegerAttr(0)); + return success(); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index b370704be6bc..f94da69df7a2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1,28 +1,24 @@ #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" -#include - #define DEBUG_TYPE "triton-matmul-loop-pipeline" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -70,6 +66,30 @@ class OpBuilderWithStage : public OpBuilder { using OpBuilder::create; }; +class OpBuilderForStage : public OpBuilder { + std::optional stage_, cluster_; + +public: + explicit OpBuilderForStage(Operation *op, int stage, int cluster) + : OpBuilder(op, nullptr), stage_(stage), cluster_(cluster) {} + explicit OpBuilderForStage(Operation *op) : OpBuilder(op, nullptr) { + auto sc = tt::maybeGetStageCluster(op); + if (sc) { + stage_ = sc->first; + cluster_ = sc->second; + } + } + + template OpTy create(Args &&...args) { + OpTy op = OpBuilder::create(std::forward(args)...); + + if (stage_ && cluster_) { + tt::setStageCluster(op, *stage_, *cluster_); + } + return op; + } +}; + static bool sameStageCluster(Operation *op1, Operation *op2) { auto [s1, c1] = tt::getStageCluster(op1); auto [s2, c2] = tt::getStageCluster(op2); @@ -704,7 +724,126 @@ getFinalSchedule(scf::ForOp &forOp, int numStages) { return fSchedule; } -// Convert load ops into their asyn version and apply multi-buffering based on +LogicalResult +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int numBuffers) { + IRRewriter rewriter(forOp); + + // Create a multi-buffered allocation for each MakeTensorDescOp call in the + // loop + forOp.walk([&](tt::MakeTensorDescOp op) { + // TODO peter: walk to loop yield to find the init value if this is a + // loop-carried value. That would save us from allocating another buffer + // just for the init value + auto loc = op.getLoc(); + Value alloc = rewriter.create( + loc, triton::getPointerType(rewriter.getI8Type()), + numBuffers * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); + return success(); +} + +template +Value createIncrementModulo(BuilderT &builder, Location loc, Value counter, + Value modulus, Value zero, Value one) { + Value addOne = builder.template create(loc, counter, one); + Value inRangeCond = builder.template create( + loc, arith::CmpIPredicate::slt, addOne, modulus); + return builder.template create(loc, inRangeCond, addOne, + zero); +} + +template +Value subviewTMADescriptor(BuilderT &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = builder.template create( + loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = + builder.template create(loc, tmaSizeVal, counter); + return builder.template create(loc, alloc.getType(), alloc, + offset); +} + +LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, Value numBuffers, Value one, + Value zero) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + + // Rewriter MakeTensorDescOp as writing a TMA descriptor + auto makeDescOp = cast(op); + + OpBuilderForStage stageBuilder(makeDescOp); + auto loc = makeDescOp.getLoc(); + + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = subviewTMADescriptor(stageBuilder, loc, alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, stageBuilder))) { + return failure(); + } + stageBuilder.create( + loc, nextBuf); + Value nextDesc = stageBuilder.create( + loc, makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + // Increment the buffer index counter + Value nextCounter = createIncrementModulo(stageBuilder, loc, counter, + numBuffers, zero, one); + + // If we are in a (potentially nested) if region, propagate the counter + // up to the main for op body scope + Operation *curOp = op; + Operation *parent = op->getParentOp(); + while (parent != forOp.getOperation()) { + auto ifOp = dyn_cast(parent); + if (!ifOp) { + std::string msg; + llvm::raw_string_ostream ss(msg); + ss << "Cannot pipeline MakeTensorDescOp inside:\n"; + parent->print(ss); + ss << "\nOnly scf.if regions are supported"; + return makeDescOp->emitOpError(std::move(msg)); + } + + IRRewriter rewriter(parent); + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, {nextCounter.getType()}); + + auto yieldNewBlock = newIfOp.thenBlock(); + auto yieldOldBlock = newIfOp.elseBlock(); + + if (yieldNewBlock != curOp->getBlock()) { + std::swap(yieldNewBlock, yieldOldBlock); + } + cast(yieldNewBlock->getTerminator()) + .getResultsMutable() + .append(nextCounter); + cast(yieldOldBlock->getTerminator()) + .getResultsMutable() + .append(counter); + + ifOp.erase(); + nextCounter = newIfOp.getResults().back(); + curOp = newIfOp; + parent = newIfOp->getParentOp(); + } + + // Finally, rewrite the loop level yield + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + } + return success(); +} + +// Convert load ops into their async version and apply multi-buffering based on // the required number of buffers. static SmallVector createAsyncOps(scf::ForOp &forOp, @@ -728,6 +867,11 @@ createAsyncOps(scf::ForOp &forOp, numBuffers++; }; + llvm::MapVector tmaBufferMapping; + if (failed(allocTMABuffers(forOp, tmaBufferMapping, numBuffers))) { + llvm_unreachable("TMA pipelining failed"); + } + SmallVector asyncLoads; SmallVector allocs; bool hasTMALoad = false; @@ -751,7 +895,10 @@ createAsyncOps(scf::ForOp &forOp, builder.setInsertionPoint(forOp); Location loc = forOp.getLoc(); - // Create two new counters to index into the allocs. + // Create a counter to index into the allocations per loop iteration. + // NOTE: We create two duplicates values, insertIdx and extractIdx so that the + // pipeliner will re-materialize the value in later stages of the pipeline + // instead of carrying it as a dependency across multiple iterations. Value minusOne = builder.create(loc, -1, 32); Value zero = builder.create(loc, 0, 32); Value one = builder.create(loc, 1, 32); @@ -764,9 +911,19 @@ createAsyncOps(scf::ForOp &forOp, newOperands.push_back(insertIdx); newOperands.push_back(extractIdx); if (hasTMALoad) { + // A single barrier arrival sequence is a "phase" and two phases can + // overlap, provided the phases are differentiated with an alternating + // boolean value. phase = builder.create(loc, 0, 32); newOperands.push_back(phase); } + // Also create one counter per TMA buffer. This allows the descriptors to be + // updated independently without needing to write duplicate of existing tma + // descriptors. + for (int i = 0; i < tmaBufferMapping.size(); ++i) { + newOperands.push_back(zero); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. scf::ForOp newForOp = @@ -778,6 +935,20 @@ createAsyncOps(scf::ForOp &forOp, if (phase) { phase = newForOp.getBody()->getArgument(newOperandIndex + 2); } + auto tmaCounters = ArrayRef(newForOp.getBody()->getArguments()) + .slice(newOperandIndex + (phase ? 3 : 2)); + + // Update yield op with temporary yield values + auto forYield = cast(newForOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + if (failed(rewriteTMABufferUpdates(newForOp, tmaBufferMapping, tmaCounters, + numBuffersVal, one, zero))) { + llvm_unreachable("Failed to rewrite TMA ops"); + } + tmaBufferMapping.clear(); // FIXME: loads can be in different (stage, cluster) // Create two counters for the insert and extract indices to avoid creating @@ -811,11 +982,12 @@ createAsyncOps(scf::ForOp &forOp, loadToInfo, numStages); } } - SmallVector newYieldOperands = {insertIdx, extractIdx}; - if (phase) - newYieldOperands.push_back(phase); // Patch the yield with the updated counters. - appendToForOpYield(forOp, newYieldOperands); + forYield.setOperand(newOperandIndex + -1, insertIdx); + forYield.setOperand(newOperandIndex + 0, extractIdx); + if (phase) { + forYield.setOperand(newOperandIndex + 1, phase); + } tt::CoarseSchedule coarseSchedule(numStages); coarseSchedule.deSerialize(forOp); @@ -953,12 +1125,11 @@ static int minNumInterleavedCommitOps(Operation *waitOp) { if (thisHistorySum >= minCommitNumber) return minCommitNumber; - // get the value value assigned to the argument coming from outside the - // loop + // get the value assigned to the argument coming from outside the loop Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); - // get the value value assigned to the argument coming from the previous + // get the value assigned to the argument coming from the previous // iteration Operation *yieldOp = block->getTerminator(); Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 2305f30beb06..46892342a08d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Casting.h" using namespace mlir; namespace tt = mlir::triton; @@ -177,15 +178,23 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, op->erase(); } -std::pair mlir::triton::getStageCluster(Operation *op) { - auto stage = cast(op->getAttr(mlir::triton::kLoopStageAttrName)) - .getValue() - .getSExtValue(); +std::optional> +mlir::triton::maybeGetStageCluster(Operation *op) { + auto stage = + dyn_cast_if_present(op->getAttr(tt::kLoopStageAttrName)); auto clusterId = - cast(op->getAttr(mlir::triton::kLoopClusterAttrName)) - .getValue() - .getSExtValue(); - return std::make_pair(stage, clusterId); + dyn_cast_if_present(op->getAttr(tt::kLoopClusterAttrName)); + if (!stage || !clusterId) { + return std::nullopt; + } + + return { + {stage.getValue().getSExtValue(), clusterId.getValue().getSExtValue()}}; +} +std::pair mlir::triton::getStageCluster(Operation *op) { + auto res = maybeGetStageCluster(op); + assert(res.has_value() || "Operation is missing stage & cluster attribute"); + return *res; } void mlir::triton::setStageCluster(Operation *op, int stage, int cluster) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index cb9ae9dd0f3c..f412755d55f7 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -2,13 +2,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include @@ -115,87 +115,11 @@ class TMACreateDescLowering : public OpRewritePattern { PatternRewriter &rewriter) const override { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); - constexpr auto kTmaNbytes = 128; - constexpr auto kTmaAlignment = 128; auto alloc = rewriter.create( - loc, getPointerType(rewriter.getI8Type()), kTmaNbytes, kTmaAlignment); - auto mkI32Constant = [&](int32_t val) { - return rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(val)); - }; - - auto elemType = op.getBase().getType().getPointeeType(); - auto elemSize = elemType.getIntOrFloatBitWidth() / 8; - - int32_t contig_dim_size = op.getTensorShape().back(); - int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; - if (contig_dim_size_in_bytes > 128) { - contig_dim_size = 128 / elemSize; - } - llvm::SmallVector boxDim; - boxDim.push_back(mkI32Constant(contig_dim_size)); - for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { - boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); - } - - int32_t swizzle_mode; - if (contig_dim_size_in_bytes >= 128) { - swizzle_mode = 3; - } else if (contig_dim_size_in_bytes == 64) { - swizzle_mode = 2; - } else if (contig_dim_size_in_bytes == 32) { - swizzle_mode = 1; - } else { - op->emitError() - << "contiguous box dimension must be at least 32 bytes but got " - << contig_dim_size_in_bytes; - return failure(); - } - - Value elemSizeVal = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(elemSize)); - Value globalStride = - rewriter.create(loc, op.getStrides()[0], elemSizeVal); - // TODO: Workaround for ptxas bug, remove when we update ptxas - Value four = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(4)); - globalStride = rewriter.create(loc, globalStride, four); - - int elemTypeEnum; - switch (elemSize) { - case 1: { - elemTypeEnum = 0; - break; - } - case 2: { - elemTypeEnum = 1; - break; - } - case 4: { - elemTypeEnum = 2; - break; - } - default: { - op->emitError() - << "Tensor descriptor element type must have size 1, 2, or 4 but got " - << elemSize; + loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, TMA_ALIGN); + if (failed(createTMADesc(alloc, op, rewriter))) { return failure(); } - } - - auto one = mkI32Constant(1); - rewriter.create( - loc, - /*desc_ptr=*/alloc.getResult(), - /*global_address=*/op.getBase(), - /*box_dim=*/boxDim, - /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, - /*global_stride=*/ValueRange{globalStride}, - /*element_strides=*/ValueRange{one, one}, - /*elem_type*/ rewriter.getI32IntegerAttr(elemTypeEnum), - /*interleave_layout*/ rewriter.getI32IntegerAttr(0), - /*swizzle_mode=*/rewriter.getI32IntegerAttr(swizzle_mode), - /*fill_mode=*/rewriter.getI32IntegerAttr(0)); rewriter.create( loc, alloc.getResult()); auto newDesc = rewriter.create( diff --git a/python/src/passes.cc b/python/src/passes.cc index 235eba4465cb..b0efc3cb884b 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -31,6 +31,7 @@ void init_triton_passes_common(py::module &&m) { ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); ADD_PASS_WRAPPER_0("add_cse", createCSEPass); ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); + ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); } void init_triton_passes_ttir(py::module &&m) { diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 49a8bb32c4f8..4b4a08857d7a 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -28,6 +28,8 @@ import triton.profiler as proton from contextlib import contextmanager +from typing import Optional + if torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) @@ -374,8 +376,7 @@ def matmul_tma_persistent(a, b): @triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_device_tma_persistent(workspace_ptr, # - tiles_per_update: tl.constexpr, # +def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # a_ptr, b_ptr, c_ptr, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # @@ -391,24 +392,24 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - TMA_SIZE: tl.constexpr = 128 - workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - c_desc_ptr = workspace_base + 2 * TMA_SIZE - - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: @@ -426,6 +427,9 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # num_pid_in_group = GROUP_SIZE_M * num_pid_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Create an opaque value to prevent the descriptor creation from being + # hoisted out of the loop + zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) @@ -434,21 +438,24 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # # Simulate a grouped gemm if ni == tiles_per_update: - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, - BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, - BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, - BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr + zero, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr + zero, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr + zero, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) ni = 0 tile_id += NUM_SMS @@ -463,19 +470,19 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # offs_k = ki * BLOCK_SIZE_K - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) if ki == k_tiles - 1: c = accumulator.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + c_desc.store([offs_am, offs_bn], c) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -def matmul_device_tma_persistent(a, b, tiles_per_update): +def matmul_descriptor_persistent(a, b, tiles_per_update): # Autotuner does not work with TMA. Use manual config. configs = { torch.float8_e4m3fn: { @@ -497,12 +504,15 @@ def matmul_device_tma_persistent(a, b, tiles_per_update): c = torch.empty((M, N), device=a.device, dtype=dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - tma_size = 128 - workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) - matmul_kernel_device_tma_persistent[grid]( - workspace, # + matmul_kernel_descriptor_persistent[grid]( tiles_per_update, # a, b, c, # M, N, K, # @@ -576,7 +586,7 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - bench_fn(reps, warmup_reps, matmul_device_tma_persistent, a, b, tiles_per_update) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update) def validate(M, N, K, dtype, tiles_per_update): @@ -589,7 +599,7 @@ def validate(M, N, K, dtype, tiles_per_update): naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None - device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None + descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None if torch_result is not None: naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), @@ -602,9 +612,9 @@ def validate(M, N, K, dtype, tiles_per_update): if tma_persistent_result is not None: naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16), tma_persistent_result.to(torch.float16), atol=1.0) else "❌" - if device_tma_persistent_result is not None: - naive_vs_device_tma_persistent = "✅" if torch.allclose(cublas_result.to( - torch.float16), device_tma_persistent_result.to(torch.float16), atol=1.0) else "❌" + if descriptor_persistent_result is not None: + naive_vs_descriptor_persistent = "✅" if torch.allclose(cublas_result.to( + torch.float16), descriptor_persistent_result.to(torch.float16), atol=1.0) else "❌" print(f"M={M}, N={N}, K={K} verification naive vs: ", end="") if torch_result is not None: print(f"torch: {naive_vs_torch} ", end="") @@ -613,8 +623,8 @@ def validate(M, N, K, dtype, tiles_per_update): print(f"persistent: {naive_vs_persistent} ", end="") if tma_persistent_result is not None: print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") - if device_tma_persistent_result is not None: - print(f"Device TMA persistent: {naive_vs_device_tma_persistent} ", end="") + if descriptor_persistent_result is not None: + print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="") print() @@ -639,7 +649,7 @@ def show_profile(precision, profile_name): type=int, default=1, help= - "Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel", + "Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel", ) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index 459a00c1a142..1bc6eb7cf0ab 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -9,12 +9,13 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::nvidia_gpu; namespace { -constexpr int64_t TMA_SIZE_BYTES = 128; void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value outPtr,