Export the transform_graph functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 276369111 Change-Id: Ia5f3d23b9537ebae8edfd59c085cb63d095fc713
This commit is contained in:
parent
367e11c468
commit
00983e6b07
@ -105,6 +105,7 @@ py_library(
|
|||||||
":_pywrap_scoped_annotation",
|
":_pywrap_scoped_annotation",
|
||||||
":_pywrap_stat_summarizer",
|
":_pywrap_stat_summarizer",
|
||||||
":_pywrap_tfprof",
|
":_pywrap_tfprof",
|
||||||
|
":_pywrap_transform_graph",
|
||||||
":_pywrap_util_port",
|
":_pywrap_util_port",
|
||||||
":_pywrap_utils",
|
":_pywrap_utils",
|
||||||
":array_ops",
|
":array_ops",
|
||||||
@ -577,6 +578,21 @@ tf_python_pybind_extension(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_transform_graph",
|
||||||
|
srcs = ["util/transform_graph_wrapper.cc"],
|
||||||
|
hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
|
||||||
|
module_name = "_pywrap_transform_graph",
|
||||||
|
deps = [
|
||||||
|
":pybind11_status",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "py_exception_registry_hdr",
|
name = "py_exception_registry_hdr",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -952,6 +968,7 @@ py_library(
|
|||||||
":_pywrap_scoped_annotation",
|
":_pywrap_scoped_annotation",
|
||||||
":_pywrap_stat_summarizer",
|
":_pywrap_stat_summarizer",
|
||||||
":_pywrap_tfprof",
|
":_pywrap_tfprof",
|
||||||
|
":_pywrap_transform_graph",
|
||||||
":_pywrap_util_port",
|
":_pywrap_util_port",
|
||||||
":_pywrap_utils",
|
":_pywrap_utils",
|
||||||
":composite_tensor",
|
":composite_tensor",
|
||||||
@ -5217,7 +5234,6 @@ tf_py_wrap_cc(
|
|||||||
"pywrap_tfe.i",
|
"pywrap_tfe.i",
|
||||||
"util/py_checkpoint_reader.i",
|
"util/py_checkpoint_reader.i",
|
||||||
"util/traceme.i",
|
"util/traceme.i",
|
||||||
"util/transform_graph.i",
|
|
||||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||||
],
|
],
|
||||||
# add win_def_file for pywrap_tensorflow
|
# add win_def_file for pywrap_tensorflow
|
||||||
@ -5324,6 +5340,7 @@ genrule(
|
|||||||
":py_exception_registry", # py_exception_registry
|
":py_exception_registry", # py_exception_registry
|
||||||
":kernel_registry",
|
":kernel_registry",
|
||||||
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
||||||
|
"//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
|
||||||
],
|
],
|
||||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
|
@ -56,6 +56,7 @@ from tensorflow.python import _pywrap_py_exception_registry
|
|||||||
from tensorflow.python import _pywrap_kernel_registry
|
from tensorflow.python import _pywrap_kernel_registry
|
||||||
from tensorflow.python import _pywrap_quantize_training
|
from tensorflow.python import _pywrap_quantize_training
|
||||||
from tensorflow.python import _pywrap_scoped_annotation
|
from tensorflow.python import _pywrap_scoped_annotation
|
||||||
|
from tensorflow.python import _pywrap_transform_graph
|
||||||
|
|
||||||
# Protocol buffers
|
# Protocol buffers
|
||||||
from tensorflow.core.framework.graph_pb2 import *
|
from tensorflow.core.framework.graph_pb2 import *
|
||||||
|
@ -35,8 +35,6 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/platform/stacktrace_handler.i"
|
%include "tensorflow/python/platform/stacktrace_handler.i"
|
||||||
|
|
||||||
%include "tensorflow/python/util/transform_graph.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/grappler/cluster.i"
|
%include "tensorflow/python/grappler/cluster.i"
|
||||||
%include "tensorflow/python/grappler/item.i"
|
%include "tensorflow/python/grappler/item.i"
|
||||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
%include <std_string.i>
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
%include "tensorflow/python/lib/core/strings.i"
|
||||||
%include "tensorflow/python/platform/base.i"
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
|
||||||
|
@ -1,85 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include <std_string.i>
|
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/util/stat_summarizer.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
|
||||||
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
%ignoreall
|
|
||||||
|
|
||||||
%unignore tensorflow;
|
|
||||||
%unignore TransformGraphWithStringInputs;
|
|
||||||
|
|
||||||
|
|
||||||
%{
|
|
||||||
string TransformGraphWithStringInputs(string graph_def_string,
|
|
||||||
string inputs_string,
|
|
||||||
string outputs_string,
|
|
||||||
string transforms_string,
|
|
||||||
TF_Status* out_status) {
|
|
||||||
tensorflow::GraphDef graph_def;
|
|
||||||
if (!graph_def.ParseFromString(graph_def_string)) {
|
|
||||||
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
|
|
||||||
"Couldn't interpret input as a GraphDef"));
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::graph_transforms::TransformParameters params_list;
|
|
||||||
tensorflow::Status parse_status =
|
|
||||||
tensorflow::graph_transforms::ParseTransformParameters(
|
|
||||||
transforms_string, ¶ms_list);
|
|
||||||
if (!parse_status.ok()) {
|
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, parse_status);
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
std::vector<string> inputs = tensorflow::str_util::Split(inputs_string, ',');
|
|
||||||
std::vector<string> outputs =
|
|
||||||
tensorflow::str_util::Split(outputs_string, ',');
|
|
||||||
|
|
||||||
tensorflow::Status transform_status =
|
|
||||||
tensorflow::graph_transforms::TransformGraph(
|
|
||||||
inputs, outputs, params_list, &graph_def);
|
|
||||||
if (!transform_status.ok()) {
|
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, transform_status);
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
string result;
|
|
||||||
if (!graph_def.SerializeToString(&result)) {
|
|
||||||
Set_TF_Status_from_Status(out_status, tensorflow::errors::InvalidArgument(
|
|
||||||
"Couldn't serialize output as a GraphDef"));
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
|
|
||||||
|
|
||||||
string TransformGraphWithStringInputs(string graph_def_string,
|
|
||||||
string inputs_string,
|
|
||||||
string outputs_string,
|
|
||||||
string transforms_string,
|
|
||||||
TF_Status* out_status);
|
|
||||||
|
|
||||||
%unignoreall
|
|
74
tensorflow/python/util/transform_graph_wrapper.cc
Normal file
74
tensorflow/python/util/transform_graph_wrapper.cc
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
string TransformGraphWithStringInputs(string graph_def_string,
|
||||||
|
string inputs_string,
|
||||||
|
string outputs_string,
|
||||||
|
string transforms_string) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
if (!graph_def.ParseFromString(graph_def_string)) {
|
||||||
|
MaybeRaiseFromStatus(
|
||||||
|
errors::InvalidArgument("Couldn't interpret input as a GraphDef"));
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_transforms::TransformParameters params_list;
|
||||||
|
Status parse_status = graph_transforms::ParseTransformParameters(
|
||||||
|
transforms_string, ¶ms_list);
|
||||||
|
if (!parse_status.ok()) {
|
||||||
|
MaybeRaiseFromStatus(parse_status);
|
||||||
|
}
|
||||||
|
std::vector<string> inputs = str_util::Split(inputs_string, ',');
|
||||||
|
std::vector<string> outputs = str_util::Split(outputs_string, ',');
|
||||||
|
|
||||||
|
Status transform_status = graph_transforms::TransformGraph(
|
||||||
|
inputs, outputs, params_list, &graph_def);
|
||||||
|
if (!transform_status.ok()) {
|
||||||
|
MaybeRaiseFromStatus(transform_status);
|
||||||
|
}
|
||||||
|
string result;
|
||||||
|
if (!graph_def.SerializeToString(&result)) {
|
||||||
|
MaybeRaiseFromStatus(
|
||||||
|
errors::InvalidArgument("Couldn't serialize output as a GraphDef"));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_transform_graph, m) {
|
||||||
|
m.def(
|
||||||
|
"TransformGraphWithStringInputs",
|
||||||
|
[](const py::object graph_def_string, const py::object inputs_string,
|
||||||
|
const py::object outputs_string, const py::object transforms_string) {
|
||||||
|
return py::bytes(tensorflow::TransformGraphWithStringInputs(
|
||||||
|
graph_def_string.cast<std::string>(),
|
||||||
|
inputs_string.cast<std::string>(),
|
||||||
|
outputs_string.cast<std::string>(),
|
||||||
|
transforms_string.cast<std::string>()));
|
||||||
|
});
|
||||||
|
};
|
@ -93,3 +93,6 @@ tensorflow::swig::TryFindKernelClass
|
|||||||
toco::TocoConvert
|
toco::TocoConvert
|
||||||
toco::TocoGetPotentiallySupportedOps
|
toco::TocoGetPotentiallySupportedOps
|
||||||
|
|
||||||
|
[transform_graph_lib] # transform_graph
|
||||||
|
tensorflow::graph_transforms::TransformGraph
|
||||||
|
tensorflow::graph_transforms::ParseTransformParameters
|
||||||
|
@ -207,6 +207,18 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "transform_graph_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"transform_graph.h",
|
||||||
|
"transform_utils.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/core:__pkg__",
|
||||||
|
"//tensorflow/python:__pkg__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "transform_graph_lib",
|
name = "transform_graph_lib",
|
||||||
srcs = ["transform_graph.cc"],
|
srcs = ["transform_graph.cc"],
|
||||||
@ -316,8 +328,8 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:_pywrap_transform_graph",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,8 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python._pywrap_transform_graph import TransformGraphWithStringInputs
|
||||||
from tensorflow.python.pywrap_tensorflow import TransformGraphWithStringInputs
|
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
@ -45,10 +44,8 @@ def TransformGraph(input_graph_def, inputs, outputs, transforms):
|
|||||||
inputs_string = compat.as_bytes(",".join(inputs))
|
inputs_string = compat.as_bytes(",".join(inputs))
|
||||||
outputs_string = compat.as_bytes(",".join(outputs))
|
outputs_string = compat.as_bytes(",".join(outputs))
|
||||||
transforms_string = compat.as_bytes(" ".join(transforms))
|
transforms_string = compat.as_bytes(" ".join(transforms))
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
output_graph_def_string = TransformGraphWithStringInputs(
|
||||||
output_graph_def_string = TransformGraphWithStringInputs(
|
input_graph_def_string, inputs_string, outputs_string, transforms_string)
|
||||||
input_graph_def_string, inputs_string, outputs_string,
|
|
||||||
transforms_string, status)
|
|
||||||
output_graph_def = graph_pb2.GraphDef()
|
output_graph_def = graph_pb2.GraphDef()
|
||||||
output_graph_def.ParseFromString(output_graph_def_string)
|
output_graph_def.ParseFromString(output_graph_def_string)
|
||||||
return output_graph_def
|
return output_graph_def
|
||||||
|
Loading…
Reference in New Issue
Block a user