From 4ea51ae99cdeeb7abb9bc8e5efff188c4729fca3 Mon Sep 17 00:00:00 2001 From: Andiry Xu Date: Thu, 15 Nov 2018 11:37:50 -0800 Subject: [PATCH] Change _Send/_Recv attrs in VirtualScheduler Change src_device_ to send_device and dst_device_ to recv_device. This complies with tensorflow naming, so that VirtualScheduler can handle graphs generated on inspectz with _Send/_Recv nodes and AutoGrappler does not need to remove them. PiperOrigin-RevId: 221660625 --- tensorflow/core/grappler/costs/virtual_scheduler.cc | 4 ++++ tensorflow/core/grappler/costs/virtual_scheduler.h | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 424994125f7..b9b240e72cb 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -639,6 +639,8 @@ std::pair VirtualScheduler::CreateSendRecv( send->set_device(ChannelDeviceName(from, to)); auto& send_attr = *(send->mutable_attr()); send_attr[kAttrInputSrc].set_s(input_name); + // Use input_name as tensor_name. + send_attr[kAttrTensorName].set_s(input_name); send_attr[kAttrSrcDevice].set_s(DeviceName(from)); send_attr[kAttrDstDevice].set_s(DeviceName(to)); @@ -650,6 +652,8 @@ std::pair VirtualScheduler::CreateSendRecv( recv->set_device(DeviceName(to)); auto& recv_attr = *(recv->mutable_attr()); recv_attr[kAttrInputSrc].set_s(input_name); + // Use input_name as tensor_name. + recv_attr[kAttrTensorName].set_s(input_name); // NodeState for _Send op. auto& send_node_state = GetNodeStateOrCreateIt(send); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 89dff9686d3..92e0a887822 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -308,8 +308,9 @@ class VirtualScheduler { private: // Constants. const string kAttrInputSrc = "input_source_"; - const string kAttrSrcDevice = "src_device_"; - const string kAttrDstDevice = "dst_device_"; + const string kAttrSrcDevice = "send_device"; + const string kAttrDstDevice = "recv_device"; + const string kAttrTensorName = "tensor_name"; const string kChannelDevice = "Channel"; // Methods called from Init(). Fails if initialize_ is set.