Skip to content

Commit

Permalink
Set means to zero in online covariance GPU if assume_centered=True (#…
Browse files Browse the repository at this point in the history
…2850)

* Set means to zero in online covariance GPU if assume_centered=True

* Refactor tests
  • Loading branch information
olegkkruglov authored Aug 28, 2024
1 parent 42fb399 commit 651fcaf
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 377 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,17 @@ result_t finalize_compute_kernel_dense_impl<Float>::operator()(const descriptor_
(homogen_table::wrap(corr.flatten(q, { corr_event }), column_count, column_count)));
}
if (desc.get_result_options().test(result_options::means)) {
auto [means, means_event] = compute_means(q, sums, rows_count_global);
result.set_means(homogen_table::wrap(means.flatten(q, { means_event }), 1, column_count));
if (!assume_centered) {
auto [means, means_event] = compute_means(q, sums, rows_count_global);
result.set_means(
homogen_table::wrap(means.flatten(q, { means_event }), 1, column_count));
}
else {
auto [zero_means, zeros_event] =
pr::ndarray<Float, 1>::zeros(q, { column_count }, sycl::usm::alloc::device);
result.set_means(
homogen_table::wrap(zero_means.flatten(q, { zeros_event }), 1, column_count));
}
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ static partial_compute_result<Task> partial_compute(const context_gpu& ctx,
const std::int64_t column_count = data.get_column_count();
ONEDAL_ASSERT(column_count > 0);

auto assume_centered = desc.get_assume_centered();

dal::detail::check_mul_overflow(row_count, column_count);
dal::detail::check_mul_overflow(column_count, column_count);

const auto data_nd = pr::table2ndarray<Float>(q, data, sycl::usm::alloc::device);

auto [sums, sums_event] = compute_sums(q, data_nd);
auto [sums, sums_event] = compute_sums(q, data_nd, assume_centered, {});

auto [crossproduct, crossproduct_event] = compute_crossproduct(q, data_nd, { sums_event });

Expand Down
94 changes: 40 additions & 54 deletions cpp/oneapi/dal/algo/covariance/test/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,73 +23,59 @@ namespace la = te::linalg;
namespace cov = oneapi::dal::covariance;

template <typename TestType>
class covariance_batch_test : public covariance_test<TestType, covariance_batch_test<TestType>> {};
class covariance_batch_test : public covariance_test<TestType, covariance_batch_test<TestType>> {
public:
using base_t = covariance_test<TestType, covariance_batch_test<TestType>>;
using descriptor_t = typename base_t::descriptor_t;

TEMPLATE_LIST_TEST_M(covariance_batch_test,
"covariance fill_normal common flow",
"[covariance][integration][batch]",
covariance_types) {
SKIP_IF(this->not_float64_friendly());
void general_checks(const te::dataframe& input,
const te::table_id& input_table_id,
descriptor_t cov_desc) {
const table data = input.get_table(this->get_policy(), input_table_id);

const te::dataframe input =
GENERATE_DATAFRAME(te::dataframe_builder{ 4, 4 }.fill_normal(0, 1, 7777),
te::dataframe_builder{ 100, 100 }.fill_normal(0, 1, 7777),
te::dataframe_builder{ 250, 250 }.fill_normal(0, 1, 7777),
te::dataframe_builder{ 500, 100 }.fill_normal(0, 1, 7777));

// Homogen floating point type is the same as algorithm's floating point type
const auto input_data_table_id = this->get_homogen_table_id();
this->general_checks(input, input_data_table_id);
this->general_checks_assume_centered(input, input_data_table_id);
}
auto compute_result = this->compute(cov_desc, data);
this->check_compute_result(cov_desc, data, compute_result);
}
};

TEMPLATE_LIST_TEST_M(covariance_batch_test,
"covariance compute one element matrix",
"[covariance][integration][batch]",
"covariance common flow",
"[covariance][integration][online]",
covariance_types) {
SKIP_IF(this->not_float64_friendly());
const te::dataframe input =
GENERATE_DATAFRAME(te::dataframe_builder{ 1, 1 }.fill_normal(0, 1, 7777));

// Homogen floating point type is the same as algorithm's floating point type
const auto input_data_table_id = this->get_homogen_table_id();
this->general_checks(input, input_data_table_id);
}

TEMPLATE_LIST_TEST_M(covariance_batch_test,
"covariance fill_uniform common flow",
"[covariance][integration][batch]",
covariance_types) {
SKIP_IF(this->not_float64_friendly());
using Float = std::tuple_element_t<0, TestType>;
using Method = std::tuple_element_t<1, TestType>;

const bool assume_centered = GENERATE(true, false);
INFO("assume_centered=" << assume_centered);
const bool bias = GENERATE(true, false);
INFO("bias=" << bias);
const cov::result_option_id result_option =
GENERATE(covariance::result_options::means,
covariance::result_options::cov_matrix,
covariance::result_options::cor_matrix,
covariance::result_options::cor_matrix | covariance::result_options::cov_matrix,
covariance::result_options::cor_matrix | covariance::result_options::cov_matrix |
covariance::result_options::means);
INFO("result_option=" << result_option);

auto cov_desc = covariance::descriptor<Float, Method, covariance::task::compute>()
.set_result_options(result_option)
.set_assume_centered(assume_centered)
.set_bias(bias);

const te::dataframe input =
GENERATE_DATAFRAME(te::dataframe_builder{ 1000, 20 }.fill_uniform(-30, 30, 7777),
te::dataframe_builder{ 100, 10 }.fill_uniform(0, 1, 7777),
te::dataframe_builder{ 100, 10 }.fill_uniform(-10, 10, 7777),
te::dataframe_builder{ 500, 40 }.fill_uniform(-100, 100, 7777),
te::dataframe_builder{ 500, 250 }.fill_uniform(0, 1, 7777));

// Homogen floating point type is the same as algorithm's floating point type
const auto input_data_table_id = this->get_homogen_table_id();
this->general_checks(input, input_data_table_id);
this->general_checks_assume_centered(input, input_data_table_id);
}

TEMPLATE_LIST_TEST_M(covariance_batch_test,
"covariance fill_uniform nightly common flow",
"[covariance][integration][batch][nightly]",
covariance_types) {
SKIP_IF(this->not_float64_friendly());
GENERATE_DATAFRAME(te::dataframe_builder{ 100, 100 }.fill_normal(0, 1, 7777),
te::dataframe_builder{ 500, 100 }.fill_normal(0, 1, 7777),
te::dataframe_builder{ 10000, 200 }.fill_uniform(-30, 30, 7777));

const te::dataframe input =
GENERATE_DATAFRAME(te::dataframe_builder{ 5000, 20 }.fill_uniform(-30, 30, 7777),
te::dataframe_builder{ 10000, 200 }.fill_uniform(-30, 30, 7777),
te::dataframe_builder{ 1000000, 20 }.fill_uniform(-0.5, 0.5, 7777));
INFO("num_rows=" << input.get_row_count());
INFO("num_columns=" << input.get_column_count());

// Homogen floating point type is the same as algorithm's floating point type
const auto input_data_table_id = this->get_homogen_table_id();
this->general_checks(input, input_data_table_id);
this->general_checks_assume_centered(input, input_data_table_id);
this->general_checks(input, input_data_table_id, cov_desc);
}

} // namespace oneapi::dal::covariance::test
34 changes: 33 additions & 1 deletion cpp/oneapi/dal/algo/covariance/test/compute_parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class covariance_params_test : public covariance_test<TestType, covariance_param
}
}

