tflite_convert: register custom op defs at the beginning of process
When the allow_custom_opdefs flag is given, the converter API should register the given opdef to the TF global registry before loading graphs to make sure the TF global registry is updated. After that, the corresponding fake op kernel will be added since TF saved model importer logic does fetch kernels while loading. PiperOrigin-RevId: 335504825 Change-Id: I8fc8881600605f8fbe3bfdac0f622f83ce99b9a8
This commit is contained in:
parent
d36cef3697
commit
b34984f33a
tensorflow
lite
python
toco/python
python/lite
tools/def_file_filter
@ -159,6 +159,19 @@ def mlir_sparsify(input_data_str):
|
||||
return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str)
|
||||
|
||||
|
||||
def register_custom_opdefs(custom_opdefs_list):
|
||||
"""Register the given custom opdefs to the TensorFlow global op registry.
|
||||
|
||||
Args:
|
||||
custom_opdefs_list: String representing the custom ops OpDefs that are
|
||||
included in the GraphDef.
|
||||
|
||||
Returns:
|
||||
True if the registration is successfully completed.
|
||||
"""
|
||||
return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list)
|
||||
|
||||
|
||||
def toco_convert_protos(model_flags_str,
|
||||
toco_flags_str,
|
||||
input_data_str,
|
||||
|
@ -22,6 +22,11 @@ namespace python_utils {
|
||||
|
||||
int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyUnicode_Check(obj)) {
|
||||
// const_cast<> is for CPython 3.7 finally adding const to the API.
|
||||
*data = const_cast<char*>(PyUnicode_AsUTF8AndSize(obj, length));
|
||||
return *data == nullptr ? -1 : 0;
|
||||
}
|
||||
return PyBytes_AsStringAndSize(obj, data, length);
|
||||
#else
|
||||
return PyString_AsStringAndSize(obj, data, length);
|
||||
|
@ -29,6 +29,7 @@ from six.moves import zip
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||
from tensorflow.lite.toco.logging import gen_html
|
||||
from tensorflow.python import keras
|
||||
@ -128,6 +129,10 @@ def _convert_tf1_model(flags):
|
||||
Raises:
|
||||
ValueError: Invalid flags.
|
||||
"""
|
||||
# Register custom opdefs before converter object creation.
|
||||
if flags.custom_opdefs:
|
||||
register_custom_opdefs(_parse_array(flags.custom_opdefs))
|
||||
|
||||
# Create converter.
|
||||
converter = _get_tflite_converter(flags)
|
||||
if flags.inference_type:
|
||||
@ -176,8 +181,7 @@ def _convert_tf1_model(flags):
|
||||
|
||||
if flags.allow_custom_ops:
|
||||
converter.allow_custom_ops = flags.allow_custom_ops
|
||||
if flags.custom_opdefs:
|
||||
converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access
|
||||
|
||||
if flags.target_ops:
|
||||
ops_set_options = lite.OpsSet.get_options()
|
||||
converter.target_spec.supported_ops = set()
|
||||
|
@ -22,7 +22,9 @@ import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.lite.python import tflite_convert
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
@ -31,6 +33,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework.importer import import_graph_def
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -179,6 +182,73 @@ class TfLiteConvertV1Test(TestModels):
|
||||
flags_str = '--saved_model_dir={}'.format(saved_model_dir)
|
||||
self._run(flags_str, should_succeed=True)
|
||||
|
||||
def _createSavedModelWithCustomOp(self):
|
||||
custom_opdefs_str = (
|
||||
'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
|
||||
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
|
||||
'\'Output\' type: DT_FLOAT}')
|
||||
|
||||
# Create a graph that has one add op.
|
||||
new_graph = graph_pb2.GraphDef()
|
||||
with ops.Graph().as_default():
|
||||
with session.Session() as sess:
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
|
||||
out_tensor = in_tensor + in_tensor
|
||||
inputs = {'x': in_tensor}
|
||||
outputs = {'z': out_tensor}
|
||||
|
||||
new_graph.CopyFrom(sess.graph_def)
|
||||
|
||||
# Rename Add op name to CustomAdd.
|
||||
for node in new_graph.node:
|
||||
if node.op.startswith('Add'):
|
||||
node.op = 'CustomAdd'
|
||||
del node.attr['T']
|
||||
|
||||
# Register custom op defs to import modified graph def.
|
||||
register_custom_opdefs([custom_opdefs_str])
|
||||
|
||||
# Store saved model.
|
||||
saved_model_dir = self._getFilepath('model')
|
||||
with ops.Graph().as_default():
|
||||
with session.Session() as sess:
|
||||
import_graph_def(new_graph, name='')
|
||||
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
|
||||
return (saved_model_dir, custom_opdefs_str)
|
||||
|
||||
def testEnsureCustomOpdefsFlag(self):
|
||||
saved_model_dir, _ = self._createSavedModelWithCustomOp()
|
||||
|
||||
# Ensure --custom_opdefs.
|
||||
flags_str = ('--saved_model_dir={0} --allow_custom_ops '
|
||||
'--experimental_new_converter'.format(saved_model_dir))
|
||||
self._run(flags_str, should_succeed=False)
|
||||
|
||||
def testSavedModelWithCustomOpdefsFlag(self):
|
||||
saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp()
|
||||
|
||||
# Valid conversion.
|
||||
flags_str = (
|
||||
'--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
|
||||
'--experimental_new_converter'.format(saved_model_dir,
|
||||
custom_opdefs_str))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
|
||||
def testSavedModelWithInvalidCustomOpdefsFlag(self):
|
||||
saved_model_dir, _ = self._createSavedModelWithCustomOp()
|
||||
|
||||
invalid_custom_opdefs_str = (
|
||||
'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
|
||||
'output_arg: {name: \'Output\' type: DT_FLOAT}')
|
||||
|
||||
# Valid conversion.
|
||||
flags_str = (
|
||||
'--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
|
||||
'--experimental_new_converter'.format(saved_model_dir,
|
||||
invalid_custom_opdefs_str))
|
||||
self._run(flags_str, should_succeed=False)
|
||||
|
||||
def testKerasFile(self):
|
||||
keras_file = self._getKerasModelFile()
|
||||
|
||||
@ -269,9 +339,9 @@ class TfLiteConvertV1Test(TestModels):
|
||||
'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
|
||||
'attr : { name: \'nms_score_threshold\' type: \'float\'} '
|
||||
'attr : { name: \'num_classes\' type: \'int\'} '
|
||||
'attr : { name: \'w_scale\' type: \'int\'} '
|
||||
'attr : { name: \'x_scale\' type: \'int\'} '
|
||||
'attr : { name: \'y_scale\' type: \'int\'}')
|
||||
'attr : { name: \'w_scale\' type: \'float\'} '
|
||||
'attr : { name: \'x_scale\' type: \'float\'} '
|
||||
'attr : { name: \'y_scale\' type: \'float\'}')
|
||||
|
||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2} --input_shapes={3} '
|
||||
|
@ -55,3 +55,8 @@ def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
||||
def wrapped_experimental_mlir_sparsify(input_data_str):
|
||||
"""Wraps experimental mlir sparsify model."""
|
||||
return _pywrap_toco_api.ExperimentalMlirSparsifyModel(input_data_str)
|
||||
|
||||
|
||||
def wrapped_register_custom_opdefs(custom_opdefs_list):
|
||||
"""Wraps RegisterCustomOpdefs with lazy loader."""
|
||||
return _pywrap_toco_api.RegisterCustomOpdefs(custom_opdefs_list)
|
||||
|
@ -37,6 +37,9 @@ cc_library(
|
||||
deps = [
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
"//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863
|
||||
"//tensorflow/c:kernels",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
|
@ -20,10 +20,13 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "tensorflow/c/kernels.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
|
||||
#include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
@ -317,4 +320,69 @@ PyObject* MlirSparsifyModel(PyObject* data) {
|
||||
builder.GetSize());
|
||||
}
|
||||
|
||||
PyObject* RegisterCustomOpdefs(PyObject* list) {
|
||||
if (!PyList_Check(list)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected list in argument");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64 size = PyList_Size(list);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
// Get character array from Python object.
|
||||
char* tf_opdefs;
|
||||
Py_ssize_t len;
|
||||
if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
|
||||
&tf_opdefs, &len) == -1) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Failed to convert Python string at index %d of custom op "
|
||||
"defs argument",
|
||||
i);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Parse op def from character array.
|
||||
tensorflow::OpDef opdef;
|
||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
|
||||
PyErr_Format(
|
||||
PyExc_ValueError,
|
||||
"Failed to parse opdefs at index %d of custom op defs argument: %s",
|
||||
i, tf_opdefs);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Register extra opdefs to TensorFlow global op registry.
|
||||
tensorflow::OpRegistry::Global()->Register(
|
||||
[opdef](
|
||||
tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
|
||||
*op_reg_data = tensorflow::OpRegistrationData(opdef);
|
||||
return tensorflow::Status::OK();
|
||||
});
|
||||
|
||||
// Register the corresponding fake op kernel.
|
||||
const char* node_name = opdef.name().c_str();
|
||||
const char* op_name = opdef.name().c_str();
|
||||
const char* device_name = "CPU";
|
||||
static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
};
|
||||
|
||||
TF_KernelBuilder* builder =
|
||||
TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
|
||||
fake_compute_func, /*delete_func=*/nullptr);
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
TF_DeleteStatus(status);
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Failed to register fake op kernel at index %d of custom op "
|
||||
"defs argument",
|
||||
i);
|
||||
return nullptr;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -49,6 +49,9 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||
// Sparsifies model to encode sparse tensors with proper format. Throws error if
|
||||
// sparsification fails.
|
||||
PyObject* MlirSparsifyModel(PyObject* data);
|
||||
|
||||
// Registers the given custom opdefs to TensorFlow global op registry.
|
||||
PyObject* RegisterCustomOpdefs(PyObject* list);
|
||||
} // namespace toco
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
|
||||
|
@ -77,4 +77,14 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
|
||||
R"pbdoc(
|
||||
Returns a sparsified model.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"RegisterCustomOpdefs",
|
||||
[](py::object custom_opdefs_txt_raw) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
toco::RegisterCustomOpdefs(custom_opdefs_txt_raw.ptr()));
|
||||
},
|
||||
py::arg("custom_opdefs_txt_raw"),
|
||||
R"pbdoc(
|
||||
Registers the given custom opdefs to the TensorFlow global op registry.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -123,6 +123,7 @@ toco::TocoConvert
|
||||
toco::TocoGetPotentiallySupportedOps
|
||||
toco::MlirQuantizeModel
|
||||
toco::MlirSparsifyModel
|
||||
toco::RegisterCustomOpdefs
|
||||
|
||||
[transform_graph_lib] # transform_graph
|
||||
tensorflow::graph_transforms::TransformGraph
|
||||
|
Loading…
Reference in New Issue
Block a user