Remove dead replicated Arg nodes.
PiperOrigin-RevId: 326930011 Change-Id: If0fcb8041af124497b7865b91691672233accba8
This commit is contained in:
parent
1cf462e362
commit
c164998f77
@ -159,6 +159,16 @@ class ReplicateHelper {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RemoveDeadReplicatedArgs(Graph* graph) {
|
||||
for (const auto& entry : replicated_nodes_map_) {
|
||||
for (Node* replicated_node : entry.second) {
|
||||
if (replicated_node->IsArg() && replicated_node->out_edges().empty()) {
|
||||
graph->RemoveNode(replicated_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Map from original nodes to corresponding replicated nodes.
|
||||
absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
|
||||
@ -256,6 +266,8 @@ Status ReplicatePerReplicaNodesInFunctionGraph(
|
||||
for (auto* n : cluster_nodes) {
|
||||
graph->RemoveNode(n);
|
||||
}
|
||||
|
||||
helper.RemoveDeadReplicatedArgs(graph);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ namespace {
|
||||
|
||||
class GraphHelper {
|
||||
public:
|
||||
explicit GraphHelper(const Graph& graph) {
|
||||
explicit GraphHelper(const Graph& graph) : graph_(graph) {
|
||||
for (Node* node : graph.nodes()) {
|
||||
nodes_by_name_[node->name()] = node;
|
||||
}
|
||||
@ -55,6 +55,16 @@ class GraphHelper {
|
||||
->set_assigned_device_name(device_name);
|
||||
}
|
||||
|
||||
void CheckArgNum(const int expected_num) {
|
||||
int arg_num = 0;
|
||||
for (Node* node : graph_.op_nodes()) {
|
||||
if (node->IsArg()) {
|
||||
arg_num++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(arg_num, expected_num);
|
||||
}
|
||||
|
||||
void CheckAssignedDevice(const string& node_name,
|
||||
const string& expected_device_name) {
|
||||
EXPECT_EQ(expected_device_name,
|
||||
@ -62,6 +72,7 @@ class GraphHelper {
|
||||
}
|
||||
|
||||
private:
|
||||
const Graph& graph_;
|
||||
// Maps from a node name to a Node* in the graph.
|
||||
absl::flat_hash_map<string, Node*> nodes_by_name_;
|
||||
};
|
||||
@ -103,6 +114,7 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
|
||||
// ReadVariableOp(TPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 7);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckArgNum(2);
|
||||
helper.CheckAssignedDevice("arg/R0", "TPU:0");
|
||||
helper.CheckAssignedDevice("arg/R1", "TPU:1");
|
||||
helper.CheckAssignedDevice("read", "TPU:0");
|
||||
@ -141,6 +153,7 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
|
||||
// _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 3);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckArgNum(1);
|
||||
helper.CheckAssignedDevice("arg", "TPU:0");
|
||||
helper.CheckAssignedDevice("read", "TPU:0");
|
||||
helper.CheckAssignedDevice("ret", "CPU:0");
|
||||
@ -192,6 +205,7 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
|
||||
// TPU:3) -> Identity(TPU:1, TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 12);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckArgNum(4);
|
||||
helper.CheckAssignedDevice("arg0/R0", "TPU:0");
|
||||
helper.CheckAssignedDevice("arg0/R1", "TPU:1");
|
||||
helper.CheckAssignedDevice("arg1/R0", "TPU:2");
|
||||
@ -261,6 +275,7 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
||||
// _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
|
||||
EXPECT_EQ(graph.num_op_nodes(), 5);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckArgNum(2);
|
||||
helper.CheckAssignedDevice("arg/R0", "TPU:0");
|
||||
helper.CheckAssignedDevice("arg/R1", "TPU:1");
|
||||
helper.CheckAssignedDevice("arg/Packed", "CPU:0");
|
||||
@ -279,5 +294,41 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
|
||||
auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
|
||||
auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
|
||||
|
||||
const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
|
||||
const absl::flat_hash_map<string, const std::vector<string>*>
|
||||
composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(scope.ToGraph(&graph));
|
||||
{
|
||||
// _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
|
||||
ASSERT_EQ(graph.num_op_nodes(), 3);
|
||||
GraphHelper helper(graph);
|
||||
helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
|
||||
helper.SetAssignedDevice("read", "TPU:0");
|
||||
helper.SetAssignedDevice("ret", "CPU:0");
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(
|
||||
ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
|
||||
|
||||
{
|
||||
// _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
|
||||
// "arg/R1" is a dead node, so gets removed.
|
||||
EXPECT_EQ(graph.num_op_nodes(), 3);
|
||||
GraphHelper helper(graph);
|
||||
helper.CheckArgNum(1);
|
||||
helper.CheckAssignedDevice("arg/R0", "TPU:0");
|
||||
helper.CheckAssignedDevice("read", "TPU:0");
|
||||
helper.CheckAssignedDevice("ret", "CPU:0");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user