diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp index 02245bb6a..9519ae7ad 100644 --- a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp @@ -127,6 +127,10 @@ class GemmUniversalAttention static constexpr int SG_N = CollectiveMainloop::SG_N; static constexpr int SG_K = CollectiveMainloop::SG_K; + static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); + // Kernel level shared memory storage struct SharedStorage { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -257,12 +261,9 @@ class GemmUniversalAttention TiledMma tiled_mma; Tensor out_reg = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); - constexpr int VecA = 8; - constexpr int FragsM1 = 4; - constexpr int FragsN2 = 2; - Tensor max_reg = make_tensor(Shape, Int>{}); - Tensor sum_reg = make_tensor(Shape, Int>{}); + Tensor max_reg = make_tensor(Shape, Int>{}); + Tensor sum_reg = make_tensor(Shape, Int>{}); fill(max_reg, -INFINITY); clear(sum_reg); @@ -285,7 +286,7 @@ class GemmUniversalAttention // 1) Load K (performed inside mmaQK) // 2) Create Tensor S auto gK = local_tile(mK_nk, blk_shape, take<0, 3>(make_coord(0, 0, _, blk_l_coord)), Step< X, _1, _1>{}); - Tensor tSr = make_tensor(Shape, Int, Int>{}); + Tensor tSr = make_tensor(Shape, Int, Int>{}); clear(tSr); // 3) Perform GEMM S = Q*K auto tile_coord_QK = make_coord(seq_coord, load_idx, _, blk_l_coord); @@ -296,12 +297,12 @@ class GemmUniversalAttention // mask the elements of each tile where j > i int col_idx = item_id + load_idx; CUTLASS_PRAGMA_UNROLL - for(int n = 0; n < FragsN2; n++, col_idx += get<1>(MmaAtomShape())) { + for(int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { CUTLASS_PRAGMA_UNROLL - for(int m = 0; m < FragsM1; m++) { - int row_idx = m * VecA + seq_coord; + for(int m = 0; m < FragsM; m++) { + int row_idx = m * Vec + seq_coord; CUTLASS_PRAGMA_UNROLL - for(int row = 0; row < VecA; row++, row_idx++) { + for(int row = 0; row < Vec; row++, row_idx++) { if(col_idx > row_idx) tSr(row, m, n) = -INFINITY; } @@ -310,10 +311,10 @@ class GemmUniversalAttention } if (nblock == 0) - flash::Softmax::template run(tSr, + flash::Softmax::template run(tSr, max_reg, sum_reg, out_reg, params.softmax); else - flash::Softmax::template run(tSr, + flash::Softmax::template run(tSr, max_reg, sum_reg, out_reg, params.softmax); // 7) Convert S to P (FP32 -> BF16) Tensor tPr = make_tensor(shape(tSr)); @@ -331,7 +332,7 @@ class GemmUniversalAttention // Reduce the sum of exponents across the subgroup before scaling/normalizing output flash::SumOp op; - flash::Softmax::template subgroup_allreduce(sum_reg, op); + flash::Softmax::template subgroup_allreduce(sum_reg, op); CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};