diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 016896b36f4..ec9fe0ef688 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -1361,6 +1361,7 @@ cc_library( hdrs = ["replicate_per_replica_nodes.h"], copts = tf_copts(), deps = [ + "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc index fbae80aef55..610dc1b8835 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" namespace tensorflow { namespace { @@ -115,12 +116,36 @@ class ReplicateHelper { // This happens when the dst node runs on a host CPU and // captures a function with an arg node assigned to the same // composite device (e.g. ScanDataset). - // For this case, we only need to add an edge connecting the arg - // node in the outer function and the corresponding arg in the - // inner function, since the host CPU only needs one copy of the - // ResourceHandle. - graph->AddEdge(src_replicated_nodes.at(0), edge->src_output(), dst, - edge->dst_input()); + // For this case, we insert a PackOp between replicated nodes and the + // dst node. The dst node is responsible for unpacking the packed + // tensor. + // Add '/Packed' as a substring to the name of the new node, which + // could be helpful when debugging the graph. + NodeDefBuilder pack_builder( + graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")), + "Pack"); + const int num_replicas = src_replicated_nodes.size(); + pack_builder.Attr("N", num_replicas); + const DataType dtype = edge->src()->output_type(edge->src_output()); + pack_builder.Attr("T", dtype); + std::vector inputs; + inputs.reserve(src_replicated_nodes.size()); + for (Node* replicated_node : src_replicated_nodes) { + inputs.emplace_back(NodeDefBuilder::NodeOut{ + replicated_node->name(), edge->src_output(), dtype}); + } + pack_builder.Input(inputs); + NodeDef pack_def; + TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def)); + Status status; + Node* pack_node = graph->AddNode(pack_def, &status); + TF_RETURN_IF_ERROR(status); + pack_node->set_assigned_device_name(dst->assigned_device_name()); + for (int i = 0; i < src_replicated_nodes.size(); ++i) { + graph->AddEdge(src_replicated_nodes[i], edge->src_output(), + pack_node, i); + } + graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input()); } else { return errors::InvalidArgument( "Dst node should be assigned to an allowed device. Found an " diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc index db05907710c..0bf2001a955 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc @@ -258,16 +258,24 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) { ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph)); { - // _Arg(TPU:0) -> Func(CPU:0) -> _Retval(CPU:0) - EXPECT_EQ(graph.num_op_nodes(), 4); + // _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0) + EXPECT_EQ(graph.num_op_nodes(), 5); GraphHelper helper(graph); helper.CheckAssignedDevice("arg/R0", "TPU:0"); helper.CheckAssignedDevice("arg/R1", "TPU:1"); + helper.CheckAssignedDevice("arg/Packed", "CPU:0"); helper.CheckAssignedDevice("func", "CPU:0"); helper.CheckAssignedDevice("ret", "CPU:0"); - const EdgeSet& in_edges = helper.GetNodeByName("func")->in_edges(); - EXPECT_EQ(in_edges.size(), 1); - EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*in_edges.begin())->src()); + const EdgeSet& packed_in_edges = + helper.GetNodeByName("arg/Packed")->in_edges(); + EXPECT_EQ(packed_in_edges.size(), 2); + auto it = packed_in_edges.begin(); + EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src()); + EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src()); + const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges(); + EXPECT_EQ(func_in_edges.size(), 1); + EXPECT_EQ(helper.GetNodeByName("arg/Packed"), + (*func_in_edges.begin())->src()); } }