Wipe out previous shape inference result when importing a grappler item
Run graph optimizations last: since they can be expensive it's best to filter invalid items first. PiperOrigin-RevId: 157792834
This commit is contained in:
parent
9ae941c4a8
commit
0503ce09c7
@ -134,14 +134,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||||||
new_item->id = id;
|
new_item->id = id;
|
||||||
new_item->graph = meta_graph.graph_def();
|
new_item->graph = meta_graph.graph_def();
|
||||||
|
|
||||||
// Optimize the graph (function inlining, l1 optimizations, etc).
|
|
||||||
Status optimize_status =
|
|
||||||
OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg);
|
|
||||||
if (!optimize_status.ok()) {
|
|
||||||
LOG(ERROR) << "Function optimization failed: " << optimize_status;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to detect the fetch node(s).
|
// Attempt to detect the fetch node(s).
|
||||||
if (meta_graph.collection_def().count("train_op") > 0) {
|
if (meta_graph.collection_def().count("train_op") > 0) {
|
||||||
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
|
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
|
||||||
@ -250,6 +242,10 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||||||
*(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
|
*(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Erase the recorded result of any previous shape inference to start again
|
||||||
|
// from scratch.
|
||||||
|
node.mutable_attr()->erase("_output_shapes");
|
||||||
|
|
||||||
// Delete user specified placement if requested.
|
// Delete user specified placement if requested.
|
||||||
if (cfg.ignore_user_placement) {
|
if (cfg.ignore_user_placement) {
|
||||||
node.clear_device();
|
node.clear_device();
|
||||||
@ -329,6 +325,14 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Optimize the graph (function inlining, l1 optimizations, etc).
|
||||||
|
Status optimize_status =
|
||||||
|
OptimizeGraph(new_item->graph, &new_item->graph, cfg);
|
||||||
|
if (!optimize_status.ok()) {
|
||||||
|
LOG(ERROR) << "Function optimization failed: " << optimize_status;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
return new_item;
|
return new_item;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user