[Grappler] Sort output graph of GenericLayoutOptimizer topologically.

PiperOrigin-RevId: 256399071
This commit is contained in:
Andy Ly 2019-07-03 10:54:04 -07:00 committed by TensorFlower Gardener
parent c23db02fd9
commit 9f9b8a5101

View File

@ -219,8 +219,14 @@ Status GenericLayoutOptimizer::Optimize(Cluster* cluster,
item, cluster, src_format_, dst_format_, target_device_, &context));
TransposerFactory transposer_factory;
TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory));
TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory));
TF_RETURN_IF_ERROR(EraseCancellableNodes(&context));
if (context.graph.node_size() > context.num_nodes) {
TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory));
TF_RETURN_IF_ERROR(EraseCancellableNodes(&context));
// TODO(lyandy): Remove sorting once other optimizers are migrated to using
// `utils::GraphView`.
TF_RETURN_IF_ERROR(
context.graph_view->SortTopologically(/*ignore_cycles=*/false, {}));
}
TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context));
*output = context.graph;