Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Race reported between Write access and Read access in fusion using async copy #3428

Open
liqiangxl opened this issue Nov 17, 2024 · 2 comments · May be fixed by #3438
Open

Race reported between Write access and Read access in fusion using async copy #3428

liqiangxl opened this issue Nov 17, 2024 · 2 comments · May be fixed by #3438
Labels
bug Something isn't working

Comments

@liqiangxl
Copy link
Collaborator

Originally found in CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_96 using

NVFUSER_DUMP=scheduler_params,cuda_to_file NVFUSER_ENABLE=kernel_debug PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --tool racecheck --racecheck-detect-level info  ./nvfuser_tests --gtest_filter='CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_96'

Can be reproduced with a simple fusion:

using NVFuserTestInlinedCpAsyncBool = NVFuserFixtureParamTest<bool>;
TEST_P(NVFuserTestInlinedCpAsyncBool, FusionCpAsyncRaceBcastInlined) {
  NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
  Fusion fusion;
  FusionGuard fg(&fusion);
  int m = 64, n = 96;

  // Use the parameterized value for use_async
  const bool inlined_cp_async = GetParam();
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = makeContigTensor(1);
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  // copy tv0 to shared memory tv2
  auto tv2 = set(tv0);
  tv2->setMemoryType(MemoryType::Shared);
  tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync);
  tv2->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified);

  // copy tv1 to shared memory tv3
  auto tv3 = set(tv1);
  tv3->setMemoryType(MemoryType::Shared);
  tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync);
  tv3->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified); 

  auto tv4 = broadcast(tv3, {true, false});
  auto tv5 = add(tv2, tv4);
  fusion.addOutput(tv5);

  // No Race if TIDy is not used
  // for (auto tv : {tv0, tv2, tv4, tv5}) {
  //   tv->split(0, 5);
  //   tv->axis(1)->parallelize(ParallelType::BIDy);    
  // }  
  for (auto tv : {tv0, tv2, tv4, tv5}) {
    tv->split(0, 2);
    tv->split(0, 5);
    tv->axis(2)->parallelize(ParallelType::TIDy);    
    tv->axis(1)->parallelize(ParallelType::BIDy);    
  } 
  for (auto tv : {tv0, tv1, tv2, tv3, tv4, tv5}) {
    tv->split(-1, 1, false);
    tv->axis(-1)->parallelize(ParallelType::TIDx);
    tv->axis(-2)->parallelize(ParallelType::Unswitch);
  }

  fusion.printMath();
  // No Race if all tvs are inlined
  if(inlined_cp_async){
    inlineMost();
  }else{
    inlineMost(std::vector<TensorView*>{tv2, tv4, tv5});
  }

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({m, n}, options);
  at::Tensor t1 = at::randn({n}, options);

  KernelExecutor ke;
  ke.compile(&fusion, {t0, t1});
  auto cg_outputs = ke.run({t0, t1});
  testValidate(&fusion, cg_outputs, {t0, t1}, __LINE__, __FILE__);
}

INSTANTIATE_TEST_SUITE_P(
    ,
    NVFuserTestInlinedCpAsyncBool,
    ::testing::Values(false, true));

If run this fusion using current main branch with NVFUSER_DUMP=cuda_to_file NVFUSER_ENABLE=kernel_debug compute-sanitizer --tool racecheck ./nvfuser_tests --gtest_filter=*FusionCpAsyncRaceBcastInlined/0 will get:

========= Error: Race reported between Write access at <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<float, (int)2, (int)2>, <unnamed>::Tensor<float, (int)1, (int)1>, <unnamed>::Tensor<float, (int)2, (int)2>)+0x2be0 in /opt/pytorch/nvfuser/build/__tmp_kernel_none_f0_c0_r0_g0.cu:10883
=========     and Read access at <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<float, (int)2, (int)2>, <unnamed>::Tensor<float, (int)1, (int)1>, <unnamed>::Tensor<float, (int)2, (int)2>)+0x34b0 in /opt/pytorch/nvfuser/build/__tmp_kernel_none_f0_c0_r0_g0.cu:10907 [1920 hazards]

