Skip to content

Commit

Permalink
Fixed a bug with bounded memory trees. The bug randomly caused a segf…
Browse files Browse the repository at this point in the history
…ault.
  • Loading branch information
mrucker committed Oct 9, 2023
1 parent ba83f62 commit 802450d
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,14 @@ void node_split(emt_tree& b, emt_node& cn)
cn.examples.clear();
}

void node_insert(emt_node& cn, std::unique_ptr<emt_example> ex)
void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr<emt_example> ex)
{
for (auto& cn_ex : cn.examples)
{
if (cn_ex->full == ex->full) { return; }
}
cn.examples.push_back(std::move(ex));
tree_bound(b, cn.examples.back().get());
}

emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex)
Expand Down Expand Up @@ -779,16 +780,18 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW:
auto* closest_ex = node_pick(b, base, cn, ex);
ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0;
ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0;
if (closest_ex != nullptr) {
tree_bound(b, closest_ex);
}

}

void emt_predict(emt_tree& b, learner& base, VW::example& ec)
{
b.all->feature_tweaks_config.ignore_some_linear = false;
emt_example ex(*b.all, &ec);

emt_node& cn = *tree_route(b, ex);
node_predict(b, base, cn, ex, ec);
tree_bound(b, &ex);
}

void emt_learn(emt_tree& b, learner& base, VW::example& ec)
Expand All @@ -797,10 +800,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec)
auto ex = VW::make_unique<emt_example>(*b.all, &ec);

emt_node& cn = *tree_route(b, *ex);
scorer_learn(b, base, cn, *ex, ec.weight);
node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn
tree_bound(b, ex.get());
node_insert(cn, std::move(ex));
scorer_learn(b, base, cn, *ex, ec.weight);
node_insert(b, cn, std::move(ex));
node_split(b, cn);
}

Expand Down

0 comments on commit 802450d

Please sign in to comment.