Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: forgotten lint fix #4688

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 42 additions & 25 deletions vowpalwabbit/core/src/reductions/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,37 @@
namespace
{
float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness)
{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
{ // implementation follows
// https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
const float mellow_log_e_count_over_e_count =
mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count);
// loss should be in [0,1]
avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f);

const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater.
//std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;

if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use.
+ mellow_log_e_count_over_e_count) { return 1; }
//old equation
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
// return mellow_log_e_count_over_e_count * rs * rs;
const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff;
const float sqrt_avg_loss_plus_sqrt_alt_loss =
std::min(1.f, // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
std::sqrt(avg_loss + alt_label_error_rate_diff)); // emperical variance deflater.
// std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = "
// << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss
//<< std::endl;

if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss // deflater in use.
+ mellow_log_e_count_over_e_count)
{
return 1;
}
// old equation
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss *
// sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); return
// mellow_log_e_count_over_e_count * rs * rs;
const float sqrt_s = (sqrt_mellow_lecoec +
std::sqrt(mellow_log_e_count_over_e_count +
4 * alt_label_error_rate_diff * mellow_log_e_count_over_e_count)) /
2 * alt_label_error_rate_diff;
// std::cout << "sqrt_s = " << sqrt_s << std::endl;
return sqrt_s*sqrt_s;
return sqrt_s * sqrt_s;
}

float query_decision(const active& a, float updates_to_change_prediction, float example_count)
Expand All @@ -61,8 +73,10 @@
{
// const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / example_count);
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not
// following why we need it from the theory.
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " <<
// a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0);
}
// std::cout << "bias = " << bias << std::endl;
Expand Down Expand Up @@ -122,31 +136,32 @@

template <bool is_learn>
void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec)
{
{
if (is_learn) { base.learn(ec); }
else { base.predict(ec); }

if (ec.l.simple.label == FLT_MAX)
{
if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?")
{
if (std::string(ec.tag.begin(), ec.tag.begin() + 6) == "query?")

Check warning on line 145 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L145

Added line #L145 was not covered by tests
{
const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f;
// We want to understand the change in prediction if the label were to be
// the opposite of what was predicted. 0 and 1 are used for the expected min
// and max labels to be coming in from the active interactor.
ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label;
ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec);
ec.l.simple.label = FLT_MAX;
ec.pred.scalar = query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));
ec.pred.scalar =
query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));

Check warning on line 155 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L154-L155

Added lines #L154 - L155 were not covered by tests
}
}
else
{
{
// Update seen labels based on the current example's label.
a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label);
a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label);
}
}
}
}

Check warning on line 164 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L164

Added line #L164 was not covered by tests

void active_print_result(
VW::io::writer* f, float res, float weight, const VW::v_array<char>& tag, VW::io::logger& logger)
Expand Down Expand Up @@ -232,7 +247,9 @@
option_group_definition new_options("[Reduction] Active Learning");
new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning"))
.add(make_option("simulation", simulation).help("Active learning simulation mode"))
.add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes."))
.add(make_option("direct", direct)
.help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get "
"query decision. Returned prediction is either -1 for no or the importance weight for yes."))
.add(make_option("mellowness", active_c0)
.keep()
.default_value(1.f)
Expand Down
Loading