Export the toco functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

We are adding `toco::` to the exported namespaces for pywrap_tensorflow's shared object. A few downstream modules also require a previous import of pywrap tensorflow, because the wrapper depends on the shared library. See https://github.com/tensorflow/tensorflow/pull/31955 for additional information.

PiperOrigin-RevId: 276096778
Change-Id: I042f488c36b00818b2344fb39c36cad97cee6eb8
This commit is contained in:
Amit Patankar 2019-10-22 10:44:39 -07:00 committed by TensorFlower Gardener
parent 8991cf5efa
commit a07390fd17
13 changed files with 107 additions and 66 deletions

View File

@ -209,6 +209,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:_pywrap_toco_api",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
],

View File

@ -17,7 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
# We need to import pywrap_tensorflow prior to the toco wrapper.
# pylint: disable=invalud-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python import _pywrap_toco_api
# TODO(b/137402359): Remove lazy loading wrapper
@ -25,7 +29,7 @@ from tensorflow.python import pywrap_tensorflow
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
debug_info_str, enable_mlir_converter):
"""Wraps TocoConvert with lazy loader."""
return pywrap_tensorflow.TocoConvert(
return _pywrap_toco_api.TocoConvert(
model_flags_str,
toco_flags_str,
input_data_str,
@ -36,4 +40,4 @@ def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
def wrapped_get_potentially_supported_ops():
"""Wraps TocoGetPotentiallySupportedOps with lazy loader."""
return pywrap_tensorflow.TocoGetPotentiallySupportedOps()
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()

View File

@ -16,6 +16,16 @@ config_setting(
],
)
filegroup(
name = "toco_python_api_hdrs",
srcs = [
"toco_python_api.h",
],
visibility = [
"//tensorflow/python:__subpackages__",
],
)
cc_library(
name = "toco_python_api",
srcs = ["toco_python_api.cc"],
@ -49,6 +59,7 @@ cc_library(
if_false = [],
if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"],
),
alwayslink = True,
)
# Compatibility stub. Remove when internal customers moved.
@ -61,7 +72,7 @@ py_library(
"//tensorflow/lite:__subpackages__",
],
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:_pywrap_toco_api",
],
)
@ -71,6 +82,7 @@ py_binary(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:_pywrap_toco_api",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
],
@ -89,10 +101,3 @@ tf_py_test(
"no_pip",
],
)
exports_files(
["toco.i"],
visibility = [
"//tensorflow/python:__subpackages__",
],
)

View File

@ -20,5 +20,5 @@ from __future__ import print_function
# TODO(aselle): Remove once no clients internally depend on this.
# pylint: disable=unused-import
from tensorflow.python.pywrap_tensorflow import TocoConvert
from tensorflow.python._pywrap_toco_api import TocoConvert
# pylint: enable=unused-import

View File

@ -1,49 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "std_string.i"
%{
#include "tensorflow/lite/toco/python/toco_python_api.h"
%}
// The TensorFlow exception handler releases the GIL with
// Py_BEGIN_ALLOW_THREADS. Remove that because these function use the Python
// API to decode inputs.
%noexception toco::TocoConvert;
%noexception toco::TocoGetPotentiallySupportedOps;
namespace toco {
// Convert a model represented in `input_contents`. `model_flags_proto`
// describes model parameters. `toco_flags_proto` describes conversion
// parameters (see relevant .protos for more information). Returns a string
// representing the contents of the converted model. When extended_return
// flag is set to true returns a dictionary that contains string representation
// of the converted model and some statistics like arithmetic ops count.
// `debug_info_str` contains the `GraphDebugInfo` proto. When
// `enable_mlir_converter` is True, use MLIR-based conversion instead of
// TOCO conversion.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw,
PyObject* input_contents_txt_raw,
bool extended_return = false,
PyObject* debug_info_txt_raw = nullptr,
bool enable_mlir_converter = false);
// Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps();
} // namespace toco

View File

