diff --git a/include/ur_api.h b/include/ur_api.h index ca58a7ac66..41fa62b25f 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -1054,6 +1054,8 @@ typedef enum ur_adapter_backend_t { /// + `NULL == phAdapters` /// - ::UR_RESULT_ERROR_INVALID_SIZE /// + `NumEntries == 0 && phPlatforms != NULL` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +/// + `pNumPlatforms == NULL && phPlatforms == NULL` UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( ur_adapter_handle_t *phAdapters, ///< [in][range(0, NumAdapters)] array of adapters to query for platforms. diff --git a/scripts/core/platform.yml b/scripts/core/platform.yml index a1aa0dc7ca..997f4918ee 100644 --- a/scripts/core/platform.yml +++ b/scripts/core/platform.yml @@ -48,6 +48,8 @@ params: returns: - $X_RESULT_ERROR_INVALID_SIZE: - "`NumEntries == 0 && phPlatforms != NULL`" + - $X_RESULT_ERROR_INVALID_VALUE: + - "`pNumPlatforms == NULL && phPlatforms == NULL`" --- #-------------------------------------------------------------------------- type: enum desc: "Supported platform info" diff --git a/source/adapters/opencl/platform.cpp b/source/adapters/opencl/platform.cpp index 9fa5025196..218a5e7f00 100644 --- a/source/adapters/opencl/platform.cpp +++ b/source/adapters/opencl/platform.cpp @@ -96,6 +96,11 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, } } + /* INVALID_VALUE is returned when the size is invalid, special case it here */ + if (Result == CL_INVALID_VALUE && phPlatforms != nullptr && NumEntries == 0) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + return mapCLErrorToUR(Result); } diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 748f40638e..ef7bb019ea 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -234,6 +234,10 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGet( if (NumEntries == 0 && phPlatforms != NULL) { return UR_RESULT_ERROR_INVALID_SIZE; } + + if (pNumPlatforms == NULL && phPlatforms == NULL) { + return UR_RESULT_ERROR_INVALID_VALUE; + } } ur_result_t result = diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index f1044ea3af..574e81103c 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -509,6 +509,8 @@ ur_result_t UR_APICALL urAdapterGetInfo( /// + `NULL == phAdapters` /// - ::UR_RESULT_ERROR_INVALID_SIZE /// + `NumEntries == 0 && phPlatforms != NULL` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +/// + `pNumPlatforms == NULL && phPlatforms == NULL` ur_result_t UR_APICALL urPlatformGet( ur_adapter_handle_t * phAdapters, ///< [in][range(0, NumAdapters)] array of adapters to query for platforms. diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 0b3d7f20bc..79aadc6090 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -460,6 +460,8 @@ ur_result_t UR_APICALL urAdapterGetInfo( /// + `NULL == phAdapters` /// - ::UR_RESULT_ERROR_INVALID_SIZE /// + `NumEntries == 0 && phPlatforms != NULL` +/// - ::UR_RESULT_ERROR_INVALID_VALUE +/// + `pNumPlatforms == NULL && phPlatforms == NULL` ur_result_t UR_APICALL urPlatformGet( ur_adapter_handle_t * phAdapters, ///< [in][range(0, NumAdapters)] array of adapters to query for platforms. diff --git a/test/conformance/platform/urPlatformGet.cpp b/test/conformance/platform/urPlatformGet.cpp index f3ac6318e9..20f12c16df 100644 --- a/test/conformance/platform/urPlatformGet.cpp +++ b/test/conformance/platform/urPlatformGet.cpp @@ -41,3 +41,10 @@ TEST_F(urPlatformGetTest, InvalidNullPointer) { static_cast(adapters.size()), 0, nullptr, &count)); } + +TEST_F(urPlatformGetTest, NullArgs) { + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_VALUE, + urPlatformGet(adapters.data(), + static_cast(adapters.size()), 0, + nullptr, nullptr)); +}