From f736991fd3a7987665a6f9fcd26d464ea7f68e2b Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Thu, 19 Jan 2017 14:21:59 -0800 Subject: [PATCH] Python wrapper for the Graph Transform Tool Change: 145009203 --- tensorflow/python/BUILD | 2 + tensorflow/python/tensorflow.i | 2 + tensorflow/python/util/transform_graph.i | 86 +++++++++++++++++++ tensorflow/tools/graph_transforms/BUILD | 23 +++++ tensorflow/tools/graph_transforms/__init__.py | 53 ++++++++++++ .../python/transform_graph_test.py | 85 ++++++++++++++++++ 6 files changed, 251 insertions(+) create mode 100644 tensorflow/python/util/transform_graph.i create mode 100644 tensorflow/tools/graph_transforms/__init__.py create mode 100644 tensorflow/tools/graph_transforms/python/transform_graph_test.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 626447d7833..e607cb5bc98 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2292,6 +2292,7 @@ tf_py_wrap_cc( "util/port.i", "util/py_checkpoint_reader.i", "util/stat_summarizer.i", + "util/transform_graph.i", ], deps = [ ":cpp_shape_inference", @@ -2310,6 +2311,7 @@ tf_py_wrap_cc( "//tensorflow/core:lib", "//tensorflow/core/debug", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/tools/graph_transforms:transform_graph_lib", "//tensorflow/tools/tfprof/internal:print_model_analysis", "//util/python:python_headers", ] + tf_additional_lib_deps() + tf_additional_plugin_deps(), diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 0f7deb7827e..09b4b20bcd2 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -38,3 +38,5 @@ limitations under the License. %include "tensorflow/python/framework/cpp_shape_inference.i" %include "tensorflow/python/util/kernel_registry.i" + +%include "tensorflow/python/util/transform_graph.i" diff --git a/tensorflow/python/util/transform_graph.i b/tensorflow/python/util/transform_graph.i new file mode 100644 index 00000000000..1d26c57e99b --- /dev/null +++ b/tensorflow/python/util/transform_graph.i @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include +%include "tensorflow/python/lib/core/strings.i" +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/stat_summarizer.h" +#include "tensorflow/python/lib/core/py_func.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/tools/graph_transforms/transform_graph.h" +%} + +%ignoreall + +%unignore tensorflow; +%unignore TransformGraphWithStringInputs; + + +%{ +string TransformGraphWithStringInputs(string graph_def_string, + string inputs_string, + string outputs_string, + string transforms_string, + TF_Status* out_status) { + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(graph_def_string)) { + Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument( + "Couldn't interpret input as a GraphDef")); + return ""; + } + + tensorflow::graph_transforms::TransformParameters params_list; + tensorflow::Status parse_status = + tensorflow::graph_transforms::ParseTransformParameters( + transforms_string, ¶ms_list); + if (!parse_status.ok()) { + tensorflow::Set_TF_Status_from_Status(out_status, parse_status); + return ""; + } + std::vector inputs = tensorflow::str_util::Split(inputs_string, ','); + std::vector outputs = + tensorflow::str_util::Split(outputs_string, ','); + + tensorflow::Status transform_status = + tensorflow::graph_transforms::TransformGraph( + inputs, outputs, params_list, &graph_def); + if (!transform_status.ok()) { + tensorflow::Set_TF_Status_from_Status(out_status, transform_status); + return ""; + } + string result; + if (!graph_def.SerializeToString(&result)) { + Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument( + "Couldn't serialize output as a GraphDef")); + return ""; + } + Set_TF_Status_from_Status(out_status, tensorflow::Status::OK()); + return result; +} +%} + + +string TransformGraphWithStringInputs(string graph_def_string, + string inputs_string, + string outputs_string, + string transforms_string, + TF_Status* out_status); + +%unignoreall diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 8bd18bec099..4b91da50798 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -9,6 +9,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_copts", "tf_cc_test", + "tf_py_test", ) exports_files(["LICENSE"]) @@ -234,3 +235,25 @@ cc_binary( "//tensorflow/core:lib", ], ) + +py_library( + name = "transform_graph_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python:pywrap_tensorflow"], +) + +tf_py_test( + name = "transform_graph_py_test", + size = "small", + srcs = ["python/transform_graph_test.py"], + additional_deps = [ + ":transform_graph_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + ], + main = "python/transform_graph_test.py", +) diff --git a/tensorflow/tools/graph_transforms/__init__.py b/tensorflow/tools/graph_transforms/__init__.py new file mode 100644 index 00000000000..38567443e18 --- /dev/null +++ b/tensorflow/tools/graph_transforms/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exposes the Python wrapper for graph transforms.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.pywrap_tensorflow import TransformGraphWithStringInputs + + +def TransformGraph(input_graph_def, inputs, outputs, transforms): + """Python wrapper for the Graph Transform Tool. + + Gives access to all graph transforms available through the command line tool. + See documentation at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md + for full details of the options available. + + Args: + input_graph_def: GraphDef object containing a model to be transformed. + inputs: List of node names for the model inputs. + outputs: List of node names for the model outputs. + transforms: List of strings containing transform names and parameters. + + Returns: + New GraphDef with transforms applied. + """ + + input_graph_def_string = input_graph_def.SerializeToString() + inputs_string = ",".join(inputs) + outputs_string = ",".join(outputs) + transforms_string = " ".join(transforms) + with errors.raise_exception_on_not_ok_status() as status: + output_graph_def_string = TransformGraphWithStringInputs( + input_graph_def_string, inputs_string, outputs_string, + transforms_string, status) + output_graph_def = graph_pb2.GraphDef() + output_graph_def.ParseFromString(output_graph_def_string) + return output_graph_def diff --git a/tensorflow/tools/graph_transforms/python/transform_graph_test.py b/tensorflow/tools/graph_transforms/python/transform_graph_test.py new file mode 100644 index 00000000000..0c19df2b3f9 --- /dev/null +++ b/tensorflow/tools/graph_transforms/python/transform_graph_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for StatSummarizer Python wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.platform import test +from tensorflow.tools.graph_transforms import TransformGraph + + +class TransformGraphTest(test.TestCase): + + # This test constructs a graph with a relu op that's not used by the normal + # inference path, and then tests that the strip_unused transform removes it as + # expected. + def testTransformGraph(self): + input_graph_def = graph_pb2.GraphDef() + + const_op1 = input_graph_def.node.add() + const_op1.op = "Const" + const_op1.name = "const_op1" + const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue( + type=dtypes.float32.as_datatype_enum)) + const_op1.attr["value"].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( + [1, 2], dtypes.float32, [1, 2]))) + + const_op2 = input_graph_def.node.add() + const_op2.op = "Const" + const_op2.name = "const_op2" + const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue( + type=dtypes.float32.as_datatype_enum)) + const_op2.attr["value"].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( + [3, 4], dtypes.float32, [1, 2]))) + + # Create an add that has two constants as inputs. + add_op = input_graph_def.node.add() + add_op.op = "Add" + add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue( + type=dtypes.float32.as_datatype_enum)) + add_op.name = "add_op" + add_op.input.extend(["const_op1", "const_op2"]) + + # Create a relu that reads from the add. + relu_op = input_graph_def.node.add() + relu_op.op = "Relu" + relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue( + type=dtypes.float32.as_datatype_enum)) + relu_op.name = "relu_op" + relu_op.input.extend(["add_op"]) + + # We're specifying that add_op is the final output, and so the relu isn't + # needed. + input_names = [] + output_names = ["add_op"] + transforms = ["strip_unused_nodes"] + transformed_graph_def = TransformGraph(input_graph_def, input_names, + output_names, transforms) + + # We expect that the relu is no longer present after running the transform. + for node in transformed_graph_def.node: + self.assertNotEqual("Relu", node.op) + + +if __name__ == "__main__": + test.main()