Remove dead replicated Arg nodes.

PiperOrigin-RevId: 326930011
Change-Id: If0fcb8041af124497b7865b91691672233accba8
This commit is contained in:
Yujing Zhang 2020-08-16 15:58:02 -07:00 committed by TensorFlower Gardener
parent 1cf462e362
commit c164998f77
2 changed files with 64 additions and 1 deletions

View File

@ -159,6 +159,16 @@ class ReplicateHelper {
return Status::OK();
}
void RemoveDeadReplicatedArgs(Graph* graph) {
for (const auto& entry : replicated_nodes_map_) {
for (Node* replicated_node : entry.second) {
if (replicated_node->IsArg() && replicated_node->out_edges().empty()) {
graph->RemoveNode(replicated_node);
}
}
}
}
private:
// Map from original nodes to corresponding replicated nodes.
absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
@ -256,6 +266,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
for (auto* n : cluster_nodes) {
graph->RemoveNode(n);
}
helper.RemoveDeadReplicatedArgs(graph);
}
return Status::OK();
}

View File

@ -31,7 +31,7 @@ namespace {
class GraphHelper {
public:
explicit GraphHelper(const Graph& graph) {
explicit GraphHelper(const Graph& graph) : graph_(graph) {
for (Node* node : graph.nodes()) {
nodes_by_name_[node->name()] = node;
}
@ -55,6 +55,16 @@ class GraphHelper {
->set_assigned_device_name(device_name);
}
void CheckArgNum(const int expected_num) {
int arg_num = 0;
for (Node* node : graph_.op_nodes()) {
if (node->IsArg()) {
arg_num++;
}
}
EXPECT_EQ(arg_num, expected_num);
}
void CheckAssignedDevice(const string& node_name,
const string& expected_device_name) {
EXPECT_EQ(expected_device_name,
@ -62,6 +72,7 @@ class GraphHelper {
}
private:
const Graph& graph_;
// Maps from a node name to a Node* in the graph.
absl::flat_hash_map<string, Node*> nodes_by_name_;
};
@ -103,6 +114,7 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
// ReadVariableOp(TPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 7);
GraphHelper helper(graph);
helper.CheckArgNum(2);
helper.CheckAssignedDevice("arg/R0", "TPU:0");
helper.CheckAssignedDevice("arg/R1", "TPU:1");
helper.CheckAssignedDevice("read", "TPU:0");
@ -141,6 +153,7 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
// _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 3);
GraphHelper helper(graph);
helper.CheckArgNum(1);
helper.CheckAssignedDevice("arg", "TPU:0");
helper.CheckAssignedDevice("read", "TPU:0");
helper.CheckAssignedDevice("ret", "CPU:0");
@ -192,6 +205,7 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
// TPU:3) -> Identity(TPU:1, TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 12);
GraphHelper helper(graph);
helper.CheckArgNum(4);
helper.CheckAssignedDevice("arg0/R0", "TPU:0");
helper.CheckAssignedDevice("arg0/R1", "TPU:1");
helper.CheckAssignedDevice("arg1/R0", "TPU:2");
@ -261,6 +275,7 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
// _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
EXPECT_EQ(graph.num_op_nodes(), 5);
GraphHelper helper(graph);
helper.CheckArgNum(2);
helper.CheckAssignedDevice("arg/R0", "TPU:0");
helper.CheckAssignedDevice("arg/R1", "TPU:1");
helper.CheckAssignedDevice("arg/Packed", "CPU:0");
@ -279,5 +294,41 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
}
}
TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
const absl::flat_hash_map<string, const std::vector<string>*>
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(scope.ToGraph(&graph));
{
// _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
ASSERT_EQ(graph.num_op_nodes(), 3);
GraphHelper helper(graph);
helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
helper.SetAssignedDevice("read", "TPU:0");
helper.SetAssignedDevice("ret", "CPU:0");
}
TF_EXPECT_OK(
ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
{
// _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
// "arg/R1" is a dead node, so gets removed.
EXPECT_EQ(graph.num_op_nodes(), 3);
GraphHelper helper(graph);
helper.CheckArgNum(1);
helper.CheckAssignedDevice("arg/R0", "TPU:0");
helper.CheckAssignedDevice("read", "TPU:0");
helper.CheckAssignedDevice("ret", "CPU:0");
}
}
} // namespace
} // namespace tensorflow