Skip to content

Commit

Permalink
Remove hard coded parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Nov 15, 2024
1 parent f2226c9 commit 4a88b2a
Showing 1 changed file with 14 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ class GemmUniversalAttention
static constexpr int SG_N = CollectiveMainloop::SG_N;
static constexpr int SG_K = CollectiveMainloop::SG_K;

static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape());
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape());

// Kernel level shared memory storage
struct SharedStorage {
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
Expand Down Expand Up @@ -257,12 +261,9 @@ class GemmUniversalAttention
TiledMma tiled_mma;

Tensor out_reg = partition_fragment_C(tiled_mma, take<0,2>(blk_shape));
constexpr int VecA = 8;
constexpr int FragsM1 = 4;
constexpr int FragsN2 = 2;

Tensor max_reg = make_tensor<ElementAccumulator>(Shape<Int<VecA>, Int<FragsM1>>{});
Tensor sum_reg = make_tensor<ElementAccumulator>(Shape<Int<VecA>, Int<FragsM1>>{});
Tensor max_reg = make_tensor<ElementAccumulator>(Shape<Int<Vec>, Int<FragsM>>{});
Tensor sum_reg = make_tensor<ElementAccumulator>(Shape<Int<Vec>, Int<FragsM>>{});

fill(max_reg, -INFINITY);
clear(sum_reg);
Expand All @@ -285,7 +286,7 @@ class GemmUniversalAttention
// 1) Load K (performed inside mmaQK)
// 2) Create Tensor S
auto gK = local_tile(mK_nk, blk_shape, take<0, 3>(make_coord(0, 0, _, blk_l_coord)), Step< X, _1, _1>{});
Tensor tSr = make_tensor<ElementAccumulator>(Shape<Int<VecA>, Int<FragsM1>, Int<FragsN2>>{});
Tensor tSr = make_tensor<ElementAccumulator>(Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
clear(tSr);
// 3) Perform GEMM S = Q*K
auto tile_coord_QK = make_coord(seq_coord, load_idx, _, blk_l_coord);
Expand All @@ -296,12 +297,12 @@ class GemmUniversalAttention
// mask the elements of each tile where j > i
int col_idx = item_id + load_idx;
CUTLASS_PRAGMA_UNROLL
for(int n = 0; n < FragsN2; n++, col_idx += get<1>(MmaAtomShape())) {
for(int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) {
CUTLASS_PRAGMA_UNROLL
for(int m = 0; m < FragsM1; m++) {
int row_idx = m * VecA + seq_coord;
for(int m = 0; m < FragsM; m++) {
int row_idx = m * Vec + seq_coord;
CUTLASS_PRAGMA_UNROLL
for(int row = 0; row < VecA; row++, row_idx++) {
for(int row = 0; row < Vec; row++, row_idx++) {
if(col_idx > row_idx)
tSr(row, m, n) = -INFINITY;
}
Expand All @@ -310,10 +311,10 @@ class GemmUniversalAttention
}

if (nblock == 0)
flash::Softmax<ElementAccumulator>::template run<true, CausalMask, VecA, FragsM1, FragsN2>(tSr,
flash::Softmax<ElementAccumulator>::template run<true, CausalMask, Vec, FragsM, FragsN>(tSr,
max_reg, sum_reg, out_reg, params.softmax);
else
flash::Softmax<ElementAccumulator>::template run<false, CausalMask, VecA, FragsM1, FragsN2>(tSr,
flash::Softmax<ElementAccumulator>::template run<false, CausalMask, Vec, FragsM, FragsN>(tSr,
max_reg, sum_reg, out_reg, params.softmax);
// 7) Convert S to P (FP32 -> BF16)
Tensor tPr = make_tensor<typename TiledMma::ValTypeA>(shape(tSr));
Expand All @@ -331,7 +332,7 @@ class GemmUniversalAttention

// Reduce the sum of exponents across the subgroup before scaling/normalizing output
flash::SumOp<ElementAccumulator> op;
flash::Softmax<ElementAccumulator>::template subgroup_allreduce<false, VecA, FragsM1, FragsN2>(sum_reg, op);
flash::Softmax<ElementAccumulator>::template subgroup_allreduce<false, Vec, FragsM, FragsN>(sum_reg, op);

CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};

Expand Down

0 comments on commit 4a88b2a

Please sign in to comment.