From b9eb5da9a77cb6f27a94451ac81732a52db912f2 Mon Sep 17 00:00:00 2001 From: Ben Tracy Date: Tue, 29 Oct 2024 15:34:11 +0000 Subject: [PATCH] [CMDBUF] Fix incorrect handling of shared local mem args in CUDA/HIP - Fix handling of local mem args in CUDA/HIP - Add conformance tests which check updating local memory args and work size --- source/adapters/cuda/command_buffer.cpp | 7 +- source/adapters/hip/command_buffer.cpp | 7 +- test/conformance/device_code/CMakeLists.txt | 1 + .../device_code/saxpy_usm_local_mem.cpp | 30 ++ .../exp_command_buffer/CMakeLists.txt | 1 + ...xp_command_buffer_adapter_native_cpu.match | 4 + .../update/local_memory_update.cpp | 484 ++++++++++++++++++ 7 files changed, 532 insertions(+), 2 deletions(-) create mode 100644 test/conformance/device_code/saxpy_usm_local_mem.cpp create mode 100644 test/conformance/exp_command_buffer/update/local_memory_update.cpp diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index 2029903c92..527c339783 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -1304,7 +1304,12 @@ updateKernelArguments(kernel_command_handle *Command, ur_result_t Result = UR_RESULT_SUCCESS; try { - Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + // Local memory args are passed as value args with nullptr value + if (ArgValue) { + Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + } else { + Kernel->setKernelLocalArg(ArgIndex, ArgSize); + } } catch (ur_result_t Err) { Result = Err; return Result; diff --git a/source/adapters/hip/command_buffer.cpp b/source/adapters/hip/command_buffer.cpp index afd15c1bd4..9fed5db2f8 100644 --- a/source/adapters/hip/command_buffer.cpp +++ b/source/adapters/hip/command_buffer.cpp @@ -1013,7 +1013,12 @@ updateKernelArguments(ur_exp_command_buffer_command_handle_t Command, const void *ArgValue = ValueArgDesc.pNewValueArg; try { - Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + // Local memory args are passed as value args with nullptr value + if (ArgValue) { + Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + } else { + Kernel->setKernelLocalArg(ArgIndex, ArgSize); + } } catch (ur_result_t Err) { return Err; } diff --git a/test/conformance/device_code/CMakeLists.txt b/test/conformance/device_code/CMakeLists.txt index 2120d26bf3..1621b01544 100644 --- a/test/conformance/device_code/CMakeLists.txt +++ b/test/conformance/device_code/CMakeLists.txt @@ -162,6 +162,7 @@ add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/sequence.cpp) add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/standard_types.cpp) add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/subgroup.cpp) add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/linker_error.cpp) +add_device_binary(${CMAKE_CURRENT_SOURCE_DIR}/saxpy_usm_local_mem.cpp) set(KERNEL_HEADER ${UR_CONFORMANCE_DEVICE_BINARIES_DIR}/kernel_entry_points.h) add_custom_command(OUTPUT ${KERNEL_HEADER} diff --git a/test/conformance/device_code/saxpy_usm_local_mem.cpp b/test/conformance/device_code/saxpy_usm_local_mem.cpp new file mode 100644 index 0000000000..7ef17e59b5 --- /dev/null +++ b/test/conformance/device_code/saxpy_usm_local_mem.cpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +int main() { + size_t array_size = 16; + size_t local_size = 4; + sycl::queue sycl_queue; + uint32_t *X = sycl::malloc_shared(array_size, sycl_queue); + uint32_t *Y = sycl::malloc_shared(array_size, sycl_queue); + uint32_t *Z = sycl::malloc_shared(array_size, sycl_queue); + uint32_t A = 42; + + sycl_queue.submit([&](sycl::handler &cgh) { + sycl::local_accessor local_mem(local_size, cgh); + cgh.parallel_for( + sycl::nd_range<1>{{array_size}, {local_size}}, + [=](sycl::nd_item<1> itemId) { + auto i = itemId.get_global_linear_id(); + auto local_id = itemId.get_local_linear_id(); + local_mem[local_id] = i; + Z[i] = A * X[i] + Y[i] + local_mem[local_id] + + itemId.get_local_range(0); + }); + }); + return 0; +} diff --git a/test/conformance/exp_command_buffer/CMakeLists.txt b/test/conformance/exp_command_buffer/CMakeLists.txt index 9845ba86b1..8b7aaa5a63 100644 --- a/test/conformance/exp_command_buffer/CMakeLists.txt +++ b/test/conformance/exp_command_buffer/CMakeLists.txt @@ -19,4 +19,5 @@ add_conformance_test_with_kernels_environment(exp_command_buffer update/usm_saxpy_kernel_update.cpp update/event_sync.cpp update/kernel_event_sync.cpp + update/local_memory_update.cpp ) diff --git a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match index 2ccc267535..e6b8320def 100644 --- a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match +++ b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match @@ -36,3 +36,7 @@ {{OPT}}KernelCommandEventSyncUpdateTest.TwoWaitEvents/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} {{OPT}}KernelCommandEventSyncUpdateTest.InvalidWaitUpdate/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} {{OPT}}KernelCommandEventSyncUpdateTest.InvalidSignalUpdate/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} +{{OPT}}LocalMemoryUpdateTest.UpdateParameters/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} +{{OPT}}LocalMemoryUpdateTest.UpdateParametersAndLocalSize/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} +{{OPT}}LocalMemoryMultiUpdateTest.UpdateParameters/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} +{{OPT}}LocalMemoryMultiUpdateTest.UpdateWithoutBlocking/SYCL_NATIVE_CPU___SYCL_Native_CPU__{{.*}} diff --git a/test/conformance/exp_command_buffer/update/local_memory_update.cpp b/test/conformance/exp_command_buffer/update/local_memory_update.cpp new file mode 100644 index 0000000000..82e280e0f9 --- /dev/null +++ b/test/conformance/exp_command_buffer/update/local_memory_update.cpp @@ -0,0 +1,484 @@ +// Copyright (C) 2024 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "../fixtures.h" +#include +#include + +// Test that updating a command-buffer with a single kernel command +// taking a local memory argument works correctly. + +struct LocalMemoryUpdateTestBase + : uur::command_buffer::urUpdatableCommandBufferExpExecutionTest { + virtual void SetUp() override { + program_name = "saxpy_usm_local_mem"; + UUR_RETURN_ON_FATAL_FAILURE( + urUpdatableCommandBufferExpExecutionTest::SetUp()); + + ur_device_usm_access_capability_flags_t shared_usm_flags; + ASSERT_SUCCESS( + uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_flags)); + if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) { + GTEST_SKIP() << "Shared USM is not supported."; + } + + const size_t allocation_size = + sizeof(uint32_t) * global_size * local_size; + for (auto &shared_ptr : shared_ptrs) { + ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr, + allocation_size, &shared_ptr)); + ASSERT_NE(shared_ptr, nullptr); + + std::vector pattern(allocation_size); + uur::generateMemFillPattern(pattern); + std::memcpy(shared_ptr, pattern.data(), allocation_size); + } + + // Index 0 is local_mem arg + ASSERT_SUCCESS(urKernelSetArgLocal(kernel, 0, local_mem_size, nullptr)); + + // Index 1 is output + ASSERT_SUCCESS( + urKernelSetArgPointer(kernel, 1, nullptr, shared_ptrs[0])); + // Index 2 is A + ASSERT_SUCCESS(urKernelSetArgValue(kernel, 2, sizeof(A), nullptr, &A)); + // Index 3 is X + ASSERT_SUCCESS( + urKernelSetArgPointer(kernel, 3, nullptr, shared_ptrs[1])); + // Index 4 is Y + ASSERT_SUCCESS( + urKernelSetArgPointer(kernel, 4, nullptr, shared_ptrs[2])); + } + + void Validate(uint32_t *output, uint32_t *X, uint32_t *Y, uint32_t A, + size_t length, size_t local_size) { + for (size_t i = 0; i < length; i++) { + uint32_t result = A * X[i] + Y[i] + i + local_size; + ASSERT_EQ(result, output[i]); + } + } + + virtual void TearDown() override { + for (auto &shared_ptr : shared_ptrs) { + if (shared_ptr) { + EXPECT_SUCCESS(urUSMFree(context, shared_ptr)); + } + } + + UUR_RETURN_ON_FATAL_FAILURE( + urUpdatableCommandBufferExpExecutionTest::TearDown()); + } + + static constexpr size_t local_size = 4; + static constexpr size_t local_mem_size = local_size * sizeof(uint32_t); + static constexpr size_t global_size = 16; + static constexpr size_t global_offset = 0; + static constexpr size_t n_dimensions = 1; + static constexpr uint32_t A = 42; + std::array shared_ptrs = {nullptr, nullptr, nullptr, nullptr, + nullptr}; +}; + +struct LocalMemoryUpdateTest : LocalMemoryUpdateTestBase { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(LocalMemoryUpdateTestBase::SetUp()); + + // Append kernel command to command-buffer and close command-buffer + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset, + &global_size, &local_size, 0, nullptr, 0, nullptr, 0, nullptr, + nullptr, nullptr, &command_handle)); + ASSERT_NE(command_handle, nullptr); + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); + } + + void TearDown() override { + if (command_handle) { + EXPECT_SUCCESS(urCommandBufferReleaseCommandExp(command_handle)); + } + + UUR_RETURN_ON_FATAL_FAILURE(LocalMemoryUpdateTestBase::TearDown()); + } + + ur_exp_command_buffer_command_handle_t command_handle = nullptr; +}; + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(LocalMemoryUpdateTest); + +TEST_P(LocalMemoryUpdateTest, UpdateParameters) { + // Run command-buffer prior to update an verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; + ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + + // New local_mem at index 0 + new_value_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + local_mem_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New A at index 2 + uint32_t new_A = 33; + new_value_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 2, // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }; + + // New X at index 3 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 3, // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 4 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4, // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + 2, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + new_value_descs, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, local_size); +} + +TEST_P(LocalMemoryUpdateTest, UpdateParametersAndLocalSize) { + // Run command-buffer prior to update an verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; + ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + + size_t new_local_size = local_size * 2; + size_t new_local_mem_size = new_local_size * sizeof(uint32_t); + // New local_mem at index 0 + new_value_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + static_cast(new_local_mem_size), // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New A at index 2 + uint32_t new_A = 33; + new_value_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 2, // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }; + + // New X at index 3 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 3, // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 4 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4, // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + 2, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + new_value_descs, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + &new_local_size, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, new_local_size); +} + +struct LocalMemoryMultiUpdateTest : LocalMemoryUpdateTestBase { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(LocalMemoryUpdateTestBase::SetUp()); + + // Append kernel command to command-buffer and close command-buffer + for (unsigned node = 0; node < nodes; node++) { + // We need to set the local memory arg each time because it is + // cleared in the kernel handle after being used. + ASSERT_SUCCESS( + urKernelSetArgLocal(kernel, 0, local_mem_size, nullptr)); + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset, + &global_size, &local_size, 0, nullptr, 0, nullptr, 0, nullptr, + nullptr, nullptr, &command_handles[node])); + ASSERT_NE(command_handles[node], nullptr); + } + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); + } + + void TearDown() override { + for (auto &handle : command_handles) { + if (handle) { + EXPECT_SUCCESS(urCommandBufferReleaseCommandExp(handle)); + } + } + UUR_RETURN_ON_FATAL_FAILURE(LocalMemoryUpdateTestBase::TearDown()); + } + + static constexpr size_t nodes = 1024; + static constexpr uint32_t A = 42; + std::array command_handles{}; +}; + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(LocalMemoryMultiUpdateTest); + +TEST_P(LocalMemoryMultiUpdateTest, UpdateParameters) { + // Run command-buffer prior to update an verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; + ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + + // New local_mem at index 0 + new_value_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + local_mem_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New A at index 2 + uint32_t new_A = 33; + new_value_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 2, // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }; + + // New X at index 3 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 3, // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 4 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4, // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + 2, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + new_value_descs, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + for (auto &handle : command_handles) { + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + } + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, local_size); +} + +TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) { + // Update inputs + ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; + ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + + // New local_mem at index 0 + new_value_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + local_mem_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New A at index 2 + uint32_t new_A = 33; + new_value_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 2, // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }; + + // New X at index 3 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 3, // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 4 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4, // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + 2, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + new_value_descs, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + // Enqueue without calling urQueueFinish after + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + + // Update kernel and enqueue command-buffer again + for (auto &handle : command_handles) { + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + } + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, local_size); +}