Remove TPUPartitionedInput nodes from the src nodes of the TPUReplicatedInput node that does not have output edge.
PiperOrigin-RevId: 348054996 Change-Id: Ifa5b2996f987b7ff103895082a14353fb70305d3
This commit is contained in:
parent
13de383a2f
commit
9526aa6c03
@ -54,6 +54,7 @@ namespace {
|
||||
const char* const kTPUReplicatedInput = "TPUReplicatedInput";
|
||||
const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
|
||||
const char* const kPivotForClusterAttr = "_pivot_for_cluster";
|
||||
const char* const kTPUPartitionedInput = "TPUPartitionedInput";
|
||||
|
||||
// Finds the `index` of an _Arg or _Retval node.
|
||||
Status GetIndexAttr(const Node& n, int num_args, int* index) {
|
||||
@ -1586,7 +1587,18 @@ void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
|
||||
}
|
||||
}
|
||||
if (!has_output) {
|
||||
// Remove any TPUPartitionedInput node from the src nodes of the
|
||||
// to-be-removed TPUReplicatedInput node
|
||||
std::vector<Node*> to_be_removed_src_nodes;
|
||||
for (const auto& e_in : n->in_edges()) {
|
||||
if (!e_in->IsControlEdge() &&
|
||||
e_in->src()->type_string() == kTPUPartitionedInput)
|
||||
to_be_removed_src_nodes.push_back(e_in->src());
|
||||
}
|
||||
graph->RemoveNode(n);
|
||||
for (Node* node : to_be_removed_src_nodes) {
|
||||
graph->RemoveNode(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user