diff --git a/ark/sched/sched_opgraph.cc b/ark/sched/sched_opgraph.cc index 8c1bd17c3..114efaddf 100644 --- a/ark/sched/sched_opgraph.cc +++ b/ark/sched/sched_opgraph.cc @@ -238,35 +238,83 @@ void OpGraph::recursive_merge(std::list> &nodes, } new_boundary_nodes.emplace_back(producer); } + OpNode *merge_candidate = nullptr; if (boundary_node->producers.size() > 1) { - // This node has multiple producers. It cannot be merged. - seen_nodes.insert(boundary_node); - OPGRAPH_DEBUG(" multiple producers"); - continue; + // This node has multiple producers. We can merge only if one + // producer depends on all other producers. + for (auto &producer : boundary_node->producers) { + bool depends_on_all = true; + for (auto &other_producer : boundary_node->producers) { + if (other_producer == producer) { + continue; + } + if (!this->depends_on(producer, other_producer)) { + depends_on_all = false; + break; + } + } + if (depends_on_all) { + merge_candidate = producer; + break; + } + } + if (merge_candidate == nullptr) { + // At least one producer does not depend on others. + // Cannot merge. + seen_nodes.insert(boundary_node); + OPGRAPH_DEBUG(" multiple producers"); + continue; + } + } else { + // This node has only one producer. + merge_candidate = *(boundary_node->producers.begin()); } - // This node has only one producer. - OpNode *producer = *(boundary_node->producers.begin()); - if (producer->users.size() == 0) { + if (merge_candidate->users.size() == 0) { ERR(SchedulerError, "unexpected error: graph is incomplete"); } - if (producer->users.size() > 1) { - // The producer has multiple users. It cannot be merged. - seen_nodes.insert(boundary_node); - OPGRAPH_DEBUG(" multiple users"); - continue; + if (merge_candidate->users.size() > 1) { + // The candidate has multiple users. We can merge only if all + // other users depend on the current boundary_node. + bool depends_on_one = true; + for (auto &user : merge_candidate->users) { + if (user == boundary_node) { + continue; + } + if (!this->depends_on(user, boundary_node)) { + depends_on_one = false; + break; + } + } + if (!depends_on_one) { + // At least one user does not depend on the boundary_node. + // Cannot merge. + seen_nodes.insert(boundary_node); + OPGRAPH_DEBUG(" multiple users"); + continue; + } } - // The producer has only one user. Merge the two nodes. + // The candidate has only one user. Merge the two nodes. - // Merge `boundary_node` into `producer`. - OPGRAPH_DEBUG(" merge ops: ", producer->get_name(), " -> ", + // Merge `boundary_node` into `merge_candidate`. + OPGRAPH_DEBUG(" merge: ", merge_candidate->get_name(), " -> ", boundary_node->get_name()); auto &ops = boundary_node->ops; - producer->ops.insert(producer->ops.end(), ops.begin(), ops.end()); - producer->users = boundary_node->users; - for (auto &user : producer->users) { + merge_candidate->ops.insert(merge_candidate->ops.end(), ops.begin(), + ops.end()); + for (auto &user : boundary_node->users) { user->producers.erase(boundary_node); - user->producers.insert(producer); + user->producers.insert(merge_candidate); + merge_candidate->users.insert(user); } + for (auto &producer : boundary_node->producers) { + if (producer == merge_candidate) { + continue; + } + producer->users.erase(boundary_node); + producer->users.insert(merge_candidate); + merge_candidate->producers.insert(producer); + } + merge_candidate->users = boundary_node->users; // Remove `boundary_node` from `nodes`. auto it = @@ -309,4 +357,34 @@ OpNode *OpGraph::break_node(OpNode *node, int op_idx) { return new_node; } +/// Check dependencies between two @ref OpNode. +/// +/// @param node1 The first @ref OpNode. +/// @param node2 The second @ref OpNode. +/// @return True if @p node1 depends on @p node2. +bool OpGraph::depends_on(OpNode *node1, OpNode *node2) const { + if (node1 == node2) { + return false; + } + std::set seen_nodes; + std::list boundary_nodes; + boundary_nodes.emplace_back(node1); + while (boundary_nodes.size() > 0) { + std::list new_boundary_nodes; + for (auto &boundary_node : boundary_nodes) { + if (boundary_node == node2) { + return true; + } + for (auto &producer : boundary_node->producers) { + if (seen_nodes.find(producer) != seen_nodes.end()) { + continue; + } + new_boundary_nodes.emplace_back(producer); + } + } + boundary_nodes = new_boundary_nodes; + } + return false; +} + } // namespace ark diff --git a/ark/sched/sched_opgraph.h b/ark/sched/sched_opgraph.h index e828f8c31..b85436fc4 100644 --- a/ark/sched/sched_opgraph.h +++ b/ark/sched/sched_opgraph.h @@ -84,16 +84,23 @@ class OpGraph { /// @return The new @ref OpNode. OpNode *break_node(OpNode *node, int op_idx); + /// Check dependencies between two @ref OpNode. + /// + /// @param node1 The first @ref OpNode. + /// @param node2 The second @ref OpNode. + /// @return True if @p node1 depends on @p node2. + bool depends_on(OpNode *node1, OpNode *node2) const; + private: std::list> nodes_storage; void create_nodes(const Model &model); - static void recursive_rm_virt(std::list> &nodes, - std::set &seen_nodes, - const std::list &boundary_nodes); - static void recursive_merge(std::list> &nodes, - std::set &seen_nodes, - const std::list &boundary_nodes); + void recursive_rm_virt(std::list> &nodes, + std::set &seen_nodes, + const std::list &boundary_nodes); + void recursive_merge(std::list> &nodes, + std::set &seen_nodes, + const std::list &boundary_nodes); }; } // namespace ark diff --git a/ark/sched/sched_opgraph_test.cc b/ark/sched/sched_opgraph_test.cc index 54299f2d0..4fbd669eb 100644 --- a/ark/sched/sched_opgraph_test.cc +++ b/ark/sched/sched_opgraph_test.cc @@ -130,22 +130,19 @@ ark::unittest::State test_sched_opgraph() { // OpNode graph (parentheses indicate a OpNode): // - // +----------------------+ - // | | - // (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) + // (AddOp,AddOp,ReluOp,AddOp,) // graph = ark::OpGraph(model); - UNITTEST_EQ(graph.get_nodes().size(), 3UL); + UNITTEST_EQ(graph.get_nodes().size(), 1UL); auto nodes_iter = graph.get_nodes().begin(); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); UNITTEST_EQ(node->ops[0]->outputs[0], t2); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t3); - UNITTEST_EQ(node->ops[1]->outputs[0], t4); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t5); + UNITTEST_EQ(node->ops[1]->outputs[0], t3); + UNITTEST_EQ(node->ops[2]->outputs[0], t4); + UNITTEST_EQ(node->ops[3]->outputs[0], t5); // Test an Op that uses outputs from multiple previous Ops. // Model graph (omit leftmost part): @@ -175,27 +172,26 @@ ark::unittest::State test_sched_opgraph() { // OpNode graph (parentheses indicate a OpNode): // - // +----------------------+ - // | | - // (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) --+ - // | - // (AddOp,) --+--> (AddOp,) + // (AddOp,AddOp,ReluOp,AddOp,) --+ + // | + // (AddOp,) --+--> (AddOp,) // graph = ark::OpGraph(model); - UNITTEST_EQ(graph.get_nodes().size(), 5UL); + UNITTEST_EQ(graph.get_nodes().size(), 3UL); nodes_iter = graph.get_nodes().begin(); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); UNITTEST_EQ(node->ops[0]->outputs[0], t2); + UNITTEST_EQ(node->ops[1]->outputs[0], t3); + UNITTEST_EQ(node->ops[2]->outputs[0], t4); + UNITTEST_EQ(node->ops[3]->outputs[0], t5); node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t3); - UNITTEST_EQ(node->ops[1]->outputs[0], t4); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t5); - node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_3;"); UNITTEST_EQ(node->ops[0]->outputs[0], t8); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_4;"); UNITTEST_EQ(node->ops[0]->outputs[0], t9); // Test an Op that uses a single input tensor for multiple inputs. @@ -231,31 +227,31 @@ ark::unittest::State test_sched_opgraph() { // OpNode graph (parentheses indicate a OpNode): // - // +----------------------+ - // | | - // (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) --+ - // | - // (AddOp,) --+--> (AddOp,) + // (AddOp,AddOp,ReluOp,AddOp,) --+ + // | + // (AddOp,) --+--> (AddOp,) // - // (AddOp,) + // (AddOp,) // graph = ark::OpGraph(model); - UNITTEST_EQ(graph.get_nodes().size(), 6UL); + UNITTEST_EQ(graph.get_nodes().size(), 4UL); nodes_iter = graph.get_nodes().begin(); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); UNITTEST_EQ(node->ops[0]->outputs[0], t2); + UNITTEST_EQ(node->ops[1]->outputs[0], t3); + UNITTEST_EQ(node->ops[2]->outputs[0], t4); + UNITTEST_EQ(node->ops[3]->outputs[0], t5); node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t3); - UNITTEST_EQ(node->ops[1]->outputs[0], t4); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t5); - node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_3;"); UNITTEST_EQ(node->ops[0]->outputs[0], t8); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_4;"); UNITTEST_EQ(node->ops[0]->outputs[0], t9); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_5;"); UNITTEST_EQ(node->ops[0]->outputs[0], t11); // Test using previous Ops' outputs from multiple different Ops. @@ -290,33 +286,34 @@ ark::unittest::State test_sched_opgraph() { // OpNode graph (parentheses indicate a OpNode): // - // +----------------------+ - // | | - // (AddOp,) --+--> (AddOp,ReluOp,) --+--> (AddOp,) --+--> (AddOp,) - // | - // (AddOp,) --+--> (AddOp,) + // (AddOp,AddOp,ReluOp,AddOp,) --+--> (AddOp,) + // | + // (AddOp,) --+--> (AddOp,) // - // (AddOp,) + // (AddOp,) // graph = ark::OpGraph(model); - UNITTEST_EQ(graph.get_nodes().size(), 7UL); + UNITTEST_EQ(graph.get_nodes().size(), 5UL); nodes_iter = graph.get_nodes().begin(); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add;add_1;relu;add_2;"); UNITTEST_EQ(node->ops[0]->outputs[0], t2); + UNITTEST_EQ(node->ops[1]->outputs[0], t3); + UNITTEST_EQ(node->ops[2]->outputs[0], t4); + UNITTEST_EQ(node->ops[3]->outputs[0], t5); node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t3); - UNITTEST_EQ(node->ops[1]->outputs[0], t4); - node = (nodes_iter++)->get(); - UNITTEST_EQ(node->ops[0]->outputs[0], t5); - node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_3;"); UNITTEST_EQ(node->ops[0]->outputs[0], t8); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_4;"); UNITTEST_EQ(node->ops[0]->outputs[0], t9); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_5;"); UNITTEST_EQ(node->ops[0]->outputs[0], t11); node = (nodes_iter++)->get(); + UNITTEST_EQ(node->get_name(), "add_6;"); UNITTEST_EQ(node->ops[0]->outputs[0], t12); return ark::unittest::SUCCESS; diff --git a/ark/sched/sched_test.cc b/ark/sched/sched_test.cc index 1a9d8c363..b6cd3cb00 100644 --- a/ark/sched/sched_test.cc +++ b/ark/sched/sched_test.cc @@ -9,230 +9,6 @@ #include "ops/ops_test_common.h" #include "unittest/unittest_utils.h" -using namespace std; -using namespace ark; - -ark::unittest::State test_sched_mm_add() { - unittest::spawn_process([&]() { - DimType batch_size = 1; - DimType dim_input = 2048; - DimType dim_hidden = 12288; - TensorType dtype = FP16; - - Model model; - Tensor *input = - model.tensor({batch_size, dim_input, dim_hidden}, dtype); - Tensor *weight = model.tensor({dim_hidden, dim_hidden}, dtype); - - Tensor *mm = model.matmul(input, weight); - /* Tensor *mm_add = */ model.add(mm, input); - - GpuMgr *mgr = get_gpu_mgr(0); - const GpuInfo &ginfo = mgr->get_gpu_info(); - ark::DefaultScheduler sched{model, 0, 0, 1, 8}; - GpuMgrCtx *ctx = sched.create_context("test_sched_mm_add"); - sched.schedule(); - auto codes = sched.gen_code(); - - GpuLoopKernel glk{"test_sched_mm_add", - codes, - (unsigned int)ginfo.num_sm, - 8, - (unsigned int)ginfo.smem_block_total, - "", - ctx}; - glk.compile(ginfo); - glk.load(); - GpuStream stream = ctx->create_stream(); - GpuState ret = glk.launch(stream, false); - UNITTEST_EQ(ret, 0); - int iter = 1000; - glk.run(iter); - glk.stop(); - - UNITTEST_LOG("test_sched_mm_add: batch_size ", batch_size, - " dim_input ", dim_input, " dim_hidden ", dim_hidden, - " dtype ", dtype, " elapsed ", - glk.get_elapsed_msec() / (float)iter, " ms/iter"); - - return unittest::SUCCESS; - }); - ark::unittest::wait_all_processes(); - return unittest::SUCCESS; -} - -ark::unittest::State test_scheduler_simple_mm() { - // Hidden dimension of the dense layer. - unsigned int units = 2048; - // Input dimension of the dense layer. - unsigned int in_dim = 2048; - // Extra dimension of the input. CHANNEL=1 for 2D inputs. - unsigned int channel = 2048; - // Batch size of the input. - unsigned int batch_size = 1; - - Model m; - Tensor *input = m.tensor({batch_size, channel, in_dim}, FP16); - Tensor *weight = m.tensor({in_dim, units}, FP16); - m.matmul(input, weight); - - GpuMgr *mgr = get_gpu_mgr(0); - const GpuInfo &ginfo = mgr->get_gpu_info(); - - DefaultScheduler sched{m, 0, 0, 1, 8}; - GpuMgrCtx *ctx = sched.create_context("test_scheduler_simple_mm"); - sched.schedule(); - auto codes = sched.gen_code(); - - GpuLoopKernel glk{"test_scheduler_simple_mm", - codes, - (unsigned int)ginfo.num_sm, - 8, - (unsigned int)ginfo.smem_block_total, - "", - ctx}; - glk.compile(ginfo); - glk.load(); - - GpuStream stream = ctx->create_stream(); - for (int i = 0; i < 10; ++i) { - GpuState ret = glk.launch(stream, false); - UNITTEST_EQ(ret, 0); - glk.run(100); - glk.stop(); - UNITTEST_LOG(glk.get_elapsed_msec()); - } - - return unittest::SUCCESS; -} - -Tensor *MultiheadAttention(Model *model, Tensor *input, DimType embed_dim, - DimType num_heads, float dropout, TensorType dtype) { - // input: (batch_size, seq_len, embed_dim) - // output: (batch_size, seq_len, embed_dim) - Tensor *w_q_proj = model->tensor({embed_dim, embed_dim}, dtype); - Tensor *w_k_proj = model->tensor({embed_dim, embed_dim}, dtype); - Tensor *w_v_proj = model->tensor({embed_dim, embed_dim}, dtype); - Tensor *w_out_proj = model->tensor({embed_dim, embed_dim}, dtype); - - Tensor *q_proj = model->matmul(input, w_q_proj); - Tensor *k_proj = model->matmul(input, w_k_proj); - Tensor *v_proj = model->matmul(input, w_v_proj); - Tensor *q_proj_r_t = model->reshape( - q_proj, - {input->shape[0], input->shape[1], num_heads, embed_dim / num_heads}); - Tensor *k_proj_r_t = model->reshape( - k_proj, - {input->shape[0], input->shape[1], num_heads, embed_dim / num_heads}); - Tensor *v_proj_r_t = model->reshape( - v_proj, - {input->shape[0], input->shape[1], num_heads, embed_dim / num_heads}); - // Tensor *q_proj_r_t = model->transpose(q_proj_r, {0, 2, 1, 3}); - // Tensor *k_proj_r_t = model->transpose(k_proj_r, {0, 2, 1, 3}); - // Tensor *v_proj_r_t = model->transpose(v_proj_r, {0, 2, 1, 3}); - q_proj_r_t = model->reshape( - q_proj_r_t, - {input->shape[0] * num_heads, input->shape[1], embed_dim / num_heads}); - k_proj_r_t = model->reshape( - k_proj_r_t, - {input->shape[0] * num_heads, input->shape[1], embed_dim / num_heads}); - v_proj_r_t = model->reshape( - v_proj_r_t, - {input->shape[0] * num_heads, input->shape[1], embed_dim / num_heads}); - - // scaled dot product - Tensor *attn_logits = - model->matmul(q_proj_r_t, k_proj_r_t, nullptr, 1, false, true); - Tensor *attn_logits_scaled = - model->scale(attn_logits, 1.0 / sqrt(embed_dim / num_heads)); - - // Tensor *attention = model->softmax(attn_logits_scaled, 2); - Tensor *attention = attn_logits_scaled; - Tensor *values = model->matmul(attention, v_proj_r_t); - // values = model->reshape(values, {input->shape[0], num_heads, - // input->shape[1], embed_dim / num_heads}); - - // Tensor *values_t = model->transpose(values, {0, 2, 1, 3}); - Tensor *values_t_r = - model->reshape(values, {input->shape[0], input->shape[1], embed_dim}); - Tensor *output = model->matmul(values_t_r, w_out_proj); - - if (dropout > 0.0) { - // output = model->dropout(output, dropout); - } - return output; -} - -Tensor *TransformerLayerForward(Model *model, Tensor *input, DimType embed_dim, - DimType num_heads, DimType dim_ff, - float dropout, TensorType dtype) { - Tensor *attn_out = - MultiheadAttention(model, input, embed_dim, num_heads, dropout, dtype); - Tensor *res = model->add(input, attn_out); - // res = model->layernorm(res, res); - - Tensor *w_ff1 = model->tensor({embed_dim, dim_ff}, dtype); - Tensor *w_ff2 = model->tensor({dim_ff, embed_dim}, dtype); - - Tensor *ff1 = model->matmul(res, w_ff1); - Tensor *ff2 = model->matmul(ff1, w_ff2); - Tensor *ret = model->add(res, ff2); - // ret = model->layernorm(ret, ret); - return ret; -} - -Tensor *GPT3LayerForward(Model *model, Tensor *input, TensorType dtype) { - return TransformerLayerForward(model, input, - /*embed_dim=*/12288, - /*num_heads=*/96, - /*dim_ff=*/49152, - /*dropout=*/0.0, dtype); -} - -ark::unittest::State test_sched_gpt3() { - Model model; - DimType batch_size = 1; - DimType seq_len = 2048; - DimType embed_dim = 12288; - TensorType dtype = FP16; - Tensor *input = model.tensor({batch_size, seq_len, embed_dim}, dtype); - GPT3LayerForward(&model, input, dtype); - - unittest::spawn_process([&]() { - GpuMgr *mgr = get_gpu_mgr(0); - const GpuInfo &ginfo = mgr->get_gpu_info(); - ark::DefaultScheduler sched{model, 0, 0, 1, 8}; - GpuMgrCtx *ctx = sched.create_context("test_sched_gpt3"); - sched.schedule(); - auto codes = sched.gen_code(); - - GpuLoopKernel glk{"test_sched_gpt3", - codes, - (unsigned int)ginfo.num_sm, - 8, - (unsigned int)ginfo.smem_block_total, - "", - ctx}; - glk.compile(ginfo); - glk.load(); - GpuStream stream = ctx->create_stream(); - GpuState ret = glk.launch(stream, false); - UNITTEST_EQ(ret, 0); - int iter = 100; - glk.run(iter); - glk.stop(); - - UNITTEST_LOG("test_sched_gpt3: batch_size ", batch_size, " seq_len ", - seq_len, " embed_dim ", embed_dim, " dtype ", dtype, - " elapsed ", glk.get_elapsed_msec() / (float)iter, - " ms/iter"); - - return unittest::SUCCESS; - }); - ark::unittest::wait_all_processes(); - return unittest::SUCCESS; -} - ark::unittest::State test_sched_many_comm_ops() { constexpr int num_gpus = 4; for (int gpu_id = 0; gpu_id < num_gpus; ++gpu_id) { @@ -335,9 +111,6 @@ ark::unittest::State test_sched_graph_opt() { int main() { ark::init(); - // UNITTEST(test_sched_mm_add); - // UNITTEST(test_scheduler_simple_mm); - // UNITTEST(test_sched_gpt3); UNITTEST(test_sched_many_comm_ops); UNITTEST(test_sched_mixed_precision); UNITTEST(test_sched_parallel_matmul);