Wipe out previous shape inference result when importing a grappler item

Run graph optimizations last: since they can be expensive it's best to filter invalid items first.

PiperOrigin-RevId: 157792834
This commit is contained in:
Benoit Steiner 2017-06-01 18:49:33 -07:00 committed by TensorFlower Gardener
parent 9ae941c4a8
commit 0503ce09c7

View File

@ -134,14 +134,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
new_item->id = id;
new_item->graph = meta_graph.graph_def();
// Optimize the graph (function inlining, l1 optimizations, etc).
Status optimize_status =
OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg);
if (!optimize_status.ok()) {
LOG(ERROR) << "Function optimization failed: " << optimize_status;
return nullptr;
}
// Attempt to detect the fetch node(s).
if (meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
@ -250,6 +242,10 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
*(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
}
// Erase the recorded result of any previous shape inference to start again
// from scratch.
node.mutable_attr()->erase("_output_shapes");
// Delete user specified placement if requested.
if (cfg.ignore_user_placement) {
node.clear_device();
@ -329,6 +325,14 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
}
// Optimize the graph (function inlining, l1 optimizations, etc).
Status optimize_status =
OptimizeGraph(new_item->graph, &new_item->graph, cfg);
if (!optimize_status.ok()) {
LOG(ERROR) << "Function optimization failed: " << optimize_status;
return nullptr;
}
return new_item;
}