From 58af37eca4a4f2967e9b5216320397e8df46ea48 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 13 Jan 2021 21:35:30 -0800 Subject: [PATCH] Control dependency flowing between devices should have a higher priority virtual channel than data flows, otherwise there is a risk of control transfer getting stuck behind a large data transfer. Current channel device implementation is simple, it doesn't model such prioritization. Instead of adding complexity to channel device, tf-sim could let control deps bypass channel device and (magically) flow across devices in zero time. Control deps are small, latency dominated transfers. Channel devices are only good at modeling BW bound transfers, so bypassing channel device for control deps should not be so bad. PiperOrigin-RevId: 351727352 Change-Id: Ibe20b9018427e30cffe9e7fc1cb3713ebe47510b --- tensorflow/core/grappler/costs/virtual_scheduler.cc | 7 ++++++- .../core/grappler/costs/virtual_scheduler_test.cc | 10 ++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index c533dc6468f..089c24e18a2 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -498,7 +498,12 @@ Status SchedulerState::Init(const GrapplerItem* item, const string in_device = DeviceName(input_node); const auto input_node_port_num = NodePosition(input_node_name); - if (curr_node_device == in_device) { + // Control dependencies should be treated as high priority. Current + // Channel device doesn't model a separate virual channel for control v/s + // data transfers. So in the interim, it may be okay to let control + // dependencies magically flow across devices bypassing the channel + // device. + if (curr_node_device == in_device || IsControlInput(input_node_name)) { // Same device: connect input_node and curr_node directly. curr_node_state.inputs.push_back( std::make_pair(input_node, input_node_port_num)); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 2da15512181..c22a6ed9419 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -2982,9 +2982,10 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { // Same number of _Send and _Recv. EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv)); - // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency. - EXPECT_EQ(op_count.at(kRecv), 4); - EXPECT_EQ(op_count.at(kSend), 4); + // Expect 3 Send and Recvs each: port 0, 1, and, 2. + // Control dependency bypasses the channel. + EXPECT_EQ(op_count.at(kRecv), 3); + EXPECT_EQ(op_count.at(kSend), 3); // Helper lambda for extracting output Tensor size. auto get_output_size = [this, ops_executed](const string& name) -> int64 { @@ -3006,9 +3007,6 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_); EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_); EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_); - // Control dependency size is 4B. - EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); - EXPECT_EQ(get_output_size(send_op_names[-1]), 4); } TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {