Export the toco functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
We are adding `toco::` to the exported namespaces for pywrap_tensorflow's shared object. A few downstream modules also require a previous import of pywrap tensorflow, because the wrapper depends on the shared library. See https://github.com/tensorflow/tensorflow/pull/31955 for additional information. PiperOrigin-RevId: 276096778 Change-Id: I042f488c36b00818b2344fb39c36cad97cee6eb8
This commit is contained in:
parent
8991cf5efa
commit
a07390fd17
@ -209,6 +209,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/python:_pywrap_toco_api",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
|
@ -17,7 +17,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
# We need to import pywrap_tensorflow prior to the toco wrapper.
|
||||||
|
# pylint: disable=invalud-import-order,g-bad-import-order
|
||||||
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
|
from tensorflow.python import _pywrap_toco_api
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/137402359): Remove lazy loading wrapper
|
# TODO(b/137402359): Remove lazy loading wrapper
|
||||||
|
|
||||||
@ -25,7 +29,7 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
|
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
|
||||||
debug_info_str, enable_mlir_converter):
|
debug_info_str, enable_mlir_converter):
|
||||||
"""Wraps TocoConvert with lazy loader."""
|
"""Wraps TocoConvert with lazy loader."""
|
||||||
return pywrap_tensorflow.TocoConvert(
|
return _pywrap_toco_api.TocoConvert(
|
||||||
model_flags_str,
|
model_flags_str,
|
||||||
toco_flags_str,
|
toco_flags_str,
|
||||||
input_data_str,
|
input_data_str,
|
||||||
@ -36,4 +40,4 @@ def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
|
|||||||
|
|
||||||
def wrapped_get_potentially_supported_ops():
|
def wrapped_get_potentially_supported_ops():
|
||||||
"""Wraps TocoGetPotentiallySupportedOps with lazy loader."""
|
"""Wraps TocoGetPotentiallySupportedOps with lazy loader."""
|
||||||
return pywrap_tensorflow.TocoGetPotentiallySupportedOps()
|
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
|
||||||
|
@ -16,6 +16,16 @@ config_setting(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "toco_python_api_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"toco_python_api.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/python:__subpackages__",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "toco_python_api",
|
name = "toco_python_api",
|
||||||
srcs = ["toco_python_api.cc"],
|
srcs = ["toco_python_api.cc"],
|
||||||
@ -49,6 +59,7 @@ cc_library(
|
|||||||
if_false = [],
|
if_false = [],
|
||||||
if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"],
|
if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"],
|
||||||
),
|
),
|
||||||
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compatibility stub. Remove when internal customers moved.
|
# Compatibility stub. Remove when internal customers moved.
|
||||||
@ -61,7 +72,7 @@ py_library(
|
|||||||
"//tensorflow/lite:__subpackages__",
|
"//tensorflow/lite:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:_pywrap_toco_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,6 +82,7 @@ py_binary(
|
|||||||
python_version = "PY2",
|
python_version = "PY2",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/python:_pywrap_toco_api",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
],
|
],
|
||||||
@ -89,10 +101,3 @@ tf_py_test(
|
|||||||
"no_pip",
|
"no_pip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files(
|
|
||||||
["toco.i"],
|
|
||||||
visibility = [
|
|
||||||
"//tensorflow/python:__subpackages__",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
@ -20,5 +20,5 @@ from __future__ import print_function
|
|||||||
# TODO(aselle): Remove once no clients internally depend on this.
|
# TODO(aselle): Remove once no clients internally depend on this.
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from tensorflow.python.pywrap_tensorflow import TocoConvert
|
from tensorflow.python._pywrap_toco_api import TocoConvert
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include "std_string.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
// The TensorFlow exception handler releases the GIL with
|
|
||||||
// Py_BEGIN_ALLOW_THREADS. Remove that because these function use the Python
|
|
||||||
// API to decode inputs.
|
|
||||||
%noexception toco::TocoConvert;
|
|
||||||
%noexception toco::TocoGetPotentiallySupportedOps;
|
|
||||||
|
|
||||||
namespace toco {
|
|
||||||
|
|
||||||
// Convert a model represented in `input_contents`. `model_flags_proto`
|
|
||||||
// describes model parameters. `toco_flags_proto` describes conversion
|
|
||||||
// parameters (see relevant .protos for more information). Returns a string
|
|
||||||
// representing the contents of the converted model. When extended_return
|
|
||||||
// flag is set to true returns a dictionary that contains string representation
|
|
||||||
// of the converted model and some statistics like arithmetic ops count.
|
|
||||||
// `debug_info_str` contains the `GraphDebugInfo` proto. When
|
|
||||||
// `enable_mlir_converter` is True, use MLIR-based conversion instead of
|
|
||||||
// TOCO conversion.
|
|
||||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|
||||||
PyObject* toco_flags_proto_txt_raw,
|
|
||||||
PyObject* input_contents_txt_raw,
|
|
||||||
bool extended_return = false,
|
|
||||||
PyObject* debug_info_txt_raw = nullptr,
|
|
||||||
bool enable_mlir_converter = false);
|
|
||||||
|
|
||||||
// Returns a list of names of all ops potentially supported by tflite.
|
|
||||||
PyObject* TocoGetPotentiallySupportedOps();
|
|
||||||
|
|
||||||
} // namespace toco
|
|
@ -19,7 +19,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
|
# We need to import pywrap_tensorflow prior to the toco wrapper.
|
||||||
|
# pylint: disable=invalud-import-order,g-bad-import-order
|
||||||
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
|
from tensorflow.python import _pywrap_toco_api
|
||||||
from tensorflow.python.platform import app
|
from tensorflow.python.platform import app
|
||||||
|
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
@ -43,7 +47,7 @@ def execute(unused_args):
|
|||||||
|
|
||||||
enable_mlir_converter = FLAGS.enable_mlir_converter
|
enable_mlir_converter = FLAGS.enable_mlir_converter
|
||||||
|
|
||||||
output_str = pywrap_tensorflow.TocoConvert(
|
output_str = _pywrap_toco_api.TocoConvert(
|
||||||
model_str,
|
model_str,
|
||||||
toco_str,
|
toco_str,
|
||||||
input_str,
|
input_str,
|
||||||
|
@ -583,6 +583,20 @@ tf_python_pybind_extension(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_toco_api",
|
||||||
|
srcs = [
|
||||||
|
"lite/toco_python_api_wrapper.cc",
|
||||||
|
],
|
||||||
|
hdrs = ["//tensorflow/lite/toco/python:toco_python_api_hdrs"],
|
||||||
|
module_name = "_pywrap_toco_api",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpp_python_util",
|
name = "cpp_python_util",
|
||||||
srcs = ["util/util.cc"],
|
srcs = ["util/util.cc"],
|
||||||
@ -5188,7 +5202,6 @@ tf_py_wrap_cc(
|
|||||||
"util/traceme.i",
|
"util/traceme.i",
|
||||||
"util/transform_graph.i",
|
"util/transform_graph.i",
|
||||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||||
"//tensorflow/lite/toco/python:toco.i",
|
|
||||||
],
|
],
|
||||||
# add win_def_file for pywrap_tensorflow
|
# add win_def_file for pywrap_tensorflow
|
||||||
win_def_file = select({
|
win_def_file = select({
|
||||||
@ -5292,6 +5305,7 @@ genrule(
|
|||||||
"//tensorflow/core:core_cpu_impl", # device_lib
|
"//tensorflow/core:core_cpu_impl", # device_lib
|
||||||
":py_exception_registry", # py_exception_registry
|
":py_exception_registry", # py_exception_registry
|
||||||
":kernel_registry",
|
":kernel_registry",
|
||||||
|
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
||||||
],
|
],
|
||||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
|
57
tensorflow/python/lite/toco_python_api_wrapper.cc
Normal file
57
tensorflow/python/lite/toco_python_api_wrapper.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_toco_api, m) {
|
||||||
|
m.def(
|
||||||
|
"TocoConvert",
|
||||||
|
[](py::object model_flags_proto_txt_raw,
|
||||||
|
py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw,
|
||||||
|
bool extended_return, py::object debug_info_txt_raw,
|
||||||
|
bool enable_mlir_converter) {
|
||||||
|
return tensorflow::pyo_or_throw(toco::TocoConvert(
|
||||||
|
model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(),
|
||||||
|
input_contents_txt_raw.ptr(), extended_return,
|
||||||
|
debug_info_txt_raw.ptr(), enable_mlir_converter));
|
||||||
|
},
|
||||||
|
py::arg("model_flags_proto_txt_raw"), py::arg("toco_flags_proto_txt_raw"),
|
||||||
|
py::arg("input_contents_txt_raw"), py::arg("extended_return") = false,
|
||||||
|
py::arg("debug_info_txt_raw") = py::none(),
|
||||||
|
py::arg("enable_mlir_converter") = false,
|
||||||
|
R"pbdoc(
|
||||||
|
Convert a model represented in `input_contents`. `model_flags_proto`
|
||||||
|
describes model parameters. `toco_flags_proto` describes conversion
|
||||||
|
parameters (see relevant .protos for more information). Returns a string
|
||||||
|
representing the contents of the converted model. When extended_return
|
||||||
|
flag is set to true returns a dictionary that contains string representation
|
||||||
|
of the converted model and some statistics like arithmetic ops count.
|
||||||
|
`debug_info_str` contains the `GraphDebugInfo` proto. When
|
||||||
|
`enable_mlir_converter` is True, tuse MLIR-based conversion instead of
|
||||||
|
TOCO conversion.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"TocoGetPotentiallySupportedOps",
|
||||||
|
[]() {
|
||||||
|
return tensorflow::pyo_or_throw(toco::TocoGetPotentiallySupportedOps());
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a list of names of all ops potentially supported by tflite.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -29,8 +29,6 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/lib/core/bfloat16.i"
|
%include "tensorflow/python/lib/core/bfloat16.i"
|
||||||
|
|
||||||
%include "tensorflow/lite/toco/python/toco.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/lib/io/file_io.i"
|
%include "tensorflow/python/lib/io/file_io.i"
|
||||||
|
|
||||||
%include "tensorflow/python/framework/python_op_gen.i"
|
%include "tensorflow/python/framework/python_op_gen.i"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
*tensorflow*
|
*tensorflow*
|
||||||
|
*toco*
|
||||||
*perftools*gputools*
|
*perftools*gputools*
|
||||||
*tf_*
|
*tf_*
|
||||||
*TF_*
|
*TF_*
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
tensorflow {
|
tensorflow {
|
||||||
global:
|
global:
|
||||||
*tensorflow*;
|
*tensorflow*;
|
||||||
|
*toco*;
|
||||||
*perftools*gputools*;
|
*perftools*gputools*;
|
||||||
*TF_*;
|
*TF_*;
|
||||||
*TFE_*;
|
*TFE_*;
|
||||||
|
@ -70,6 +70,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
|
|||||||
r"^(TFE_\w*)$|"
|
r"^(TFE_\w*)$|"
|
||||||
r"nsync::|"
|
r"nsync::|"
|
||||||
r"tensorflow::|"
|
r"tensorflow::|"
|
||||||
|
r"toco::|"
|
||||||
r"functor::|"
|
r"functor::|"
|
||||||
r"perftools::gputools")
|
r"perftools::gputools")
|
||||||
|
|
||||||
|
@ -83,3 +83,7 @@ tensorflow::PyExceptionRegistry::Lookup
|
|||||||
[kernel_registry] # kernel_registry
|
[kernel_registry] # kernel_registry
|
||||||
tensorflow::swig::TryFindKernelClass
|
tensorflow::swig::TryFindKernelClass
|
||||||
|
|
||||||
|
[toco_python_api] # toco_python_api
|
||||||
|
toco::TocoConvert
|
||||||
|
toco::TocoGetPotentiallySupportedOps
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user