tflite_convert: register custom op defs at the beginning of process

When the allow_custom_opdefs flag is given, the converter API should register
the given opdef to the TF global registry before loading graphs to make sure the
TF global registry is updated. After that, the corresponding fake op kernel will
be added since TF saved model importer logic does fetch kernels while loading.

PiperOrigin-RevId: 335504825
Change-Id: I8fc8881600605f8fbe3bfdac0f622f83ce99b9a8
This commit is contained in:
Jaesung Chung 2020-10-05 14:50:00 -07:00 committed by TensorFlower Gardener
parent d36cef3697
commit b34984f33a
10 changed files with 187 additions and 5 deletions

View File

@ -159,6 +159,19 @@ def mlir_sparsify(input_data_str):
return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str)
def register_custom_opdefs(custom_opdefs_list):
"""Register the given custom opdefs to the TensorFlow global op registry.
Args:
custom_opdefs_list: String representing the custom ops OpDefs that are
included in the GraphDef.
Returns:
True if the registration is successfully completed.
"""
return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list)
def toco_convert_protos(model_flags_str,
toco_flags_str,
input_data_str,

View File

@ -22,6 +22,11 @@ namespace python_utils {
int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) {
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(obj)) {
// const_cast<> is for CPython 3.7 finally adding const to the API.
*data = const_cast<char*>(PyUnicode_AsUTF8AndSize(obj, length));
return *data == nullptr ? -1 : 0;
}
return PyBytes_AsStringAndSize(obj, data, length);
#else
return PyString_AsStringAndSize(obj, data, length);

View File

@ -29,6 +29,7 @@ from six.moves import zip
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.lite.toco.logging import gen_html
from tensorflow.python import keras
@ -128,6 +129,10 @@ def _convert_tf1_model(flags):
Raises:
ValueError: Invalid flags.
"""
# Register custom opdefs before converter object creation.
if flags.custom_opdefs:
register_custom_opdefs(_parse_array(flags.custom_opdefs))
# Create converter.
converter = _get_tflite_converter(flags)
if flags.inference_type:
@ -176,8 +181,7 @@ def _convert_tf1_model(flags):
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
if flags.custom_opdefs:
converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access
if flags.target_ops:
ops_set_options = lite.OpsSet.get_options()
converter.target_spec.supported_ops = set()

View File

@ -22,7 +22,9 @@ import os
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.lite.python import tflite_convert
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.client import session
@ -31,6 +33,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import gfile
@ -179,6 +182,73 @@ class TfLiteConvertV1Test(TestModels):
flags_str = '--saved_model_dir={}'.format(saved_model_dir)
self._run(flags_str, should_succeed=True)
def _createSavedModelWithCustomOp(self):
custom_opdefs_str = (
'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
'\'Output\' type: DT_FLOAT}')
# Create a graph that has one add op.
new_graph = graph_pb2.GraphDef()
with ops.Graph().as_default():
with session.Session() as sess:
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
out_tensor = in_tensor + in_tensor
inputs = {'x': in_tensor}
outputs = {'z': out_tensor}
new_graph.CopyFrom(sess.graph_def)
# Rename Add op name to CustomAdd.
for node in new_graph.node:
if node.op.startswith('Add'):
node.op = 'CustomAdd'
del node.attr['T']
# Register custom op defs to import modified graph def.
register_custom_opdefs([custom_opdefs_str])
# Store saved model.
saved_model_dir = self._getFilepath('model')
with ops.Graph().as_default():
with session.Session() as sess:
import_graph_def(new_graph, name='')
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
return (saved_model_dir, custom_opdefs_str)
def testEnsureCustomOpdefsFlag(self):
saved_model_dir, _ = self._createSavedModelWithCustomOp()
# Ensure --custom_opdefs.
flags_str = ('--saved_model_dir={0} --allow_custom_ops '
'--experimental_new_converter'.format(saved_model_dir))
self._run(flags_str, should_succeed=False)
def testSavedModelWithCustomOpdefsFlag(self):
saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp()
# Valid conversion.
flags_str = (
'--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
'--experimental_new_converter'.format(saved_model_dir,
custom_opdefs_str))
self._run(flags_str, should_succeed=True)
def testSavedModelWithInvalidCustomOpdefsFlag(self):
saved_model_dir, _ = self._createSavedModelWithCustomOp()
invalid_custom_opdefs_str = (
'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
'output_arg: {name: \'Output\' type: DT_FLOAT}')
# Valid conversion.
flags_str = (
'--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
'--experimental_new_converter'.format(saved_model_dir,
invalid_custom_opdefs_str))
self._run(flags_str, should_succeed=False)
def testKerasFile(self):
keras_file = self._getKerasModelFile()
@ -269,9 +339,9 @@ class TfLiteConvertV1Test(TestModels):
'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
'attr : { name: \'nms_score_threshold\' type: \'float\'} '
'attr : { name: \'num_classes\' type: \'int\'} '
'attr : { name: \'w_scale\' type: \'int\'} '
'attr : { name: \'x_scale\' type: \'int\'} '
'attr : { name: \'y_scale\' type: \'int\'}')
'attr : { name: \'w_scale\' type: \'float\'} '
'attr : { name: \'x_scale\' type: \'float\'} '
'attr : { name: \'y_scale\' type: \'float\'}')
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2} --input_shapes={3} '

View File

@ -55,3 +55,8 @@ def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
def wrapped_experimental_mlir_sparsify(input_data_str):
"""Wraps experimental mlir sparsify model."""
return _pywrap_toco_api.ExperimentalMlirSparsifyModel(input_data_str)
def wrapped_register_custom_opdefs(custom_opdefs_list):
"""Wraps RegisterCustomOpdefs with lazy loader."""
return _pywrap_toco_api.RegisterCustomOpdefs(custom_opdefs_list)

View File

@ -37,6 +37,9 @@ cc_library(
deps = [
"@com_google_protobuf//:protobuf_headers",
"//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863
"//tensorflow/c:kernels",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",

View File

@ -20,10 +20,13 @@ limitations under the License.
#include <vector>
#include "google/protobuf/text_format.h"
#include "tensorflow/c/kernels.h"
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
#include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
@ -317,4 +320,69 @@ PyObject* MlirSparsifyModel(PyObject* data) {
builder.GetSize());
}
PyObject* RegisterCustomOpdefs(PyObject* list) {
if (!PyList_Check(list)) {
PyErr_SetString(PyExc_TypeError, "Expected list in argument");
return nullptr;
}
int64 size = PyList_Size(list);
for (int i = 0; i < size; ++i) {
// Get character array from Python object.
char* tf_opdefs;
Py_ssize_t len;
if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
&tf_opdefs, &len) == -1) {
PyErr_Format(PyExc_ValueError,
"Failed to convert Python string at index %d of custom op "
"defs argument",
i);
return nullptr;
}
// Parse op def from character array.
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
PyErr_Format(
PyExc_ValueError,
"Failed to parse opdefs at index %d of custom op defs argument: %s",
i, tf_opdefs);
return nullptr;
}
// Register extra opdefs to TensorFlow global op registry.
tensorflow::OpRegistry::Global()->Register(
[opdef](
tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
*op_reg_data = tensorflow::OpRegistrationData(opdef);
return tensorflow::Status::OK();
});
// Register the corresponding fake op kernel.
const char* node_name = opdef.name().c_str();
const char* op_name = opdef.name().c_str();
const char* device_name = "CPU";
static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
};
TF_KernelBuilder* builder =
TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
fake_compute_func, /*delete_func=*/nullptr);
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(node_name, builder, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
PyErr_Format(PyExc_ValueError,
"Failed to register fake op kernel at index %d of custom op "
"defs argument",
i);
return nullptr;
}
TF_DeleteStatus(status);
}
Py_RETURN_TRUE;
}
} // namespace toco

View File

@ -49,6 +49,9 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
// Sparsifies model to encode sparse tensors with proper format. Throws error if
// sparsification fails.
PyObject* MlirSparsifyModel(PyObject* data);
// Registers the given custom opdefs to TensorFlow global op registry.
PyObject* RegisterCustomOpdefs(PyObject* list);
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_

View File

@ -77,4 +77,14 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
R"pbdoc(
Returns a sparsified model.
)pbdoc");
m.def(
"RegisterCustomOpdefs",
[](py::object custom_opdefs_txt_raw) {
return tensorflow::PyoOrThrow(
toco::RegisterCustomOpdefs(custom_opdefs_txt_raw.ptr()));
},
py::arg("custom_opdefs_txt_raw"),
R"pbdoc(
Registers the given custom opdefs to the TensorFlow global op registry.
)pbdoc");
}

View File

@ -123,6 +123,7 @@ toco::TocoConvert
toco::TocoGetPotentiallySupportedOps
toco::MlirQuantizeModel
toco::MlirSparsifyModel
toco::RegisterCustomOpdefs
[transform_graph_lib] # transform_graph
tensorflow::graph_transforms::TransformGraph