Insert a PackOp between per-replica arg nodes and a dst node which is not assigned a replica device. The dst should be responsible for unpacking the packed tensor.

PiperOrigin-RevId: 315320567
Change-Id: Ic7a94e33e8de3c9f98c735d72c0609486afc490e
This commit is contained in:
Yujing Zhang 2020-06-08 11:48:41 -07:00 committed by TensorFlower Gardener
parent e60c1ba960
commit a00daa2f37
3 changed files with 45 additions and 11 deletions

View File

@ -1361,6 +1361,7 @@ cc_library(
hdrs = ["replicate_per_replica_nodes.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
namespace tensorflow {
namespace {
@ -115,12 +116,36 @@ class ReplicateHelper {
// This happens when the dst node runs on a host CPU and
// captures a function with an arg node assigned to the same
// composite device (e.g. ScanDataset).
// For this case, we only need to add an edge connecting the arg
// node in the outer function and the corresponding arg in the
// inner function, since the host CPU only needs one copy of the
// ResourceHandle.
graph->AddEdge(src_replicated_nodes.at(0), edge->src_output(), dst,
edge->dst_input());
// For this case, we insert a PackOp between replicated nodes and the
// dst node. The dst node is responsible for unpacking the packed
// tensor.
// Add '/Packed' as a substring to the name of the new node, which
// could be helpful when debugging the graph.
NodeDefBuilder pack_builder(
graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
"Pack");
const int num_replicas = src_replicated_nodes.size();
pack_builder.Attr("N", num_replicas);
const DataType dtype = edge->src()->output_type(edge->src_output());
pack_builder.Attr("T", dtype);
std::vector<NodeDefBuilder::NodeOut> inputs;
inputs.reserve(src_replicated_nodes.size());
for (Node* replicated_node : src_replicated_nodes) {
inputs.emplace_back(NodeDefBuilder::NodeOut{
replicated_node->name(), edge->src_output(), dtype});
}
pack_builder.Input(inputs);
NodeDef pack_def;
TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
Status status;
Node* pack_node = graph->AddNode(pack_def, &status);
TF_RETURN_IF_ERROR(status);
pack_node->set_assigned_device_name(dst->assigned_device_name());
for (int i = 0; i < src_replicated_nodes.size(); ++i) {
graph->AddEdge(src_replicated_nodes[i], edge->src_output(),
pack_node, i);
}
graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
} else {
return errors::InvalidArgument(
"Dst node should be assigned to an allowed device. Found an "

View File

@ -258,16 +258,24 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
{
// _Arg(TPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 4);
// _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 5);
GraphHelper helper(graph);
helper.CheckAssignedDevice("arg/R0", "TPU:0");
helper.CheckAssignedDevice("arg/R1", "TPU:1");
helper.CheckAssignedDevice("arg/Packed", "CPU:0");
helper.CheckAssignedDevice("func", "CPU:0");
helper.CheckAssignedDevice("ret", "CPU:0");
const EdgeSet& in_edges = helper.GetNodeByName("func")->in_edges();
EXPECT_EQ(in_edges.size(), 1);
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*in_edges.begin())->src());
const EdgeSet& packed_in_edges =
helper.GetNodeByName("arg/Packed")->in_edges();
EXPECT_EQ(packed_in_edges.size(), 2);
auto it = packed_in_edges.begin();
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
EXPECT_EQ(func_in_edges.size(), 1);
EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
(*func_in_edges.begin())->src());
}
}