diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index c533dc6468f..089c24e18a2 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -498,7 +498,12 @@ Status SchedulerState::Init(const GrapplerItem* item, const string in_device = DeviceName(input_node); const auto input_node_port_num = NodePosition(input_node_name); - if (curr_node_device == in_device) { + // Control dependencies should be treated as high priority. Current + // Channel device doesn't model a separate virual channel for control v/s + // data transfers. So in the interim, it may be okay to let control + // dependencies magically flow across devices bypassing the channel + // device. + if (curr_node_device == in_device || IsControlInput(input_node_name)) { // Same device: connect input_node and curr_node directly. curr_node_state.inputs.push_back( std::make_pair(input_node, input_node_port_num)); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 2da15512181..c22a6ed9419 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -2982,9 +2982,10 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { // Same number of _Send and _Recv. EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv)); - // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency. - EXPECT_EQ(op_count.at(kRecv), 4); - EXPECT_EQ(op_count.at(kSend), 4); + // Expect 3 Send and Recvs each: port 0, 1, and, 2. + // Control dependency bypasses the channel. + EXPECT_EQ(op_count.at(kRecv), 3); + EXPECT_EQ(op_count.at(kSend), 3); // Helper lambda for extracting output Tensor size. auto get_output_size = [this, ops_executed](const string& name) -> int64 { @@ -3006,9 +3007,6 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_); EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_); EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_); - // Control dependency size is 4B. - EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); - EXPECT_EQ(get_output_size(send_op_names[-1]), 4); } TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {