Add a flag to erase "_noinline" attribute to allow total inlining in Grappler.
PiperOrigin-RevId: 171722354
This commit is contained in:
parent
e74adb6709
commit
0ffb522f02
@ -74,7 +74,7 @@ void InitializeTensor(DataType type, Tensor* tensor) {
|
|||||||
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
|
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
|
||||||
// order to get the correct session options and environment, and performing the
|
// order to get the correct session options and environment, and performing the
|
||||||
// correct optimizations.
|
// correct optimizations.
|
||||||
Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
|
||||||
const ItemConfig& cfg) {
|
const ItemConfig& cfg) {
|
||||||
if (!cfg.apply_optimizations && !cfg.inline_functions) {
|
if (!cfg.apply_optimizations && !cfg.inline_functions) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -83,8 +83,16 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
|||||||
// Create a session option for a single GPU device.
|
// Create a session option for a single GPU device.
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
|
|
||||||
// Inline all functions.
|
// Make a local copy of graph def, because we need to change some things.
|
||||||
GraphDef inlined_graph_def(graph_def);
|
GraphDef graph_def(graph_def_arg);
|
||||||
|
|
||||||
|
if (cfg.inline_functions && cfg.erase_noinline_attributes) {
|
||||||
|
// TF optimizer doesn't inline functions with "_noinline" attribute,
|
||||||
|
// so let's go over the function library and erase it.
|
||||||
|
for (auto& func : *graph_def.mutable_library()->mutable_function()) {
|
||||||
|
func.mutable_attr()->erase("_noinline");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Instantiate all variables for function library runtime creation.
|
// Instantiate all variables for function library runtime creation.
|
||||||
std::vector<Device*> devices;
|
std::vector<Device*> devices;
|
||||||
@ -92,7 +100,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
|||||||
options, "/job:localhost/replica:0/task:0", &devices));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
|
std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
|
||||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||||
inlined_graph_def.library());
|
graph_def.library());
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
|
|
||||||
// Optimizer options: L1 and inlining. L1 is default.
|
// Optimizer options: L1 and inlining. L1 is default.
|
||||||
@ -108,7 +116,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
|||||||
// Create the function library runtime.
|
// Create the function library runtime.
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||||
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
|
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
|
||||||
inlined_graph_def.versions().producer(),
|
graph_def.versions().producer(),
|
||||||
&function_library, *optimizer_opts));
|
&function_library, *optimizer_opts));
|
||||||
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
|
FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
|
||||||
|
|
||||||
@ -118,11 +126,11 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
|||||||
graph_ctor_opts.expect_device_spec = false;
|
graph_ctor_opts.expect_device_spec = false;
|
||||||
std::unique_ptr<Graph> graphptr(new Graph(function_library));
|
std::unique_ptr<Graph> graphptr(new Graph(function_library));
|
||||||
// Populate default attrs to the NodeDefs in the GraphDef.
|
// Populate default attrs to the NodeDefs in the GraphDef.
|
||||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
|
TF_RETURN_IF_ERROR(
|
||||||
*graphptr->op_registry(), 0));
|
AddDefaultAttrsToGraphDef(&graph_def, *graphptr->op_registry(), 0));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_ctor_opts, inlined_graph_def,
|
TF_RETURN_IF_ERROR(
|
||||||
graphptr.get()));
|
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
|
||||||
|
|
||||||
// Optimize the graph.
|
// Optimize the graph.
|
||||||
GraphOptimizer optimizer(*optimizer_opts);
|
GraphOptimizer optimizer(*optimizer_opts);
|
||||||
|
@ -27,24 +27,22 @@ class MetaGraphDef;
|
|||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
struct ItemConfig {
|
struct ItemConfig {
|
||||||
ItemConfig()
|
ItemConfig() {}
|
||||||
: ignore_user_placement(true),
|
|
||||||
ignore_colocation(true),
|
|
||||||
placeholder_unknown_output_shape_dim(-1),
|
|
||||||
apply_optimizations(false),
|
|
||||||
inline_functions(false) {}
|
|
||||||
|
|
||||||
// If true, ignore all user specified node placement.
|
// If true, ignore all user specified node placement.
|
||||||
bool ignore_user_placement;
|
bool ignore_user_placement = true;
|
||||||
// If true, ignore all user specified colocation attributes.
|
// If true, ignore all user specified colocation attributes.
|
||||||
bool ignore_colocation;
|
bool ignore_colocation = true;
|
||||||
// Dimension to use if a placeholder node has an _output_shapes attribute with
|
// Dimension to use if a placeholder node has an _output_shapes attribute with
|
||||||
// a dimension of -1.
|
// a dimension of -1.
|
||||||
int placeholder_unknown_output_shape_dim;
|
int placeholder_unknown_output_shape_dim = -1;
|
||||||
// If true, does L1 optimizations.
|
// If true, does L1 optimizations.
|
||||||
bool apply_optimizations;
|
bool apply_optimizations = false;
|
||||||
// If true, does inlining.
|
// If true, does inlining.
|
||||||
bool inline_functions;
|
bool inline_functions = false;
|
||||||
|
// If true, erases all "_noinline" attributes from user-defined functions.
|
||||||
|
// Has no effect if "inline_functions" is disabled.
|
||||||
|
bool erase_noinline_attributes = false;
|
||||||
// If non-empty, override the directory of asset paths.
|
// If non-empty, override the directory of asset paths.
|
||||||
string assets_directory_override;
|
string assets_directory_override;
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user