@ -19,7 +19,11 @@ from __future__ import print_function
import argparse
import sys
from tensorflow.python import pywrap_tensorflow
# We need to import pywrap_tensorflow prior to the toco wrapper.
# pylint: disable=invalud-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python import _pywrap_toco_api
from tensorflow.python.platform import app
FLAGS = None
@ -43,7 +47,7 @@ def execute(unused_args):
enable_mlir_converter = FLAGS.enable_mlir_converter
output_str = pywrap_tensorflow.TocoConvert(
output_str = _pywrap_toco_api.TocoConvert(
model_str,
toco_str,
input_str,

View File

@ -583,6 +583,20 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_toco_api",
srcs = [
"lite/toco_python_api_wrapper.cc",
],
hdrs = ["//tensorflow/lite/toco/python:toco_python_api_hdrs"],
module_name = "_pywrap_toco_api",
deps = [
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
cc_library(
name = "cpp_python_util",
srcs = ["util/util.cc"],
@ -5188,7 +5202,6 @@ tf_py_wrap_cc(
"util/traceme.i",
"util/transform_graph.i",
"//tensorflow/compiler/mlir/python:mlir.i",
"//tensorflow/lite/toco/python:toco.i",
],
# add win_def_file for pywrap_tensorflow
win_def_file = select({
@ -5292,6 +5305,7 @@ genrule(
"//tensorflow/core:core_cpu_impl", # device_lib
":py_exception_registry", # py_exception_registry
":kernel_registry",
"//tensorflow/lite/toco/python:toco_python_api", # toco
],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({

View File

@ -0,0 +1,57 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "tensorflow/lite/toco/python/toco_python_api.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_toco_api, m) {
m.def(
"TocoConvert",
[](py::object model_flags_proto_txt_raw,
py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw,
bool extended_return, py::object debug_info_txt_raw,
bool enable_mlir_converter) {
return tensorflow::pyo_or_throw(toco::TocoConvert(
model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(),
input_contents_txt_raw.ptr(), extended_return,
debug_info_txt_raw.ptr(), enable_mlir_converter));
},
py::arg("model_flags_proto_txt_raw"), py::arg("toco_flags_proto_txt_raw"),
py::arg("input_contents_txt_raw"), py::arg("extended_return") = false,
py::arg("debug_info_txt_raw") = py::none(),
py::arg("enable_mlir_converter") = false,
R"pbdoc(
Convert a model represented in `input_contents`. `model_flags_proto`
describes model parameters. `toco_flags_proto` describes conversion
parameters (see relevant .protos for more information). Returns a string
representing the contents of the converted model. When extended_return
flag is set to true returns a dictionary that contains string representation
of the converted model and some statistics like arithmetic ops count.
`debug_info_str` contains the `GraphDebugInfo` proto. When
`enable_mlir_converter` is True, tuse MLIR-based conversion instead of
TOCO conversion.
)pbdoc");
m.def(
"TocoGetPotentiallySupportedOps",
[]() {
return tensorflow::pyo_or_throw(toco::TocoGetPotentiallySupportedOps());
},
R"pbdoc(
Returns a list of names of all ops potentially supported by tflite.
)pbdoc");
}

View File

@ -29,8 +29,6 @@ limitations under the License.
%include "tensorflow/python/lib/core/bfloat16.i"
%include "tensorflow/lite/toco/python/toco.i"
%include "tensorflow/python/lib/io/file_io.i"
%include "tensorflow/python/framework/python_op_gen.i"

View File

@ -1,4 +1,5 @@
*tensorflow*
*toco*
*perftools*gputools*
*tf_*
*TF_*

View File

@ -1,6 +1,7 @@
tensorflow {
global:
*tensorflow*;
*toco*;
*perftools*gputools*;
*TF_*;
*TFE_*;

View File

@ -70,6 +70,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
r"^(TFE_\w*)$|"
r"nsync::|"
r"tensorflow::|"
r"toco::|"
r"functor::|"
r"perftools::gputools")

View File

@ -83,3 +83,7 @@ tensorflow::PyExceptionRegistry::Lookup
[kernel_registry] # kernel_registry
tensorflow::swig::TryFindKernelClass
[toco_python_api] # toco_python_api
toco::TocoConvert
toco::TocoGetPotentiallySupportedOps