Remove TPUPartitionedInput nodes from the src nodes of the TPUReplicatedInput node that does not have output edge.

PiperOrigin-RevId: 348054996
Change-Id: Ifa5b2996f987b7ff103895082a14353fb70305d3
This commit is contained in:
A. Unique TensorFlower 2020-12-17 11:05:18 -08:00 committed by TensorFlower Gardener
parent 13de383a2f
commit 9526aa6c03

View File

@ -54,6 +54,7 @@ namespace {
const char* const kTPUReplicatedInput = "TPUReplicatedInput";
const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
const char* const kPivotForClusterAttr = "_pivot_for_cluster";
const char* const kTPUPartitionedInput = "TPUPartitionedInput";
// Finds the `index` of an _Arg or _Retval node.
Status GetIndexAttr(const Node& n, int num_args, int* index) {
@ -1586,7 +1587,18 @@ void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
}
}
if (!has_output) {
// Remove any TPUPartitionedInput node from the src nodes of the
// to-be-removed TPUReplicatedInput node
std::vector<Node*> to_be_removed_src_nodes;
for (const auto& e_in : n->in_edges()) {
if (!e_in->IsControlEdge() &&
e_in->src()->type_string() == kTPUPartitionedInput)
to_be_removed_src_nodes.push_back(e_in->src());
}
graph->RemoveNode(n);
for (Node* node : to_be_removed_src_nodes) {
graph->RemoveNode(node);
}
}
}
}