Export the transform_graph functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 276369111
Change-Id: Ia5f3d23b9537ebae8edfd59c085cb63d095fc713
This commit is contained in:
Amit Patankar 2019-10-23 15:56:09 -07:00 committed by TensorFlower Gardener
parent 367e11c468
commit 00983e6b07
9 changed files with 113 additions and 95 deletions

View File

@ -105,6 +105,7 @@ py_library(
":_pywrap_scoped_annotation",
":_pywrap_stat_summarizer",
":_pywrap_tfprof",
":_pywrap_transform_graph",
":_pywrap_util_port",
":_pywrap_utils",
":array_ops",
@ -577,6 +578,21 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_transform_graph",
srcs = ["util/transform_graph_wrapper.cc"],
hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
module_name = "_pywrap_transform_graph",
deps = [
":pybind11_status",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
filegroup(
name = "py_exception_registry_hdr",
srcs = [
@ -952,6 +968,7 @@ py_library(
":_pywrap_scoped_annotation",
":_pywrap_stat_summarizer",
":_pywrap_tfprof",
":_pywrap_transform_graph",
":_pywrap_util_port",
":_pywrap_utils",
":composite_tensor",
@ -5217,7 +5234,6 @@ tf_py_wrap_cc(
"pywrap_tfe.i",
"util/py_checkpoint_reader.i",
"util/traceme.i",
"util/transform_graph.i",
"//tensorflow/compiler/mlir/python:mlir.i",
],
# add win_def_file for pywrap_tensorflow
@ -5324,6 +5340,7 @@ genrule(
":py_exception_registry", # py_exception_registry
":kernel_registry",
"//tensorflow/lite/toco/python:toco_python_api", # toco
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({

View File

@ -56,6 +56,7 @@ from tensorflow.python import _pywrap_py_exception_registry
from tensorflow.python import _pywrap_kernel_registry
from tensorflow.python import _pywrap_quantize_training
from tensorflow.python import _pywrap_scoped_annotation
from tensorflow.python import _pywrap_transform_graph
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *

View File

@ -35,8 +35,6 @@ limitations under the License.
%include "tensorflow/python/platform/stacktrace_handler.i"
%include "tensorflow/python/util/transform_graph.i"
%include "tensorflow/python/grappler/cluster.i"
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include <std_string.i>
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"

View File

@ -1,85 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include <std_string.i>
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/tools/graph_transforms/transform_graph.h"
%}
%ignoreall
%unignore tensorflow;
%unignore TransformGraphWithStringInputs;
%{
string TransformGraphWithStringInputs(string graph_def_string,
string inputs_string,
string outputs_string,
string transforms_string,
TF_Status* out_status) {
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(graph_def_string)) {
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
"Couldn't interpret input as a GraphDef"));
return "";
}
tensorflow::graph_transforms::TransformParameters params_list;
tensorflow::Status parse_status =
tensorflow::graph_transforms::ParseTransformParameters(
transforms_string, &params_list);
if (!parse_status.ok()) {
tensorflow::Set_TF_Status_from_Status(out_status, parse_status);
return "";
}
std::vector<string> inputs = tensorflow::str_util::Split(inputs_string, ',');
std::vector<string> outputs =
tensorflow::str_util::Split(outputs_string, ',');
tensorflow::Status transform_status =
tensorflow::graph_transforms::TransformGraph(
inputs, outputs, params_list, &graph_def);
if (!transform_status.ok()) {
tensorflow::Set_TF_Status_from_Status(out_status, transform_status);
return "";
}
string result;
if (!graph_def.SerializeToString(&result)) {
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
"Couldn't serialize output as a GraphDef"));
return "";
}
Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
return result;
}
%}
string TransformGraphWithStringInputs(string graph_def_string,
string inputs_string,
string outputs_string,
string transforms_string,
TF_Status* out_status);
%unignoreall

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/tools/graph_transforms/transform_graph.h"
namespace py = pybind11;
namespace tensorflow {
string TransformGraphWithStringInputs(string graph_def_string,
string inputs_string,
string outputs_string,
string transforms_string) {
GraphDef graph_def;
if (!graph_def.ParseFromString(graph_def_string)) {
MaybeRaiseFromStatus(
errors::InvalidArgument("Couldn't interpret input as a GraphDef"));
}
graph_transforms::TransformParameters params_list;
Status parse_status = graph_transforms::ParseTransformParameters(
transforms_string, &params_list);
if (!parse_status.ok()) {
MaybeRaiseFromStatus(parse_status);
}
std::vector<string> inputs = str_util::Split(inputs_string, ',');
std::vector<string> outputs = str_util::Split(outputs_string, ',');
Status transform_status = graph_transforms::TransformGraph(
inputs, outputs, params_list, &graph_def);
if (!transform_status.ok()) {
MaybeRaiseFromStatus(transform_status);
}
string result;
if (!graph_def.SerializeToString(&result)) {
MaybeRaiseFromStatus(
errors::InvalidArgument("Couldn't serialize output as a GraphDef"));
}
return result;
}
} // namespace tensorflow
PYBIND11_MODULE(_pywrap_transform_graph, m) {
m.def(
"TransformGraphWithStringInputs",
[](const py::object graph_def_string, const py::object inputs_string,
const py::object outputs_string, const py::object transforms_string) {
return py::bytes(tensorflow::TransformGraphWithStringInputs(
graph_def_string.cast<std::string>(),
inputs_string.cast<std::string>(),
outputs_string.cast<std::string>(),
transforms_string.cast<std::string>()));
});
};

View File

@ -93,3 +93,6 @@ tensorflow::swig::TryFindKernelClass
toco::TocoConvert
toco::TocoGetPotentiallySupportedOps
[transform_graph_lib] # transform_graph
tensorflow::graph_transforms::TransformGraph
tensorflow::graph_transforms::ParseTransformParameters

View File

@ -207,6 +207,18 @@ tf_cc_test(
],
)
filegroup(
name = "transform_graph_hdrs",
srcs = [
"transform_graph.h",
"transform_utils.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "transform_graph_lib",
srcs = ["transform_graph.cc"],
@ -316,8 +328,8 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:_pywrap_transform_graph",
"//tensorflow/python:errors",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
],
)

View File

@ -19,8 +19,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.pywrap_tensorflow import TransformGraphWithStringInputs
from tensorflow.python._pywrap_transform_graph import TransformGraphWithStringInputs
from tensorflow.python.util import compat
@ -45,10 +44,8 @@ def TransformGraph(input_graph_def, inputs, outputs, transforms):
inputs_string = compat.as_bytes(",".join(inputs))
outputs_string = compat.as_bytes(",".join(outputs))
transforms_string = compat.as_bytes(" ".join(transforms))
with errors.raise_exception_on_not_ok_status() as status:
output_graph_def_string = TransformGraphWithStringInputs(
input_graph_def_string, inputs_string, outputs_string,
transforms_string, status)
output_graph_def_string = TransformGraphWithStringInputs(
input_graph_def_string, inputs_string, outputs_string, transforms_string)
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(output_graph_def_string)
return output_graph_def