diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index cb7d7f7330c..d23facf81a4 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -74,7 +74,7 @@ void InitializeTensor(DataType type, Tensor* tensor) { // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in // order to get the correct session options and environment, and performing the // correct optimizations. -Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, +Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def, const ItemConfig& cfg) { if (!cfg.apply_optimizations && !cfg.inline_functions) { return Status::OK(); @@ -83,8 +83,16 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, // Create a session option for a single GPU device. SessionOptions options; - // Inline all functions. - GraphDef inlined_graph_def(graph_def); + // Make a local copy of graph def, because we need to change some things. + GraphDef graph_def(graph_def_arg); + + if (cfg.inline_functions && cfg.erase_noinline_attributes) { + // TF optimizer doesn't inline functions with "_noinline" attribute, + // so let's go over the function library and erase it. + for (auto& func : *graph_def.mutable_library()->mutable_function()) { + func.mutable_attr()->erase("_noinline"); + } + } // Instantiate all variables for function library runtime creation. std::vector devices; @@ -92,7 +100,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, options, "/job:localhost/replica:0/task:0", &devices)); std::unique_ptr dvc_mgr(new DeviceMgr(devices)); FunctionLibraryDefinition function_library(OpRegistry::Global(), - inlined_graph_def.library()); + graph_def.library()); Env* env = Env::Default(); // Optimizer options: L1 and inlining. L1 is default. @@ -108,7 +116,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, // Create the function library runtime. std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, - inlined_graph_def.versions().producer(), + graph_def.versions().producer(), &function_library, *optimizer_opts)); FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name()); @@ -118,11 +126,11 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, graph_ctor_opts.expect_device_spec = false; std::unique_ptr graphptr(new Graph(function_library)); // Populate default attrs to the NodeDefs in the GraphDef. - TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def, - *graphptr->op_registry(), 0)); + TF_RETURN_IF_ERROR( + AddDefaultAttrsToGraphDef(&graph_def, *graphptr->op_registry(), 0)); - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_ctor_opts, inlined_graph_def, - graphptr.get())); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get())); // Optimize the graph. GraphOptimizer optimizer(*optimizer_opts); diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index 4ce5055e7a1..9a7f52228b9 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -27,24 +27,22 @@ class MetaGraphDef; namespace grappler { struct ItemConfig { - ItemConfig() - : ignore_user_placement(true), - ignore_colocation(true), - placeholder_unknown_output_shape_dim(-1), - apply_optimizations(false), - inline_functions(false) {} + ItemConfig() {} // If true, ignore all user specified node placement. - bool ignore_user_placement; + bool ignore_user_placement = true; // If true, ignore all user specified colocation attributes. - bool ignore_colocation; + bool ignore_colocation = true; // Dimension to use if a placeholder node has an _output_shapes attribute with // a dimension of -1. - int placeholder_unknown_output_shape_dim; + int placeholder_unknown_output_shape_dim = -1; // If true, does L1 optimizations. - bool apply_optimizations; + bool apply_optimizations = false; // If true, does inlining. - bool inline_functions; + bool inline_functions = false; + // If true, erases all "_noinline" attributes from user-defined functions. + // Has no effect if "inline_functions" is disabled. + bool erase_noinline_attributes = false; // If non-empty, override the directory of asset paths. string assets_directory_override; };