Export the transform_graph functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 276369111 Change-Id: Ia5f3d23b9537ebae8edfd59c085cb63d095fc713
This commit is contained in:
parent
367e11c468
commit
00983e6b07
@ -105,6 +105,7 @@ py_library(
|
||||
":_pywrap_scoped_annotation",
|
||||
":_pywrap_stat_summarizer",
|
||||
":_pywrap_tfprof",
|
||||
":_pywrap_transform_graph",
|
||||
":_pywrap_util_port",
|
||||
":_pywrap_utils",
|
||||
":array_ops",
|
||||
@ -577,6 +578,21 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_transform_graph",
|
||||
srcs = ["util/transform_graph_wrapper.cc"],
|
||||
hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
|
||||
module_name = "_pywrap_transform_graph",
|
||||
deps = [
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "py_exception_registry_hdr",
|
||||
srcs = [
|
||||
@ -952,6 +968,7 @@ py_library(
|
||||
":_pywrap_scoped_annotation",
|
||||
":_pywrap_stat_summarizer",
|
||||
":_pywrap_tfprof",
|
||||
":_pywrap_transform_graph",
|
||||
":_pywrap_util_port",
|
||||
":_pywrap_utils",
|
||||
":composite_tensor",
|
||||
@ -5217,7 +5234,6 @@ tf_py_wrap_cc(
|
||||
"pywrap_tfe.i",
|
||||
"util/py_checkpoint_reader.i",
|
||||
"util/traceme.i",
|
||||
"util/transform_graph.i",
|
||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||
],
|
||||
# add win_def_file for pywrap_tensorflow
|
||||
@ -5324,6 +5340,7 @@ genrule(
|
||||
":py_exception_registry", # py_exception_registry
|
||||
":kernel_registry",
|
||||
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
||||
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
|
||||
],
|
||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||
cmd = select({
|
||||
|
@ -56,6 +56,7 @@ from tensorflow.python import _pywrap_py_exception_registry
|
||||
from tensorflow.python import _pywrap_kernel_registry
|
||||
from tensorflow.python import _pywrap_quantize_training
|
||||
from tensorflow.python import _pywrap_scoped_annotation
|
||||
from tensorflow.python import _pywrap_transform_graph
|
||||
|
||||
# Protocol buffers
|
||||
from tensorflow.core.framework.graph_pb2 import *
|
||||
|
@ -35,8 +35,6 @@ limitations under the License.
|
||||
|
||||
%include "tensorflow/python/platform/stacktrace_handler.i"
|
||||
|
||||
%include "tensorflow/python/util/transform_graph.i"
|
||||
|
||||
%include "tensorflow/python/grappler/cluster.i"
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include <std_string.i>
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
|
@ -1,85 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include <std_string.i>
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore TransformGraphWithStringInputs;
|
||||
|
||||
|
||||
%{
|
||||
string TransformGraphWithStringInputs(string graph_def_string,
|
||||
string inputs_string,
|
||||
string outputs_string,
|
||||
string transforms_string,
|
||||
TF_Status* out_status) {
|
||||
tensorflow::GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(graph_def_string)) {
|
||||
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
|
||||
"Couldn't interpret input as a GraphDef"));
|
||||
return "";
|
||||
}
|
||||
|
||||
tensorflow::graph_transforms::TransformParameters params_list;
|
||||
tensorflow::Status parse_status =
|
||||
tensorflow::graph_transforms::ParseTransformParameters(
|
||||
transforms_string, ¶ms_list);
|
||||
if (!parse_status.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, parse_status);
|
||||
return "";
|
||||
}
|
||||
std::vector<string> inputs = tensorflow::str_util::Split(inputs_string, ',');
|
||||
std::vector<string> outputs =
|
||||
tensorflow::str_util::Split(outputs_string, ',');
|
||||
|
||||
tensorflow::Status transform_status =
|
||||
tensorflow::graph_transforms::TransformGraph(
|
||||
inputs, outputs, params_list, &graph_def);
|
||||
if (!transform_status.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, transform_status);
|
||||
return "";
|
||||
}
|
||||
string result;
|
||||
if (!graph_def.SerializeToString(&result)) {
|
||||
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
|
||||
"Couldn't serialize output as a GraphDef"));
|
||||
return "";
|
||||
}
|
||||
Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
|
||||
return result;
|
||||
}
|
||||
%}
|
||||
|
||||
|
||||
string TransformGraphWithStringInputs(string graph_def_string,
|
||||
string inputs_string,
|
||||
string outputs_string,
|
||||
string transforms_string,
|
||||
TF_Status* out_status);
|
||||
|
||||
%unignoreall
|
74
tensorflow/python/util/transform_graph_wrapper.cc
Normal file
74
tensorflow/python/util/transform_graph_wrapper.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
string TransformGraphWithStringInputs(string graph_def_string,
|
||||
string inputs_string,
|
||||
string outputs_string,
|
||||
string transforms_string) {
|
||||
GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(graph_def_string)) {
|
||||
MaybeRaiseFromStatus(
|
||||
errors::InvalidArgument("Couldn't interpret input as a GraphDef"));
|
||||
}
|
||||
|
||||
graph_transforms::TransformParameters params_list;
|
||||
Status parse_status = graph_transforms::ParseTransformParameters(
|
||||
transforms_string, ¶ms_list);
|
||||
if (!parse_status.ok()) {
|
||||
MaybeRaiseFromStatus(parse_status);
|
||||
}
|
||||
std::vector<string> inputs = str_util::Split(inputs_string, ',');
|
||||
std::vector<string> outputs = str_util::Split(outputs_string, ',');
|
||||
|
||||
Status transform_status = graph_transforms::TransformGraph(
|
||||
inputs, outputs, params_list, &graph_def);
|
||||
if (!transform_status.ok()) {
|
||||
MaybeRaiseFromStatus(transform_status);
|
||||
}
|
||||
string result;
|
||||
if (!graph_def.SerializeToString(&result)) {
|
||||
MaybeRaiseFromStatus(
|
||||
errors::InvalidArgument("Couldn't serialize output as a GraphDef"));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
PYBIND11_MODULE(_pywrap_transform_graph, m) {
|
||||
m.def(
|
||||
"TransformGraphWithStringInputs",
|
||||
[](const py::object graph_def_string, const py::object inputs_string,
|
||||
const py::object outputs_string, const py::object transforms_string) {
|
||||
return py::bytes(tensorflow::TransformGraphWithStringInputs(
|
||||
graph_def_string.cast<std::string>(),
|
||||
inputs_string.cast<std::string>(),
|
||||
outputs_string.cast<std::string>(),
|
||||
transforms_string.cast<std::string>()));
|
||||
});
|
||||
};
|
@ -93,3 +93,6 @@ tensorflow::swig::TryFindKernelClass
|
||||
toco::TocoConvert
|
||||
toco::TocoGetPotentiallySupportedOps
|
||||
|
||||
[transform_graph_lib] # transform_graph
|
||||
tensorflow::graph_transforms::TransformGraph
|
||||
tensorflow::graph_transforms::ParseTransformParameters
|
||||
|
@ -207,6 +207,18 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "transform_graph_hdrs",
|
||||
srcs = [
|
||||
"transform_graph.h",
|
||||
"transform_utils.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "transform_graph_lib",
|
||||
srcs = ["transform_graph.cc"],
|
||||
@ -316,8 +328,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:_pywrap_transform_graph",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
@ -19,8 +19,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.pywrap_tensorflow import TransformGraphWithStringInputs
|
||||
from tensorflow.python._pywrap_transform_graph import TransformGraphWithStringInputs
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -45,10 +44,8 @@ def TransformGraph(input_graph_def, inputs, outputs, transforms):
|
||||
inputs_string = compat.as_bytes(",".join(inputs))
|
||||
outputs_string = compat.as_bytes(",".join(outputs))
|
||||
transforms_string = compat.as_bytes(" ".join(transforms))
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
output_graph_def_string = TransformGraphWithStringInputs(
|
||||
input_graph_def_string, inputs_string, outputs_string,
|
||||
transforms_string, status)
|
||||
output_graph_def_string = TransformGraphWithStringInputs(
|
||||
input_graph_def_string, inputs_string, outputs_string, transforms_string)
|
||||
output_graph_def = graph_pb2.GraphDef()
|
||||
output_graph_def.ParseFromString(output_graph_def_string)
|
||||
return output_graph_def
|
||||
|
Loading…
Reference in New Issue
Block a user