diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 424994125f7..b9b240e72cb 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -639,6 +639,8 @@ std::pair VirtualScheduler::CreateSendRecv( send->set_device(ChannelDeviceName(from, to)); auto& send_attr = *(send->mutable_attr()); send_attr[kAttrInputSrc].set_s(input_name); + // Use input_name as tensor_name. + send_attr[kAttrTensorName].set_s(input_name); send_attr[kAttrSrcDevice].set_s(DeviceName(from)); send_attr[kAttrDstDevice].set_s(DeviceName(to)); @@ -650,6 +652,8 @@ std::pair VirtualScheduler::CreateSendRecv( recv->set_device(DeviceName(to)); auto& recv_attr = *(recv->mutable_attr()); recv_attr[kAttrInputSrc].set_s(input_name); + // Use input_name as tensor_name. + recv_attr[kAttrTensorName].set_s(input_name); // NodeState for _Send op. auto& send_node_state = GetNodeStateOrCreateIt(send); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 89dff9686d3..92e0a887822 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -308,8 +308,9 @@ class VirtualScheduler { private: // Constants. const string kAttrInputSrc = "input_source_"; - const string kAttrSrcDevice = "src_device_"; - const string kAttrDstDevice = "dst_device_"; + const string kAttrSrcDevice = "send_device"; + const string kAttrDstDevice = "recv_device"; + const string kAttrTensorName = "tensor_name"; const string kChannelDevice = "Channel"; // Methods called from Init(). Fails if initialize_ is set.