Skip to content

Commit

Permalink
Apply feedback from previous PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Oct 4, 2024
1 parent bc6bd6c commit af87004
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 580 deletions.
22 changes: 12 additions & 10 deletions docs/domains/sparse_linear_algebra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,21 @@ rocSPARSE backend
Currently known limitations:

- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will
throw an ``oneapi::mkl::unimplemented`` exception.
- The COO format requires the indices to be sorted by row then by column. It is
not required to set the property
``oneapi::mkl::sparse::matrix_property::sorted`` to a sparse matrix handle.
See the `rocSPARSE COO documentation
throw a ``oneapi::mkl::unimplemented`` exception.
- The COO format requires the indices to be sorted by row then by column. See
the `rocSPARSE COO documentation
<https://rocm.docs.amd.com/projects/rocSPARSE/en/latest/how-to/basics.html#coo-storage-format>`_.
- The CSR format requires the column indices to be sorted within each row. It is
not required to set the property
``oneapi::mkl::sparse::matrix_property::sorted`` to a sparse matrix handle.
See the `rocSPARSE CSR documentation
Sparse operations using matrices with the COO format without the property
``matrix_property::sorted`` will throw a ``oneapi::mkl::unimplemented``
exception.
- The CSR format requires the column indices to be sorted within each row. See
the `rocSPARSE CSR documentation
<https://rocm.docs.amd.com/projects/rocSPARSE/en/latest/how-to/basics.html#csr-storage-format>`_.
Sparse operations using matrices with the CSR format without the property
``matrix_property::sorted`` will throw a ``oneapi::mkl::unimplemented``
exception.
- The same sparse matrix handle cannot be reused for multiple operations
``spmm``, ``spmv``, or ``spsv``. Doing so will throw an
``spmm``, ``spmv``, or ``spsv``. Doing so will throw a
``oneapi::mkl::unimplemented`` exception. See `#332
<https://github.com/ROCm/rocSPARSE/issues/332>`_.

Expand Down
169 changes: 93 additions & 76 deletions src/sparse_blas/backends/common_launch_task.hxx

Large diffs are not rendered by default.

385 changes: 8 additions & 377 deletions src/sparse_blas/backends/cusparse/cusparse_task.hpp

Large diffs are not rendered by default.

54 changes: 21 additions & 33 deletions src/sparse_blas/backends/rocsparse/operations/rocsparse_spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ inline auto get_roc_spmm_alg(spmm_alg alg) {
}
}

inline void fallback_alg_if_needed(oneapi::mkl::sparse::spmm_alg& alg, oneapi::mkl::transpose opA,
oneapi::mkl::transpose opB) {
if (alg == oneapi::mkl::sparse::spmm_alg::csr_alg3 &&
(opA != oneapi::mkl::transpose::nontrans || opB == oneapi::mkl::transpose::conjtrans)) {
// Avoid warnings printed on std::cerr
alg = oneapi::mkl::sparse::spmm_alg::default_alg;
}
void check_valid_spmm(const std::string& function_name, matrix_view A_view,
matrix_handle_t A_handle, dense_matrix_handle_t B_handle,
dense_matrix_handle_t C_handle, bool is_alpha_host_accessible,
bool is_beta_host_accessible) {
detail::check_valid_spmm_common(function_name, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
A_handle->check_valid_handle(function_name);
}

void spmm_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB,
Expand All @@ -88,10 +88,8 @@ void spmm_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mk
oneapi::mkl::sparse::spmm_descr_t spmm_descr, std::size_t& temp_buffer_size) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
A_handle->throw_if_already_used(__func__);
detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
fallback_alg_if_needed(alg, opA, opB);
check_valid_spmm(__func__, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
auto functor = [=, &temp_buffer_size](RocsparseScopedContextHandler& sc) {
auto [roc_handle, roc_stream] = sc.get_handle_and_stream(queue);
auto roc_a = A_handle->backend_handle;
Expand Down Expand Up @@ -120,9 +118,8 @@ inline void common_spmm_optimize(
oneapi::mkl::sparse::dense_matrix_handle_t B_handle, bool is_beta_host_accessible,
oneapi::mkl::sparse::dense_matrix_handle_t C_handle, oneapi::mkl::sparse::spmm_alg alg,
oneapi::mkl::sparse::spmm_descr_t spmm_descr) {
A_handle->throw_if_already_used("spmm_optimize");
detail::check_valid_spmm_common("spmm_optimize", A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_spmm("spmm_optimize", A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
if (!spmm_descr->buffer_size_called) {
throw mkl::uninitialized("sparse_blas", "spmm_optimize",
"spmm_buffer_size must be called before spmm_optimize.");
Expand Down Expand Up @@ -156,7 +153,7 @@ void spmm_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA,
auto status =
rocsparse_spmm(roc_handle, roc_op_a, roc_op_b, alpha, roc_a, roc_b, beta, roc_c, roc_type,
roc_alg, rocsparse_spmm_stage_preprocess, &buffer_size, workspace_ptr);
check_status(status, "optimize_spmm");
check_status(status, "spmm_optimize");
}

void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::transpose opB,
Expand All @@ -178,9 +175,9 @@ void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::
if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
return;
}
fallback_alg_if_needed(alg, opA, opB);
std::size_t buffer_size = spmm_descr->temp_buffer_size;

// The accessor can only be created if the buffer size is greater than 0
if (buffer_size > 0) {
auto functor = [=](RocsparseScopedContextHandler& sc,
sycl::accessor<std::uint8_t> workspace_acc) {
Expand All @@ -190,11 +187,7 @@ void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::
buffer_size, workspace_ptr, is_alpha_host_accessible);
};

// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(workspace);
dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, B_handle,
C_handle);
dispatch_submit(__func__, queue, functor, A_handle, workspace, B_handle, C_handle);
}
else {
auto functor = [=](RocsparseScopedContextHandler& sc) {
Expand Down Expand Up @@ -227,7 +220,6 @@ sycl::event spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA,
if (alg == oneapi::mkl::sparse::spmm_alg::no_optimize_alg) {
return detail::collapse_dependencies(queue, dependencies);
}
fallback_alg_if_needed(alg, opA, opB);
std::size_t buffer_size = spmm_descr->temp_buffer_size;
auto functor = [=](RocsparseScopedContextHandler& sc) {
auto roc_handle = sc.get_handle(queue);
Expand All @@ -247,7 +239,6 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
const std::vector<sycl::event>& dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
A_handle->throw_if_already_used(__func__);
if (A_handle->all_use_buffer() != spmm_descr->workspace.use_buffer()) {
detail::throw_incompatible_container(__func__);
}
Expand All @@ -263,10 +254,9 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
CHECK_DESCR_MATCH(spmm_descr, B_handle, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, C_handle, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, alg, "spmm_optimize");
detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_spmm(__func__, A_view, A_handle, B_handle, C_handle, is_alpha_host_accessible,
is_beta_host_accessible);
A_handle->mark_used();
fallback_alg_if_needed(alg, opA, opB);
auto& buffer_size = spmm_descr->temp_buffer_size;
auto compute_functor = [=, &buffer_size](RocsparseScopedContextHandler& sc,
void* workspace_ptr) {
Expand All @@ -287,23 +277,21 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
HIP_ERROR_FUNC(hipStreamSynchronize, roc_stream);
#endif
};
// The accessor can only be created if the buffer size is greater than 0
if (A_handle->all_use_buffer() && buffer_size > 0) {
// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
auto functor_buffer = [=](RocsparseScopedContextHandler& sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto workspace_ptr = sc.get_mem(workspace_acc);
compute_functor(sc, workspace_ptr);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmm_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
workspace_placeholder_acc, B_handle, C_handle);
spmm_descr->workspace.get_buffer<std::uint8_t>(),
B_handle, C_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
// workspace accessor is needed, workspace_ptr will be a nullptr in the
// latter case.
// workspace accessor is needed.
// workspace_ptr will be a nullptr in the latter case.
auto workspace_ptr = spmm_descr->workspace.usm_ptr;
auto functor_usm = [=](RocsparseScopedContextHandler& sc) {
compute_functor(sc, workspace_ptr);
Expand Down
23 changes: 9 additions & 14 deletions src/sparse_blas/backends/rocsparse/operations/rocsparse_spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose o
bool is_alpha_host_accessible, bool is_beta_host_accessible) {
detail::check_valid_spmv_common(function_name, opA, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible, is_beta_host_accessible);
A_handle->throw_if_already_used(__func__);
A_handle->check_valid_handle(__func__);
if (A_view.type_view != oneapi::mkl::sparse::matrix_descr::general) {
throw mkl::unimplemented(
"sparse_blas", function_name,
Expand Down Expand Up @@ -157,7 +157,7 @@ void spmv_optimize_impl(rocsparse_handle roc_handle, oneapi::mkl::transpose opA,
auto status =
rocsparse_spmv(roc_handle, roc_op, alpha, roc_a, roc_x, beta, roc_y, roc_type, roc_alg,
rocsparse_spmv_stage_preprocess, &buffer_size, workspace_ptr);
check_status(status, "optimize_spmv");
check_status(status, "spmv_optimize");
}

void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
Expand All @@ -180,6 +180,7 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
return;
}
std::size_t buffer_size = spmv_descr->temp_buffer_size;
// The accessor can only be created if the buffer size is greater than 0
if (buffer_size > 0) {
auto functor = [=](RocsparseScopedContextHandler &sc,
sycl::accessor<std::uint8_t> workspace_acc) {
Expand All @@ -189,11 +190,7 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
buffer_size, workspace_ptr, is_alpha_host_accessible);
};

// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(workspace);
dispatch_submit(__func__, queue, functor, A_handle, workspace_placeholder_acc, x_handle,
y_handle);
dispatch_submit(__func__, queue, functor, A_handle, workspace, x_handle, y_handle);
}
else {
auto functor = [=](RocsparseScopedContextHandler &sc) {
Expand Down Expand Up @@ -282,23 +279,21 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
HIP_ERROR_FUNC(hipStreamSynchronize, roc_stream);
#endif
};
// The accessor can only be created if the buffer size is greater than 0
if (A_handle->all_use_buffer() && buffer_size > 0) {
// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
auto functor_buffer = [=](RocsparseScopedContextHandler &sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto workspace_ptr = sc.get_mem(workspace_acc);
compute_functor(sc, workspace_ptr);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmv_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit_native_ext(__func__, queue, functor_buffer, A_handle,
workspace_placeholder_acc, x_handle, y_handle);
spmv_descr->workspace.get_buffer<std::uint8_t>(),
x_handle, y_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
// workspace accessor is needed, workspace_ptr will be a nullptr in the
// latter case.
// workspace accessor is needed.
// workspace_ptr will be a nullptr in the latter case.
auto workspace_ptr = spmv_descr->workspace.usm_ptr;
auto functor_usm = [=](RocsparseScopedContextHandler &sc) {
compute_functor(sc, workspace_ptr);
Expand Down
Loading

0 comments on commit af87004

Please sign in to comment.