Insert a PackOp between per-replica arg nodes and a dst node which is not assigned a replica device. The dst should be responsible for unpacking the packed tensor.
PiperOrigin-RevId: 315320567 Change-Id: Ic7a94e33e8de3c9f98c735d72c0609486afc490e
This commit is contained in:
parent
e60c1ba960
commit
a00daa2f37
@ -1361,6 +1361,7 @@ cc_library(
|
||||
hdrs = ["replicate_per_replica_nodes.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -115,12 +116,36 @@ class ReplicateHelper {
|
||||
// This happens when the dst node runs on a host CPU and
|
||||
// captures a function with an arg node assigned to the same
|
||||
// composite device (e.g. ScanDataset).
|
||||
// For this case, we only need to add an edge connecting the arg
|
||||
// node in the outer function and the corresponding arg in the
|
||||
// inner function, since the host CPU only needs one copy of the
|
||||
// ResourceHandle.
|
||||
graph->AddEdge(src_replicated_nodes.at(0), edge->src_output(), dst,
|
||||
edge->dst_input());
|
||||
// For this case, we insert a PackOp between replicated nodes and the
|
||||
// dst node. The dst node is responsible for unpacking the packed
|
||||
// tensor.
|
||||
// Add '/Packed' as a substring to the name of the new node, which
|
||||
// could be helpful when debugging the graph.
|
||||
NodeDefBuilder pack_builder(
|
||||
graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
|
||||
"Pack");
|
||||
const int num_replicas = src_replicated_nodes.size();
|
||||
pack_builder.Attr("N", num_replicas);
|
||||
const DataType dtype = edge->src()->output_type(edge->src_output());
|
||||
pack_builder.Attr("T", dtype);
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||
inputs.reserve(src_replicated_nodes.size());
|
||||
for (Node* replicated_node : src_replicated_nodes) {
|
||||
inputs.emplace_back(NodeDefBuilder::NodeOut{
|
||||
replicated_node->name(), edge->src_output(), dtype});
|
||||
}
|
||||
pack_builder.Input(inputs);
|
||||
NodeDef pack_def;
|
||||
TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
|
||||
Status status;
|
||||
Node* pack_node = graph->AddNode(pack_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
pack_node->set_assigned_device_name(dst->assigned_device_name());
|
||||
for (int i = 0; i < src_replicated_nodes.size(); ++i) {
|
||||
graph->AddEdge(src_replicated_nodes[i], edge->src_output(),
|
||||
pack_node, i);
|
||||
}
|
||||
graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Dst node should be assigned to an allowed device. Found an "
|
||||
|
@ -258,16 +258,24 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
||||
ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
|
||||
|
||||
{
|
||||
// _Arg(TPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 4);
|
||||
// _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 5);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckAssignedDevice("arg/R0", "TPU:0");
|
||||
helper.CheckAssignedDevice("arg/R1", "TPU:1");
|
||||
helper.CheckAssignedDevice("arg/Packed", "CPU:0");
|
||||
helper.CheckAssignedDevice("func", "CPU:0");
|
||||
helper.CheckAssignedDevice("ret", "CPU:0");
|
||||
const EdgeSet& in_edges = helper.GetNodeByName("func")->in_edges();
|
||||
EXPECT_EQ(in_edges.size(), 1);
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*in_edges.begin())->src());
|
||||
const EdgeSet& packed_in_edges =
|
||||
helper.GetNodeByName("arg/Packed")->in_edges();
|
||||
EXPECT_EQ(packed_in_edges.size(), 2);
|
||||
auto it = packed_in_edges.begin();
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
|
||||
const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
|
||||
EXPECT_EQ(func_in_edges.size(), 1);
|
||||
EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
|
||||
(*func_in_edges.begin())->src());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user