void general_checks(const te::dataframe& input,
const te::table_id& input_table_id,
descriptor_t cov_desc) {
const table data = input.get_table(this->get_policy(), input_table_id);

auto compute_result = this->compute(cov_desc, data);
this->check_compute_result(cov_desc, data, compute_result);
}

private:
std::int64_t block_;
bool pack_as_struct_;
Expand All @@ -74,18 +83,41 @@ TEMPLATE_LIST_TEST_M(covariance_params_test,
"[covariance][params]",
covariance_types) {
SKIP_IF(this->not_float64_friendly());
using Float = std::tuple_element_t<0, TestType>;
using Method = std::tuple_element_t<1, TestType>;
const bool assume_centered = GENERATE(true, false);
INFO("assume_centered=" << assume_centered);
const bool bias = GENERATE(true, false);
INFO("bias=" << bias);
const cov::result_option_id result_option =
GENERATE(covariance::result_options::means,
covariance::result_options::cov_matrix,
covariance::result_options::cor_matrix,
covariance::result_options::cor_matrix | covariance::result_options::cov_matrix,
covariance::result_options::cor_matrix | covariance::result_options::cov_matrix |
covariance::result_options::means);
INFO("result_option=" << result_option);

auto cov_desc = covariance::descriptor<Float, Method, covariance::task::compute>()
.set_result_options(result_option)
.set_assume_centered(assume_centered)
.set_bias(bias);

const te::dataframe input =
GENERATE_DATAFRAME(te::dataframe_builder{ 500, 40 }.fill_uniform(-100, 100, 7777),
te::dataframe_builder{ 1000, 20 }.fill_uniform(-30, 30, 7777),
te::dataframe_builder{ 10000, 100 }.fill_uniform(-30, 30, 7777),
te::dataframe_builder{ 100000, 20 }.fill_uniform(1, 10, 7777));

INFO("num_rows=" << input.get_row_count());
INFO("num_columns=" << input.get_column_count());

// Homogen floating point type is the same as algorithm's floating point type
const auto input_data_table_id = this->get_homogen_table_id();

this->generate_parameters();

this->general_checks(input, input_data_table_id);
this->general_checks(input, input_data_table_id, cov_desc);
}

TEST("can dump system-related parameters") {
Expand Down
Loading

0 comments on commit 651fcaf

Please sign in to comment.