The generated kernel is

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 2, 2> T5) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i0;
  i0 = ceilDiv((ceilDiv(T0.logical_size[0LL], 2LL)), 5LL);
  nvfuser_index_t i1;
  i1 = 4LL * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i2;
  i2 = T0.logical_size[1LL] * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i3;
  i3 = (2LL * T0.logical_size[1LL]) * ((nvfuser_index_t)blockIdx.y);
  float* ptr4;
  ptr4 = ((T0.data + ((nvfuser_index_t)threadIdx.x)) + i2) + i3;
  nvfuser_index_t i5;
  i5 = 10LL * T0.logical_size[1LL];
  float* T3 = reinterpret_cast<float*>(array + smem_offset + ((((2LL * T0.logical_size[1LL]) * 4LL) + 15LL) & -16LL));
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 0LL);
  unsigned i6;
  i6 = (toSmem(T2) + i1) + ((4LL * T0.logical_size[1LL]) * ((nvfuser_index_t)threadIdx.y));
  nvfuser_index_t i7;
  i7 = ((nvfuser_index_t)threadIdx.x) + i2;
  nvfuser_index_t i8;
  i8 = i7 + i3;
  bool b9;
  b9 = ((nvfuser_index_t)threadIdx.y) == 0LL;
  nvfuser_index_t i10;
  i10 = ((nvfuser_index_t)threadIdx.y) + (2LL * ((nvfuser_index_t)blockIdx.y));
  if (b9) {
    asm volatile(
      "{\n"
      "  .reg .pred p0; \n"
      "  setp.ne.b32 p0, %3, 0;\n"
      "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
      "}\n"
      :
      :"r"((uint32_t)((toSmem(T3) + i1))),
       "l"((T1.data + ((nvfuser_index_t)threadIdx.x))),
       "n"(4LL),
       "r"((uint32_t)((!b9)))
    );
  }
  __syncthreads();
  #pragma unroll 1
  for(nvfuser_index_t i11 = 0LL; i11 < i0; ++i11) {
    nvfuser_index_t i12;
    i12 = i5 * i11;
    if (((i10 + (10LL * i11)) < T0.logical_size[0LL])) {
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %3, 0;\n"
        "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
        "}\n"
        :
        :"r"((uint32_t)(i6)),
         "l"((ptr4 + i12)),
         "n"(4LL),
         "n"((uint32_t)(false))
      );
      float T4[1LL];
      asm volatile("cp.async.wait_all;\n");
      T4[0LL]
         = T3[((nvfuser_index_t)threadIdx.x)];
      T5[(i8 + i12)]
        = T2[i7]
        + T4[0LL];
    }
  }
}

Race happens at the read & write of T3

@liqiangxl liqiangxl added the bug Something isn't working label Nov 17, 2024
@liqiangxl
Copy link
Collaborator Author

If TIDy is not used, no race detected. The kernel is:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 2, 2> T5) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i0;
  i0 = ceilDiv(T0.logical_size[0LL], 5LL);
  nvfuser_index_t i1;
  i1 = 4LL * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i2;
  i2 = T0.logical_size[1LL] * ((nvfuser_index_t)blockIdx.y);
  float* ptr3;
  ptr3 = (T0.data + ((nvfuser_index_t)threadIdx.x)) + i2;
  nvfuser_index_t i4;
  i4 = 5LL * T0.logical_size[1LL];
  float* T3 = reinterpret_cast<float*>(array + smem_offset + (((T0.logical_size[1LL] * 4LL) + 15LL) & -16LL));
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 0LL);
  unsigned i5;
  i5 = toSmem(T2) + i1;
  nvfuser_index_t i6;
  i6 = ((nvfuser_index_t)threadIdx.x) + i2;
  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %3, 0;\n"
    "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
    "}\n"
    :
    :"r"((uint32_t)((toSmem(T3) + i1))),
     "l"((T1.data + ((nvfuser_index_t)threadIdx.x))),
     "n"(4LL),
     "n"((uint32_t)(false))
  );
  #pragma unroll 1
  for(nvfuser_index_t i7 = 0LL; i7 < i0; ++i7) {
    nvfuser_index_t i8;
    i8 = i4 * i7;
    if (((((nvfuser_index_t)blockIdx.y) + (5LL * i7)) < T0.logical_size[0LL])) {
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %3, 0;\n"
        "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
        "}\n"
        :
        :"r"((uint32_t)(i5)),
         "l"((ptr3 + i8)),
         "n"(4LL),
         "n"((uint32_t)(false))
      );
      float T4[1LL];
      asm volatile("cp.async.wait_all;\n");
      T4[0LL]
         = T3[((nvfuser_index_t)threadIdx.x)];
      T5[(i6 + i8)]
        = T2[((nvfuser_index_t)threadIdx.x)]
        + T4[0LL];
    }
  }
}

@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Nov 18, 2024

A potential explanation:

__shared__ float T2[blockDim.y * blockDim.x];
__shared__ float T3[blockDim.x];

// (1) Perform an async copy to shared memory T3, only for threads with threadIdx.y == 0
if (threadIdx.y == 0) {
  cp.async from gmem to T3
}

// Only ensures all threads are executing this sync, doesn't mean the async copy to T3 is returned.
__syncthreads();

for (int i11 = 0; i11 < 2; ++i11) {
  // (2) Perform an async copy to shared memory T2, for all threads
  cp.async from gmem to T2

  // (3) Wait for all async copies to complete. 
  // For threads not participating in the copy, no need to wait (my guess).
 // For example, threads with threadIdx.y != 0 don't participate the async copy to T3, so they don't need to wait for copy to T3 is done.
  asm volatile("cp.async.wait_all;\n");

  // (4) Read from T3 and T2
  float T4[1LL];
  T4[0] = T3[threadIdx.x];  // Potential race here
  T5[0] = T2[i7] + T4[0];
}

The asynchronous copy to T3 (step 1) is performed only by threads with threadIdx.y == 0, but T3 is accessed by all threads in the block (step 4). If threadIdx.y != 0 threads do not wait for the async copy to T3 to complete, they might read from T3 before the data has been fully written, leading to a race condition.

Adding an additional asm volatile("cp.async.wait_all;\n"); before __syncthreads(); can remove the race, why? Because:
all threads with threadIdx.y==0 wait until the async copy is complete before executing __syncthreads() . So after __syncthreads() all threads can read from smem free of race.

@liqiangxl liqiangxl linked a pull request Nov 18, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant