diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp index 27c322168a..e8dce656a3 100644 --- a/include/cute/algorithm/gemm.hpp +++ b/include/cute/algorithm/gemm.hpp @@ -411,7 +411,13 @@ gemm(MMA_Atom const& mma, CUTE_UNROLL for (int k = 0; k < K; ++k) { - gemm(mma, D, A(_,_,k), B(_,_,k), C); + if (k == 0) { + // D = Ak x Bk + C + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } else { + // D = Ak x Bk + D + gemm(mma, D, A(_,_,k), B(_,_,k), D); + } } } @@ -493,7 +499,13 @@ gemm(MMA_Atom const& mma, copy(A(_,_,k), rA(_,_,k)); copy(B(_,_,k), rB(_,_,k)); // Thread-level register gemm for k - gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + if (k == 0) { + // D = Ak x Bk + C + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } else { + // D = Ak x Bk + D + gemm(mma, D, rA(_,_,k), rB(_,_,k), D); + } } }