From 00983e6b0700f9a34e0df980669172b6ea85dad0 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Wed, 23 Oct 2019 15:56:09 -0700 Subject: [PATCH] Export the transform_graph functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 276369111 Change-Id: Ia5f3d23b9537ebae8edfd59c085cb63d095fc713 --- tensorflow/python/BUILD | 19 ++++- tensorflow/python/__init__.py | 1 + tensorflow/python/tensorflow.i | 2 - tensorflow/python/util/traceme.i | 1 + tensorflow/python/util/transform_graph.i | 85 ------------------- .../python/util/transform_graph_wrapper.cc | 74 ++++++++++++++++ .../tools/def_file_filter/symbols_pybind.txt | 3 + tensorflow/tools/graph_transforms/BUILD | 14 ++- tensorflow/tools/graph_transforms/__init__.py | 9 +- 9 files changed, 113 insertions(+), 95 deletions(-) delete mode 100644 tensorflow/python/util/transform_graph.i create mode 100644 tensorflow/python/util/transform_graph_wrapper.cc diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 07ebf8003c6..87a5a5b8442 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -105,6 +105,7 @@ py_library( ":_pywrap_scoped_annotation", ":_pywrap_stat_summarizer", ":_pywrap_tfprof", + ":_pywrap_transform_graph", ":_pywrap_util_port", ":_pywrap_utils", ":array_ops", @@ -577,6 +578,21 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_transform_graph", + srcs = ["util/transform_graph_wrapper.cc"], + hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"], + module_name = "_pywrap_transform_graph", + deps = [ + ":pybind11_status", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + filegroup( name = "py_exception_registry_hdr", srcs = [ @@ -952,6 +968,7 @@ py_library( ":_pywrap_scoped_annotation", ":_pywrap_stat_summarizer", ":_pywrap_tfprof", + ":_pywrap_transform_graph", ":_pywrap_util_port", ":_pywrap_utils", ":composite_tensor", @@ -5217,7 +5234,6 @@ tf_py_wrap_cc( "pywrap_tfe.i", "util/py_checkpoint_reader.i", "util/traceme.i", - "util/transform_graph.i", "//tensorflow/compiler/mlir/python:mlir.i", ], # add win_def_file for pywrap_tensorflow @@ -5324,6 +5340,7 @@ genrule( ":py_exception_registry", # py_exception_registry ":kernel_registry", "//tensorflow/lite/toco/python:toco_python_api", # toco + "//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph ], outs = ["pybind_symbol_target_libs_file.txt"], cmd = select({ diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 8fa3ac279a8..b593c2c280b 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -56,6 +56,7 @@ from tensorflow.python import _pywrap_py_exception_registry from tensorflow.python import _pywrap_kernel_registry from tensorflow.python import _pywrap_quantize_training from tensorflow.python import _pywrap_scoped_annotation +from tensorflow.python import _pywrap_transform_graph # Protocol buffers from tensorflow.core.framework.graph_pb2 import * diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index c3251b4e751..850df138f3b 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -35,8 +35,6 @@ limitations under the License. %include "tensorflow/python/platform/stacktrace_handler.i" -%include "tensorflow/python/util/transform_graph.i" - %include "tensorflow/python/grappler/cluster.i" %include "tensorflow/python/grappler/item.i" %include "tensorflow/python/grappler/tf_optimizer.i" diff --git a/tensorflow/python/util/traceme.i b/tensorflow/python/util/traceme.i index 7b1377cb645..519c270c7d3 100644 --- a/tensorflow/python/util/traceme.i +++ b/tensorflow/python/util/traceme.i @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +%include %include "tensorflow/python/lib/core/strings.i" %include "tensorflow/python/platform/base.i" diff --git a/tensorflow/python/util/transform_graph.i b/tensorflow/python/util/transform_graph.i deleted file mode 100644 index 6e979379e9c..00000000000 --- a/tensorflow/python/util/transform_graph.i +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%include -%include "tensorflow/python/lib/core/strings.i" -%include "tensorflow/python/platform/base.i" - -%{ -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/stat_summarizer.h" - -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/step_stats.pb.h" -#include "tensorflow/tools/graph_transforms/transform_graph.h" -%} - -%ignoreall - -%unignore tensorflow; -%unignore TransformGraphWithStringInputs; - - -%{ -string TransformGraphWithStringInputs(string graph_def_string, - string inputs_string, - string outputs_string, - string transforms_string, - TF_Status* out_status) { - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(graph_def_string)) { - Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument( - "Couldn't interpret input as a GraphDef")); - return ""; - } - - tensorflow::graph_transforms::TransformParameters params_list; - tensorflow::Status parse_status = - tensorflow::graph_transforms::ParseTransformParameters( - transforms_string, ¶ms_list); - if (!parse_status.ok()) { - tensorflow::Set_TF_Status_from_Status(out_status, parse_status); - return ""; - } - std::vector inputs = tensorflow::str_util::Split(inputs_string, ','); - std::vector outputs = - tensorflow::str_util::Split(outputs_string, ','); - - tensorflow::Status transform_status = - tensorflow::graph_transforms::TransformGraph( - inputs, outputs, params_list, &graph_def); - if (!transform_status.ok()) { - tensorflow::Set_TF_Status_from_Status(out_status, transform_status); - return ""; - } - string result; - if (!graph_def.SerializeToString(&result)) { - Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument( - "Couldn't serialize output as a GraphDef")); - return ""; - } - Set_TF_Status_from_Status(out_status, tensorflow::Status::OK()); - return result; -} -%} - - -string TransformGraphWithStringInputs(string graph_def_string, - string inputs_string, - string outputs_string, - string transforms_string, - TF_Status* out_status); - -%unignoreall diff --git a/tensorflow/python/util/transform_graph_wrapper.cc b/tensorflow/python/util/transform_graph_wrapper.cc new file mode 100644 index 00000000000..1859f0a2b5b --- /dev/null +++ b/tensorflow/python/util/transform_graph_wrapper.cc @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "include/pybind11/pybind11.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/python/lib/core/pybind11_status.h" +#include "tensorflow/tools/graph_transforms/transform_graph.h" + +namespace py = pybind11; + +namespace tensorflow { + +string TransformGraphWithStringInputs(string graph_def_string, + string inputs_string, + string outputs_string, + string transforms_string) { + GraphDef graph_def; + if (!graph_def.ParseFromString(graph_def_string)) { + MaybeRaiseFromStatus( + errors::InvalidArgument("Couldn't interpret input as a GraphDef")); + } + + graph_transforms::TransformParameters params_list; + Status parse_status = graph_transforms::ParseTransformParameters( + transforms_string, ¶ms_list); + if (!parse_status.ok()) { + MaybeRaiseFromStatus(parse_status); + } + std::vector inputs = str_util::Split(inputs_string, ','); + std::vector outputs = str_util::Split(outputs_string, ','); + + Status transform_status = graph_transforms::TransformGraph( + inputs, outputs, params_list, &graph_def); + if (!transform_status.ok()) { + MaybeRaiseFromStatus(transform_status); + } + string result; + if (!graph_def.SerializeToString(&result)) { + MaybeRaiseFromStatus( + errors::InvalidArgument("Couldn't serialize output as a GraphDef")); + } + return result; +} + +} // namespace tensorflow + +PYBIND11_MODULE(_pywrap_transform_graph, m) { + m.def( + "TransformGraphWithStringInputs", + [](const py::object graph_def_string, const py::object inputs_string, + const py::object outputs_string, const py::object transforms_string) { + return py::bytes(tensorflow::TransformGraphWithStringInputs( + graph_def_string.cast(), + inputs_string.cast(), + outputs_string.cast(), + transforms_string.cast())); + }); +}; diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 70f9099f1be..b1bb14fbac8 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -93,3 +93,6 @@ tensorflow::swig::TryFindKernelClass toco::TocoConvert toco::TocoGetPotentiallySupportedOps +[transform_graph_lib] # transform_graph +tensorflow::graph_transforms::TransformGraph +tensorflow::graph_transforms::ParseTransformParameters diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 0f5c298b48b..870edf5581a 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -207,6 +207,18 @@ tf_cc_test( ], ) +filegroup( + name = "transform_graph_hdrs", + srcs = [ + "transform_graph.h", + "transform_utils.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) + cc_library( name = "transform_graph_lib", srcs = ["transform_graph.cc"], @@ -316,8 +328,8 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:_pywrap_transform_graph", "//tensorflow/python:errors", - "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:util", ], ) diff --git a/tensorflow/tools/graph_transforms/__init__.py b/tensorflow/tools/graph_transforms/__init__.py index 943cd737ad1..8746567658e 100644 --- a/tensorflow/tools/graph_transforms/__init__.py +++ b/tensorflow/tools/graph_transforms/__init__.py @@ -19,8 +19,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import, line-too-long from tensorflow.core.framework import graph_pb2 -from tensorflow.python.framework import errors -from tensorflow.python.pywrap_tensorflow import TransformGraphWithStringInputs +from tensorflow.python._pywrap_transform_graph import TransformGraphWithStringInputs from tensorflow.python.util import compat @@ -45,10 +44,8 @@ def TransformGraph(input_graph_def, inputs, outputs, transforms): inputs_string = compat.as_bytes(",".join(inputs)) outputs_string = compat.as_bytes(",".join(outputs)) transforms_string = compat.as_bytes(" ".join(transforms)) - with errors.raise_exception_on_not_ok_status() as status: - output_graph_def_string = TransformGraphWithStringInputs( - input_graph_def_string, inputs_string, outputs_string, - transforms_string, status) + output_graph_def_string = TransformGraphWithStringInputs( + input_graph_def_string, inputs_string, outputs_string, transforms_string) output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) return output_graph_def