Skip to content

Commit

Permalink
Optimize OpGraph scheduling (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Dec 16, 2023
1 parent 14381d7 commit 97cd329
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 296 deletions.
116 changes: 97 additions & 19 deletions ark/sched/sched_opgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,35 +238,83 @@ void OpGraph::recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
}
new_boundary_nodes.emplace_back(producer);
}
OpNode *merge_candidate = nullptr;
if (boundary_node->producers.size() > 1) {
// This node has multiple producers. It cannot be merged.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" multiple producers");
continue;
// This node has multiple producers. We can merge only if one
// producer depends on all other producers.
for (auto &producer : boundary_node->producers) {
bool depends_on_all = true;
for (auto &other_producer : boundary_node->producers) {
if (other_producer == producer) {
continue;
}
if (!this->depends_on(producer, other_producer)) {
depends_on_all = false;
break;
}
}
if (depends_on_all) {
merge_candidate = producer;
break;
}
}
if (merge_candidate == nullptr) {
// At least one producer does not depend on others.
// Cannot merge.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" multiple producers");
continue;
}
} else {
// This node has only one producer.
merge_candidate = *(boundary_node->producers.begin());
}
// This node has only one producer.
OpNode *producer = *(boundary_node->producers.begin());
if (producer->users.size() == 0) {
if (merge_candidate->users.size() == 0) {
ERR(SchedulerError, "unexpected error: graph is incomplete");
}
if (producer->users.size() > 1) {
// The producer has multiple users. It cannot be merged.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" multiple users");
continue;
if (merge_candidate->users.size() > 1) {
// The candidate has multiple users. We can merge only if all
// other users depend on the current boundary_node.
bool depends_on_one = true;
for (auto &user : merge_candidate->users) {
if (user == boundary_node) {
continue;
}
if (!this->depends_on(user, boundary_node)) {
depends_on_one = false;
break;
}
}
if (!depends_on_one) {
// At least one user does not depend on the boundary_node.
// Cannot merge.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" multiple users");
continue;
}
}
// The producer has only one user. Merge the two nodes.
// The candidate has only one user. Merge the two nodes.

// Merge `boundary_node` into `producer`.
OPGRAPH_DEBUG(" merge ops: ", producer->get_name(), " -> ",
// Merge `boundary_node` into `merge_candidate`.
OPGRAPH_DEBUG(" merge: ", merge_candidate->get_name(), " -> ",
boundary_node->get_name());
auto &ops = boundary_node->ops;
producer->ops.insert(producer->ops.end(), ops.begin(), ops.end());
producer->users = boundary_node->users;
for (auto &user : producer->users) {
merge_candidate->ops.insert(merge_candidate->ops.end(), ops.begin(),
ops.end());
for (auto &user : boundary_node->users) {
user->producers.erase(boundary_node);
user->producers.insert(producer);
user->producers.insert(merge_candidate);
merge_candidate->users.insert(user);
}
for (auto &producer : boundary_node->producers) {
if (producer == merge_candidate) {
continue;
}
producer->users.erase(boundary_node);
producer->users.insert(merge_candidate);
merge_candidate->producers.insert(producer);
}
merge_candidate->users = boundary_node->users;

// Remove `boundary_node` from `nodes`.
auto it =
Expand Down Expand Up @@ -309,4 +357,34 @@ OpNode *OpGraph::break_node(OpNode *node, int op_idx) {
return new_node;
}

/// Check dependencies between two @ref OpNode.
///
/// @param node1 The first @ref OpNode.
/// @param node2 The second @ref OpNode.
/// @return True if @p node1 depends on @p node2.
bool OpGraph::depends_on(OpNode *node1, OpNode *node2) const {
if (node1 == node2) {
return false;
}
std::set<OpNode *> seen_nodes;
std::list<OpNode *> boundary_nodes;
boundary_nodes.emplace_back(node1);
while (boundary_nodes.size() > 0) {
std::list<OpNode *> new_boundary_nodes;
for (auto &boundary_node : boundary_nodes) {
if (boundary_node == node2) {
return true;
}
for (auto &producer : boundary_node->producers) {
if (seen_nodes.find(producer) != seen_nodes.end()) {
continue;
}
new_boundary_nodes.emplace_back(producer);
}
}
boundary_nodes = new_boundary_nodes;
}
return false;
}

} // namespace ark
19 changes: 13 additions & 6 deletions ark/sched/sched_opgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,23 @@ class OpGraph {
/// @return The new @ref OpNode.
OpNode *break_node(OpNode *node, int op_idx);

/// Check dependencies between two @ref OpNode.
///
/// @param node1 The first @ref OpNode.
/// @param node2 The second @ref OpNode.
/// @return True if @p node1 depends on @p node2.
bool depends_on(OpNode *node1, OpNode *node2) const;

private:
std::list<std::unique_ptr<OpNode>> nodes_storage;

void create_nodes(const Model &model);
static void recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes);
static void recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes);
void recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes);
void recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes);
};

} // namespace ark
Expand Down
85 changes: 41 additions & 44 deletions ark/sched/sched_opgraph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,19 @@ ark::unittest::State test_sched_opgraph() {

// OpNode graph (parentheses indicate a OpNode):
//
// +----------------------+
// | |
// (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,)
// (AddOp,AddOp,ReluOp,AddOp,)
//

graph = ark::OpGraph(model);
UNITTEST_EQ(graph.get_nodes().size(), 3UL);
UNITTEST_EQ(graph.get_nodes().size(), 1UL);

auto nodes_iter = graph.get_nodes().begin();
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;");
UNITTEST_EQ(node->ops[0]->outputs[0], t2);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t3);
UNITTEST_EQ(node->ops[1]->outputs[0], t4);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t5);
UNITTEST_EQ(node->ops[1]->outputs[0], t3);
UNITTEST_EQ(node->ops[2]->outputs[0], t4);
UNITTEST_EQ(node->ops[3]->outputs[0], t5);

