Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Why do we only need the result of the last k-loop in cute::gemm dispatch-5? #1629

Open
sjfeng1999 opened this issue Jul 12, 2024 · 7 comments

Comments

@sjfeng1999
Copy link
Contributor

the original code is as follow

// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
template <class MMA,
          class TD, class DLayout,
          class TA, class ALayout,
          class TB, class BLayout,
          class TC, class CLayout,
          __CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
                          ALayout::rank == 3 && is_rmem<TA>::value &&
                          BLayout::rank == 3 && is_rmem<TB>::value &&
                          CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA>       const& mma,
     Tensor<TD, DLayout>      & D,  // (V,M,N) Logical data
     Tensor<TA, ALayout> const& A,  // (V,M,K) Logical data
     Tensor<TB, BLayout> const& B,  // (V,N,K) Logical data
     Tensor<TC, CLayout> const& C)  // (V,M,N) Logical data
{
  CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C));  // AM == CM
  CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C));  // BN == CN
  CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B));  // AK == BK
  CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
  auto K = size<2>(A);

  CUTE_UNROLL
  for (int k = 0; k < K; ++k) {
    gemm(mma, D, A(_,_,k), B(_,_,k), C);
  }
}

In the for-loop of dim-k (D = Ak x Bk + C), the result of the last calculation will override the result of the previous one.
For example, the following code

  auto tA = make_tensor<int>(make_layout(make_shape(_1{}, _1{}, _2{}))); // V=1, M=1, K=2
  auto tB = make_tensor<int>(make_layout(make_shape(_1{}, _1{}, _2{}))); // V=1, N=1, K=2
  auto tC = make_tensor<int>(make_layout(make_shape(_1{}, _1{}, _1{}))); // V=1, M=1, N=1
  auto tD = make_tensor<int>(make_layout(make_shape(_1{}, _1{}, _1{}))); // V=1, M=1, N=1

  fill(tA, 1); // A = [1, 1]
  fill(tB, 1); // B = [1, 1]
  fill(tC, 10); // C = [10]

  gemm(tD, tA, tB, tC);
  print_tensor(tD); // should be 1 x 1 + 1 x 1 + 10 = 12

will get

ptr[32b](0x7fff402f0840) o (_1,_1,_1):(_0,_0,_0):
    11

instead of 12.

This is only correct when C and D point to the same register that result will be accumulated properly. Is this a restriction for calling this function(cute::gemm dispatch-5)?

@sjfeng1999
Copy link
Contributor Author

#1618 @thakkarV

@thakkarV
Copy link
Collaborator

@ccecka can you please help take a look at this one. we discussed offline but this does seem legit

@ccecka
Copy link

ccecka commented Jul 15, 2024

I agree with this MR and believe it is no-cost in terms of perf. Let's open it back up and approve.

@sjfeng1999
Copy link
Contributor Author

I believe there is no if-cond in final assembly after fully unroll.

@sjfeng1999
Copy link
Contributor Author

Would you mind merging this pr #1618 to fix this problem ? It seems I have no authority to reopen.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Copy link

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants