Don't rewrite AddN node if temporary variable already exists.
Don't create GraphTopologyView before we need it. PiperOrigin-RevId: 251527328
This commit is contained in:
parent
9584e56988
commit
76832228f8
@ -500,16 +500,6 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
// Look for AddN nodes (and equivalent) and record input names.
|
||||
MutableGraphView view(&item->graph);
|
||||
|
||||
// It's ok to use immutable GraphTopologyView here, because we do not destroy
|
||||
// any of the nodes in the underlying graph, we only add new nodes.
|
||||
GraphTopologyView graph_topology;
|
||||
Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
|
||||
if (!initialized_topology.ok()) {
|
||||
VLOG(1) << "Failed to initialize graph topology view: "
|
||||
<< initialized_topology.error_message();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
|
||||
for (NodeDef& node : *item->graph.mutable_node()) {
|
||||
if (!IsAddN(node) && node.op() != "AccumulateNV2") {
|
||||
@ -574,6 +564,16 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// It's ok to use immutable GraphTopologyView here, because we do not destroy
|
||||
// any of the nodes in the underlying graph, we only add new nodes.
|
||||
GraphTopologyView graph_topology;
|
||||
Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
|
||||
if (!initialized_topology.ok()) {
|
||||
VLOG(1) << "Failed to initialize graph topology view: "
|
||||
<< initialized_topology.error_message();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool updated_graph = false;
|
||||
// Rewrite the AddN.
|
||||
for (NodeDef* node : addn_to_rewrite) {
|
||||
@ -637,10 +637,15 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
|
||||
DataType dtype = node->attr().at("T").type();
|
||||
const string& device = node->device();
|
||||
const string tmp_var_name = strings::StrCat(node->name(), "/tmp_var");
|
||||
if (view.GetNode(tmp_var_name) != nullptr) {
|
||||
VLOG(1) << "Temporary variable already exists " << tmp_var_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create the temporary variable that will hold intermediate results
|
||||
NodeDef* tmp_var = item->graph.add_node();
|
||||
tmp_var->set_name(strings::StrCat(node->name(), "/tmp_var"));
|
||||
tmp_var->set_name(tmp_var_name);
|
||||
tmp_var->set_op("TemporaryVariable");
|
||||
tmp_var->set_device(device);
|
||||
(*tmp_var->mutable_attr())["dtype"].set_type(dtype);
|
||||
|
Loading…
Reference in New Issue
Block a user