Add a flag to erase "_noinline" attribute to allow total inlining in Grappler.

PiperOrigin-RevId: 171722354
This commit is contained in:
Max Galkin 2017-10-10 13:06:16 -07:00 committed by TensorFlower Gardener
parent e74adb6709
commit 0ffb522f02
2 changed files with 26 additions and 20 deletions

View File

@ -74,7 +74,7 @@ void InitializeTensor(DataType type, Tensor* tensor) {
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
// order to get the correct session options and environment, and performing the
// correct optimizations.
Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
const ItemConfig& cfg) {
if (!cfg.apply_optimizations && !cfg.inline_functions) {
return Status::OK();
@ -83,8 +83,16 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Create a session option for a single GPU device.
SessionOptions options;
// Inline all functions.
GraphDef inlined_graph_def(graph_def);
// Make a local copy of graph def, because we need to change some things.
GraphDef graph_def(graph_def_arg);
if (cfg.inline_functions && cfg.erase_noinline_attributes) {
// TF optimizer doesn't inline functions with "_noinline" attribute,
// so let's go over the function library and erase it.
for (auto& func : *graph_def.mutable_library()->mutable_function()) {
func.mutable_attr()->erase("_noinline");
}
}
// Instantiate all variables for function library runtime creation.
std::vector<Device*> devices;
@ -92,7 +100,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
options, "/job:localhost/replica:0/task:0", &devices));
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
FunctionLibraryDefinition function_library(OpRegistry::Global(),
inlined_graph_def.library());
graph_def.library());
Env* env = Env::Default();
// Optimizer options: L1 and inlining. L1 is default.
@ -108,7 +116,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Create the function library runtime.
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
inlined_graph_def.versions().producer(),
graph_def.versions().producer(),
&function_library, *optimizer_opts));
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
@ -118,11 +126,11 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
graph_ctor_opts.expect_device_spec = false;
std::unique_ptr<Graph> graphptr(new Graph(function_library));
// Populate default attrs to the NodeDefs in the GraphDef.
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
*graphptr->op_registry(), 0));
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&graph_def, *graphptr->op_registry(), 0));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_ctor_opts, inlined_graph_def,
graphptr.get()));
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
// Optimize the graph.
GraphOptimizer optimizer(*optimizer_opts);

View File

@ -27,24 +27,22 @@ class MetaGraphDef;
namespace grappler {
struct ItemConfig {
ItemConfig()
: ignore_user_placement(true),
ignore_colocation(true),
placeholder_unknown_output_shape_dim(-1),
apply_optimizations(false),
inline_functions(false) {}
ItemConfig() {}
// If true, ignore all user specified node placement.
bool ignore_user_placement;
bool ignore_user_placement = true;
// If true, ignore all user specified colocation attributes.
bool ignore_colocation;
bool ignore_colocation = true;
// Dimension to use if a placeholder node has an _output_shapes attribute with
// a dimension of -1.
int placeholder_unknown_output_shape_dim;
int placeholder_unknown_output_shape_dim = -1;
// If true, does L1 optimizations.
bool apply_optimizations;
bool apply_optimizations = false;
// If true, does inlining.
bool inline_functions;
bool inline_functions = false;
// If true, erases all "_noinline" attributes from user-defined functions.
// Has no effect if "inline_functions" is disabled.
bool erase_noinline_attributes = false;
// If non-empty, override the directory of asset paths.
string assets_directory_override;
};