Don't rewrite AddN node if temporary variable already exists.

Don't create GraphTopologyView before we need it.

PiperOrigin-RevId: 251527328
This commit is contained in:
A. Unique TensorFlower 2019-06-04 15:20:35 -07:00 committed by TensorFlower Gardener
parent 9584e56988
commit 76832228f8

View File

@ -500,16 +500,6 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
// Look for AddN nodes (and equivalent) and record input names.
MutableGraphView view(&item->graph);
// It's ok to use immutable GraphTopologyView here, because we do not destroy
// any of the nodes in the underlying graph, we only add new nodes.
GraphTopologyView graph_topology;
Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
if (!initialized_topology.ok()) {
VLOG(1) << "Failed to initialize graph topology view: "
<< initialized_topology.error_message();
return false;
}
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
for (NodeDef& node : *item->graph.mutable_node()) {
if (!IsAddN(node) && node.op() != "AccumulateNV2") {
@ -574,6 +564,16 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
return false;
}
// It's ok to use immutable GraphTopologyView here, because we do not destroy
// any of the nodes in the underlying graph, we only add new nodes.
GraphTopologyView graph_topology;
Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
if (!initialized_topology.ok()) {
VLOG(1) << "Failed to initialize graph topology view: "
<< initialized_topology.error_message();
return false;
}
bool updated_graph = false;
// Rewrite the AddN.
for (NodeDef* node : addn_to_rewrite) {
@ -637,10 +637,15 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
DataType dtype = node->attr().at("T").type();
const string& device = node->device();
const string tmp_var_name = strings::StrCat(node->name(), "/tmp_var");
if (view.GetNode(tmp_var_name) != nullptr) {
VLOG(1) << "Temporary variable already exists " << tmp_var_name;
return false;
}
// Create the temporary variable that will hold intermediate results
NodeDef* tmp_var = item->graph.add_node();
tmp_var->set_name(strings::StrCat(node->name(), "/tmp_var"));
tmp_var->set_name(tmp_var_name);
tmp_var->set_op("TemporaryVariable");
tmp_var->set_device(device);
(*tmp_var->mutable_attr())["dtype"].set_type(dtype);