// Test an Op that uses outputs from multiple previous Ops.
// Model graph (omit leftmost part):
Expand Down Expand Up @@ -175,27 +172,26 @@ ark::unittest::State test_sched_opgraph() {

// OpNode graph (parentheses indicate a OpNode):
//
// +----------------------+
// | |
// (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) --+
// |
// (AddOp,) --+--> (AddOp,)
// (AddOp,AddOp,ReluOp,AddOp,) --+
// |
// (AddOp,) --+--> (AddOp,)
//

graph = ark::OpGraph(model);
UNITTEST_EQ(graph.get_nodes().size(), 5UL);
UNITTEST_EQ(graph.get_nodes().size(), 3UL);

nodes_iter = graph.get_nodes().begin();
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;");
UNITTEST_EQ(node->ops[0]->outputs[0], t2);
UNITTEST_EQ(node->ops[1]->outputs[0], t3);
UNITTEST_EQ(node->ops[2]->outputs[0], t4);
UNITTEST_EQ(node->ops[3]->outputs[0], t5);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t3);
UNITTEST_EQ(node->ops[1]->outputs[0], t4);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t5);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_3;");
UNITTEST_EQ(node->ops[0]->outputs[0], t8);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_4;");
UNITTEST_EQ(node->ops[0]->outputs[0], t9);

// Test an Op that uses a single input tensor for multiple inputs.
Expand Down Expand Up @@ -231,31 +227,31 @@ ark::unittest::State test_sched_opgraph() {

// OpNode graph (parentheses indicate a OpNode):
//
// +----------------------+
// | |
// (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) --+
// |
// (AddOp,) --+--> (AddOp,)
// (AddOp,AddOp,ReluOp,AddOp,) --+
// |
// (AddOp,) --+--> (AddOp,)
//
// (AddOp,)
// (AddOp,)
//

graph = ark::OpGraph(model);
UNITTEST_EQ(graph.get_nodes().size(), 6UL);
UNITTEST_EQ(graph.get_nodes().size(), 4UL);

nodes_iter = graph.get_nodes().begin();
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;");
UNITTEST_EQ(node->ops[0]->outputs[0], t2);
UNITTEST_EQ(node->ops[1]->outputs[0], t3);
UNITTEST_EQ(node->ops[2]->outputs[0], t4);
UNITTEST_EQ(node->ops[3]->outputs[0], t5);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t3);
UNITTEST_EQ(node->ops[1]->outputs[0], t4);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t5);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_3;");
UNITTEST_EQ(node->ops[0]->outputs[0], t8);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_4;");
UNITTEST_EQ(node->ops[0]->outputs[0], t9);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_5;");
UNITTEST_EQ(node->ops[0]->outputs[0], t11);

// Test using previous Ops' outputs from multiple different Ops.
Expand Down Expand Up @@ -290,33 +286,34 @@ ark::unittest::State test_sched_opgraph() {

// OpNode graph (parentheses indicate a OpNode):
//
// +----------------------+
// | |
// (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) --+--> (AddOp,)
// |
// (AddOp,) --+--> (AddOp,)
// (AddOp,AddOp,ReluOp,AddOp,) --+--> (AddOp,)
// |
// (AddOp,) --+--> (AddOp,)
//
// (AddOp,)
// (AddOp,)
//

graph = ark::OpGraph(model);
UNITTEST_EQ(graph.get_nodes().size(), 7UL);
UNITTEST_EQ(graph.get_nodes().size(), 5UL);

nodes_iter = graph.get_nodes().begin();
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;");
UNITTEST_EQ(node->ops[0]->outputs[0], t2);
UNITTEST_EQ(node->ops[1]->outputs[0], t3);
UNITTEST_EQ(node->ops[2]->outputs[0], t4);
UNITTEST_EQ(node->ops[3]->outputs[0], t5);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t3);
UNITTEST_EQ(node->ops[1]->outputs[0], t4);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->ops[0]->outputs[0], t5);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_3;");
UNITTEST_EQ(node->ops[0]->outputs[0], t8);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_4;");
UNITTEST_EQ(node->ops[0]->outputs[0], t9);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_5;");
UNITTEST_EQ(node->ops[0]->outputs[0], t11);
node = (nodes_iter++)->get();
UNITTEST_EQ(node->get_name(), "add_6;");
UNITTEST_EQ(node->ops[0]->outputs[0], t12);

return ark::unittest::SUCCESS;
Expand Down
Loading

0 comments on commit 97cd329

Please sign in to comment.