Added a python API to the meta graph optimizer
Change: 154232702
This commit is contained in:
parent
546befc408
commit
58fe576e20
tensorflow
contrib/cmake
python
tools/pip_package
@ -204,6 +204,7 @@ add_python_module("tensorflow/python/estimator/export")
|
||||
add_python_module("tensorflow/python/estimator/inputs")
|
||||
add_python_module("tensorflow/python/estimator/inputs/queues")
|
||||
add_python_module("tensorflow/python/framework")
|
||||
add_python_module("tensorflow/python/grappler")
|
||||
add_python_module("tensorflow/python/kernel_tests")
|
||||
add_python_module("tensorflow/python/layers")
|
||||
add_python_module("tensorflow/python/lib")
|
||||
|
@ -2576,6 +2576,7 @@ tf_py_wrap_cc(
|
||||
"client/tf_session.i",
|
||||
"framework/cpp_shape_inference.i",
|
||||
"framework/python_op_gen.i",
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/py_func.i",
|
||||
"lib/core/strings.i",
|
||||
"lib/io/file_io.i",
|
||||
@ -2604,6 +2605,9 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:grappler_item_builder",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:reader_base",
|
||||
"//tensorflow/core/debug",
|
||||
@ -3539,3 +3543,28 @@ cuda_py_test(
|
||||
],
|
||||
main = "client/session_benchmark.py",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_optimizer",
|
||||
srcs = [
|
||||
"grappler/tf_optimizer.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":pywrap_tensorflow_internal"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_optimizer_test",
|
||||
size = "small",
|
||||
srcs = ["grappler/tf_optimizer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"], # tf_optimizer is not available in pip.
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":framework_for_generated_wrappers",
|
||||
":math_ops",
|
||||
":tf_optimizer",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
91
tensorflow/python/grappler/tf_optimizer.i
Normal file
91
tensorflow/python/grappler/tf_optimizer.i
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(in) const tensorflow::RewriterConfig& (
|
||||
tensorflow::RewriterConfig temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The RewriterConfig could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include <memory>
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
PyObject* TF_OptimizeGraph(
|
||||
const tensorflow::RewriterConfig& rewriter_config,
|
||||
const tensorflow::MetaGraphDef& metagraph,
|
||||
const string& graph_id, TF_Status* out_status) {
|
||||
const tensorflow::grappler::ItemConfig item_config;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer(
|
||||
*grappler_item, rewriter_config, &out_graph);
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
string out_graph_str = out_graph.SerializeAsString();
|
||||
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
|
||||
out_graph_str.size());
|
||||
return ret;
|
||||
}
|
||||
%}
|
||||
|
||||
|
||||
// Wrap this function
|
||||
PyObject* TF_OptimizeGraph(
|
||||
const tensorflow::RewriterConfig& rewriter_config,
|
||||
const tensorflow::MetaGraphDef& metagraph,
|
||||
const string& graph_id, TF_Status* out_status);
|
||||
|
||||
|
||||
|
35
tensorflow/python/grappler/tf_optimizer.py
Normal file
35
tensorflow/python/grappler/tf_optimizer.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Provides a proper python API for the symbols exported through swig."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_opt
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
def OptimizeGraph(rewriter_config, metagraph, graph_id=b'graph_to_optimize'):
|
||||
"""Optimize the provided metagraph."""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
ret_from_swig = tf_opt.TF_OptimizeGraph(rewriter_config.SerializeToString(),
|
||||
metagraph.SerializeToString(),
|
||||
graph_id, status)
|
||||
if ret_from_swig is None:
|
||||
return None
|
||||
out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
|
||||
return out_graph
|
53
tensorflow/python/grappler/tf_optimizer_test.py
Normal file
53
tensorflow/python/grappler/tf_optimizer_test.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the swig wrapper tf_optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PyWrapOptimizeGraphTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
"""Make sure arguments can be passed correctly."""
|
||||
a = constant_op.constant(10, name='a')
|
||||
b = constant_op.constant(20, name='b')
|
||||
c = math_ops.add_n([a, b], name='c')
|
||||
d = math_ops.add_n([b, c], name='d')
|
||||
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||
train_op.append(d)
|
||||
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
|
||||
|
||||
rewriter_config = rewriter_config_pb2.RewriterConfig()
|
||||
rewriter_config.optimizers.append('constfold')
|
||||
|
||||
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
|
||||
|
||||
self.assertEqual(len(graph.node), 5)
|
||||
self.assertItemsEqual([node.name for node in graph.node],
|
||||
['a', 'b', 'c', 'd', 'ConstantFolding/c'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -40,3 +40,5 @@ limitations under the License.
|
||||
%include "tensorflow/python/util/kernel_registry.i"
|
||||
|
||||
%include "tensorflow/python/util/transform_graph.i"
|
||||
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
|
@ -41,6 +41,7 @@ BLACKLIST = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
"//tensorflow:no_tensorflow_py_deps",
|
||||
"//tensorflow/python:test_ops_2",
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
"//tensorflow/python:compare_test_proto_py",
|
||||
"//tensorflow/core:image_testdata",
|
||||
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
|
||||
|
Loading…
Reference in New Issue
Block a user