Add flag to determine whether to do L1 optimizations and inline functions. Default is to do them. In tf_optimizer don't inline or do l1 optimizations.

PiperOrigin-RevId: 157673614
This commit is contained in:
A. Unique TensorFlower 2017-05-31 21:24:02 -07:00 committed by TensorFlower Gardener
parent 25bb504ccd
commit d9620cab82
4 changed files with 22 additions and 13 deletions

View File

@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include <unordered_map>
@ -70,7 +69,8 @@ void InitializeTensor(DataType type, Tensor* tensor) {
// of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
// order to get the correct session options and environment, and performing the
// correct optimizations.
Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def) {
Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
const ItemConfig& cfg) {
// Create a session option for a single GPU device.
SessionOptions options;
@ -94,7 +94,12 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def) {
// Optimizer options: L1 and inlining. L1 is default.
OptimizerOptions* optimizer_opts =
options.config.mutable_graph_options()->mutable_optimizer_options();
optimizer_opts->set_do_function_inlining(true);
if (cfg.apply_optimizations) {
optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L1);
} else {
optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0);
}
optimizer_opts->set_do_function_inlining(cfg.inline_functions);
// Create the function library runtime.
std::unique_ptr<FunctionLibraryRuntime> flib(NewFunctionLibraryRuntime(
@ -130,13 +135,11 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
new_item->graph = meta_graph.graph_def();
// Optimize the graph (function inlining, l1 optimizations, etc).
if (cfg.apply_optimizations) {
Status optimize_status =
OptimizeGraph(meta_graph.graph_def(), &new_item->graph);
if (!optimize_status.ok()) {
LOG(ERROR) << "Function optimization failed: " << optimize_status;
return nullptr;
}
Status optimize_status =
OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg);
if (!optimize_status.ok()) {
LOG(ERROR) << "Function optimization failed: " << optimize_status;
return nullptr;
}
// Attempt to detect the fetch node(s).

View File

@ -31,7 +31,8 @@ struct ItemConfig {
: ignore_user_placement(true),
ignore_colocation(true),
placeholder_unknown_output_shape_dim(-1),
apply_optimizations(true) {}
apply_optimizations(true),
inline_functions(true) {}
// If true, ignore all user specified node placement.
bool ignore_user_placement;
@ -40,8 +41,10 @@ struct ItemConfig {
// Dimension to use if a placeholder node has an _output_shapes attribute with
// a dimension of -1.
int placeholder_unknown_output_shape_dim;
// If true, does inlining and L1 optimizations.
// If true, does L1 optimizations.
bool apply_optimizations;
// If true, does inlining.
bool inline_functions;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.

View File

@ -70,6 +70,7 @@ std::unique_ptr<GrapplerItem> CreateGrapplerItem(const GraphDef &def,
const CollectionDef &fetches) {
MetaGraphDef meta_def;
ItemConfig cfg;
cfg.inline_functions = true;
*meta_def.mutable_graph_def() = def;
(*meta_def.mutable_collection_def())["train_op"] = fetches;
return GrapplerItemFromMetaGraphDef("0", meta_def, cfg);

View File

@ -67,7 +67,9 @@ PyObject* TF_OptimizeGraph(
const tensorflow::RewriterConfig& rewriter_config,
const tensorflow::MetaGraphDef& metagraph,
const string& graph_id, TF_Status* out_status) {
const tensorflow::grappler::ItemConfig item_config;
tensorflow::grappler::ItemConfig item_config;
item_config.inline_functions = false;
item_config.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
std::unordered_map<string, tensorflow::DeviceProperties> device_map;