Added a python API to the meta graph optimizer

Change: 154232702
This commit is contained in:
Benoit Steiner 2017-04-25 14:49:17 -08:00 committed by TensorFlower Gardener
parent 546befc408
commit 58fe576e20
7 changed files with 212 additions and 0 deletions

View File

@ -204,6 +204,7 @@ add_python_module("tensorflow/python/estimator/export")
add_python_module("tensorflow/python/estimator/inputs")
add_python_module("tensorflow/python/estimator/inputs/queues")
add_python_module("tensorflow/python/framework")
add_python_module("tensorflow/python/grappler")
add_python_module("tensorflow/python/kernel_tests")
add_python_module("tensorflow/python/layers")
add_python_module("tensorflow/python/lib")

View File

@ -2576,6 +2576,7 @@ tf_py_wrap_cc(
"client/tf_session.i",
"framework/cpp_shape_inference.i",
"framework/python_op_gen.i",
"grappler/tf_optimizer.i",
"lib/core/py_func.i",
"lib/core/strings.i",
"lib/io/file_io.i",
@ -2604,6 +2605,9 @@ tf_py_wrap_cc(
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core:lib",
"//tensorflow/core:reader_base",
"//tensorflow/core/debug",
@ -3539,3 +3543,28 @@ cuda_py_test(
],
main = "client/session_benchmark.py",
)
py_library(
name = "tf_optimizer",
srcs = [
"grappler/tf_optimizer.py",
],
srcs_version = "PY2AND3",
deps = [":pywrap_tensorflow_internal"],
)
py_test(
name = "tf_optimizer_test",
size = "small",
srcs = ["grappler/tf_optimizer_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"], # tf_optimizer is not available in pip.
deps = [
":client_testlib",
":framework_for_generated_wrappers",
":math_ops",
":tf_optimizer",
"//tensorflow/core:protos_all_py",
"//third_party/py/numpy",
],
)

View File

@ -0,0 +1,91 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The MetaGraphDef could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%typemap(in) const tensorflow::RewriterConfig& (
tensorflow::RewriterConfig temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The RewriterConfig could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include <memory>
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
PyObject* TF_OptimizeGraph(
const tensorflow::RewriterConfig& rewriter_config,
const tensorflow::MetaGraphDef& metagraph,
const string& graph_id, TF_Status* out_status) {
const tensorflow::grappler::ItemConfig item_config;
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
tensorflow::GraphDef out_graph;
tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer(
*grappler_item, rewriter_config, &out_graph);
tensorflow::Set_TF_Status_from_Status(out_status, status);
string out_graph_str = out_graph.SerializeAsString();
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
out_graph_str.size());
return ret;
}
%}
// Wrap this function
PyObject* TF_OptimizeGraph(
const tensorflow::RewriterConfig& rewriter_config,
const tensorflow::MetaGraphDef& metagraph,
const string& graph_id, TF_Status* out_status);

View File

@ -0,0 +1,35 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Provides a proper python API for the symbols exported through swig."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import graph_pb2
from tensorflow.python import pywrap_tensorflow as tf_opt
from tensorflow.python.framework import errors
def OptimizeGraph(rewriter_config, metagraph, graph_id=b'graph_to_optimize'):
"""Optimize the provided metagraph."""
with errors.raise_exception_on_not_ok_status() as status:
ret_from_swig = tf_opt.TF_OptimizeGraph(rewriter_config.SerializeToString(),
metagraph.SerializeToString(),
graph_id, status)
if ret_from_swig is None:
return None
out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
return out_graph

View File

@ -0,0 +1,53 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the swig wrapper tf_optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class PyWrapOptimizeGraphTest(test.TestCase):
def testBasic(self):
"""Make sure arguments can be passed correctly."""
a = constant_op.constant(10, name='a')
b = constant_op.constant(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers.append('constfold')
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), 5)
self.assertItemsEqual([node.name for node in graph.node],
['a', 'b', 'c', 'd', 'ConstantFolding/c'])
if __name__ == '__main__':
test.main()

View File

@ -40,3 +40,5 @@ limitations under the License.
%include "tensorflow/python/util/kernel_registry.i"
%include "tensorflow/python/util/transform_graph.i"
%include "tensorflow/python/grappler/tf_optimizer.i"

View File

@ -41,6 +41,7 @@ BLACKLIST = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
"//tensorflow:no_tensorflow_py_deps",
"//tensorflow/python:test_ops_2",
"//tensorflow/python:tf_optimizer",
"//tensorflow/python:compare_test_proto_py",
"//tensorflow/core:image_testdata",
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",