Add flag to add ops to flex delegate allowlist

Usage:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
converter.target_spec.select_user_tf_ops = ["UserTFOpname1", "UserTFOpname2"]
PiperOrigin-RevId: 344001319
Change-Id: If414841bd7456ba8856df3630e94d991fc050319
This commit is contained in:
Thai Nguyen 2020-11-24 00:08:45 -08:00 committed by TensorFlower Gardener
parent bddfdecabf
commit ab953ab8c5
21 changed files with 507 additions and 37 deletions

View File

@ -667,6 +667,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",

View File

@ -387,19 +387,23 @@ class Translator {
// internal error.
static Optional<std::string> Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper);
private:
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper)
: module_(module),
name_mapper_(*op_or_arg_name_mapper),
builder_(kInitialBufferSize),
saved_model_tags_(saved_model_tags) {
saved_model_tags_(saved_model_tags),
select_user_tf_ops_(select_user_tf_ops) {
// The first buffer must be empty according to the schema definition.
empty_buffer_ = tflite::CreateBuffer(builder_);
buffers_.push_back(empty_buffer_);
@ -575,6 +579,8 @@ class Translator {
// Set of saved model tags, if any.
const std::unordered_set<std::string> saved_model_tags_;
// User's defined ops allowed with Flex.
const std::unordered_set<std::string> select_user_tf_ops_;
};
std::string Translator::UniqueName(mlir::Value val) {
@ -1104,12 +1110,15 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
resource_ops_.insert(node_def->op());
}
const bool is_allowed_flex_op =
IsAllowlistedFlexOp(node_def->op()) ||
((select_user_tf_ops_.count(node_def->op()) != 0) &&
(tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
// Flex op case
// Eventually, the allowlist will go away and we will rely on some TF op
// trait (e.g. No side effect) to determine if it is a supported "Flex"
// op or not.
if (enabled_op_types_.contains(OpType::kSelectTf) &&
IsAllowlistedFlexOp(node_def->op())) {
if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
// Construct ops as flex op encoding TensorFlow node definition
// as custom options.
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
@ -1160,7 +1169,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
}
// Insert failed op to `flex_ops` or `custom_ops`.
if (IsAllowlistedFlexOp(node_def->op())) {
if (is_allowed_flex_op) {
failed_flex_ops_.insert(os.str());
} else {
failed_custom_ops_.insert(os.str());
@ -1620,12 +1629,15 @@ bool UpdateEntryFunction(ModuleOp module) {
Optional<std::string> Translator::Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, tags, op_or_arg_name_mapper);
emit_custom_ops, select_user_tf_ops, tags,
op_or_arg_name_mapper);
return translator.TranslateInternal();
}
@ -1877,9 +1889,22 @@ bool tflite::MlirToFlatBufferTranslateFunction(
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
std::unordered_set<std::string> select_user_tf_ops;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated = Translator::Translate(
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
saved_model_tags, op_or_arg_name_mapper);
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
if (!maybe_translated) return true;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;

View File

@ -52,6 +52,14 @@ bool MlirToFlatBufferTranslateFunction(
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as the above but with a list of allowed user's defined ops.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include <ostream>
#include <unordered_set>
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
@ -288,6 +289,10 @@ Status ConvertMLIRToTFLiteFlatBuffer(
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
const std::unordered_set<std::string> select_user_tf_ops(
toco_flags.select_user_tf_ops().begin(),
toco_flags.select_user_tf_ops().end());
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
@ -307,8 +312,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
saved_model_tags, result, &pm);
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
pass_config.quant_specs, saved_model_tags, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.

View File

@ -242,7 +242,8 @@ int main(int argc, char **argv) {
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm);
emit_select_tf_ops, emit_custom_ops,
/*select_user_tf_ops=*/{}, quant_specs, tags, &result, &pm);
if (!status.ok()) return kTrFailure;
std::string error_msg;

View File

@ -137,6 +137,7 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const mlir::TFL::QuantizationSpecs& quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, mlir::PassManager* pass_manager) {
@ -169,10 +170,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
}
// Write MLIR TFLite dialect into FlatBuffer
OpOrArgLocNameMapper op_or_arg_name_mapper;
if (!quant_specs.RunWeightQuantization()) {
if (tflite::MlirToFlatBufferTranslateFunction(
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, saved_model_tags)) {
emit_custom_ops, select_user_tf_ops, saved_model_tags,
&op_or_arg_name_mapper)) {
return statusHandler.ConsumeStatus();
}
} else {
@ -181,7 +184,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
std::string pre_quantized_result;
if (tflite::MlirToFlatBufferTranslateFunction(
module, &pre_quantized_result, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, saved_model_tags)) {
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
saved_model_tags, &op_or_arg_name_mapper)) {
return statusHandler.ConsumeStatus();
}
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);

View File

@ -63,6 +63,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const mlir::TFL::QuantizationSpecs& quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, mlir::PassManager* pass_manager);

View File

@ -91,6 +91,32 @@ py_library(
],
)
py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
srcs_version = "PY2AND3",
deps = [
":lite",
":schema_util",
"//tensorflow/lite/tools:visualize",
"//tensorflow/python:framework",
],
)
py_test(
name = "test_util_test",
srcs = ["test_util_test.py"],
data = [
"//tensorflow/lite:testdata/add.bin",
"//tensorflow/lite:testdata/softplus_flex.bin",
],
python_version = "PY3",
deps = [
":test_util",
],
)
py_test(
name = "tflite_convert_test",
srcs = ["tflite_convert_test.py"],
@ -111,6 +137,7 @@ py_test(
],
deps = [
":convert",
":test_util",
":tflite_convert",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops",
@ -219,6 +246,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":lite",
":test_util",
"//tensorflow/lite/python/testdata:double_op",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
],

View File

@ -317,6 +317,7 @@ def build_toco_flags(inference_type=dtypes.float32,
dump_graphviz_video=False,
target_ops=None,
conversion_summary_dir=None,
select_user_tf_ops=None,
**_):
"""Build the TOCO flags object from params."""
toco = _toco_flags_pb2.TocoFlags()
@ -333,6 +334,8 @@ def build_toco_flags(inference_type=dtypes.float32,
toco.allow_custom_ops = allow_custom_ops
if custom_opdefs:
toco.custom_opdefs.extend(custom_opdefs)
if select_user_tf_ops:
toco.select_user_tf_ops.extend(select_user_tf_ops)
toco.post_training_quantize = post_training_quantize
toco.quantize_to_float16 = quantize_to_float16
if default_ranges_stats:
@ -376,7 +379,8 @@ def build_toco_convert_protos(input_tensors,
saved_model_dir=None,
saved_model_version=0,
saved_model_tags=None,
saved_model_exported_names=None):
saved_model_exported_names=None,
select_user_tf_ops=None):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@ -454,6 +458,9 @@ def build_toco_convert_protos(input_tensors,
saved_model_exported_names: Names to be exported (default: export all) when
the saved model import path is on. This value will be set only when the
SavedModel import path will be used.
select_user_tf_ops: List of user's defined TensorFlow ops need to be
supported in the TensorFlow Lite runtime. These ops will be supported as
select TensorFlow ops.
Returns:
model_flags, toco_flags, debug_info: three protocol buffers describing the
@ -472,7 +479,7 @@ def build_toco_convert_protos(input_tensors,
allow_custom_ops, custom_opdefs,
post_training_quantize, quantize_to_float16,
dump_graphviz_dir, dump_graphviz_video, target_ops,
conversion_summary_dir)
conversion_summary_dir, select_user_tf_ops)
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):

View File

@ -165,15 +165,29 @@ class TargetSpec(object):
supported_types: List of types for constant values on the target device.
Frequently, an optimization choice is driven by the most compact
(i.e. smallest) type in this list (default [tf.float32])
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
of user's TensorFlow operators' names that are required in the TensorFlow
Lite runtime. These ops will be exported as select TensorFlow ops in the
model (in conjunction with the OpsSet.SELECT_TF_OPS flag). This is an
advanced feature that should only be used if the client is using TF ops
that may not be linked in by default with the TF ops that are provided
when using the SELECT_TF_OPS path. The client is responsible for linking
these ops into the target runtime.
"""
def __init__(self, supported_ops=None, supported_types=None):
def __init__(self,
supported_ops=None,
supported_types=None,
experimental_select_user_tf_ops=None):
if supported_ops is None:
supported_ops = set([OpsSet.TFLITE_BUILTINS])
self.supported_ops = supported_ops
if supported_types is None:
supported_types = []
self.supported_types = supported_types
if experimental_select_user_tf_ops is None:
self.experimental_select_user_tf_ops = []
class QuantizationMode(object):
@ -482,6 +496,7 @@ class TFLiteConverterBase(object):
"debug_info": self._debug_info,
"target_ops": self.target_spec.supported_ops,
"enable_mlir_converter": self.experimental_new_converter,
"select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops,
}
if self.saved_model_dir:
@ -722,8 +737,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
# TODO(b/162537905): Clean these indirect dependencies.
self.saved_model_dir = None
return super(TFLiteSavedModelConverterV2,
self).convert(graph_def, input_tensors,
output_tensors)
self).convert(graph_def, input_tensors, output_tensors)
if self._trackable_obj is None:
self._debug_info = _get_debug_info(
@ -966,8 +980,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
dataset to evaluate different optimizations. Note that this is an optional
attribute but it is necessary if INT8 is the only support builtin ops in
target ops.
target_spec: Experimental flag, subject to change. Specification of target
device.
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
inference_input_type: Data type of the input layer. Note that integer types
(tf.int8 and tf.uint8) are currently only supported for post training
integer quantization and quantization aware training. (default tf.float32,
@ -1686,8 +1701,9 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
Set of OpsSet options indicating which converter to use. (default
set([OpsSet.TFLITE_BUILTINS]))
target_spec: Experimental flag, subject to change. Specification of target
device.
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
optimizations: Experimental flag, subject to change. A list of optimizations
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
representative_dataset: A representative dataset that can be used to

View File

@ -18,20 +18,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.lite.python import lite
from tensorflow.lite.python import test_util as tflite_test_util
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.lite.python.testdata import double_op
from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.tracking import tracking
@ -135,5 +143,87 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
self.assertTrue((expected_output == output_data).all())
class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _createGraphWithCustomOp(self, opname='CustomAdd'):
custom_opdefs_str = (
'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
'\'Output\' type: DT_FLOAT}')
# Create a graph that has one add op.
new_graph = graph_pb2.GraphDef()
with ops.Graph().as_default():
with session.Session() as sess:
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
out_tensor = in_tensor + in_tensor
inputs = {'x': in_tensor}
outputs = {'z': out_tensor}
new_graph.CopyFrom(sess.graph_def)
# Rename Add op name to opname.
for node in new_graph.node:
if node.op.startswith('Add'):
node.op = opname
del node.attr['T']
# Register custom op defs to import modified graph def.
register_custom_opdefs([custom_opdefs_str])
return (new_graph, inputs, outputs)
def testFlexWithCustomOp(self):
new_graph, inputs, outputs = self._createGraphWithCustomOp(
opname='CustomAdd4')
# Import to load the custom opdef.
saved_model_dir = os.path.join(self.get_temp_dir(), 'model')
with ops.Graph().as_default():
with session.Session() as sess:
import_graph_def(new_graph, name='')
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4']
tflite_model = converter.convert()
self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model))
def testFlexWithDoubleOp(self):
# Create a graph that has one double op.
saved_model_dir = os.path.join(self.get_temp_dir(), 'model2')
with ops.Graph().as_default():
with session.Session() as sess:
in_tensor = array_ops.placeholder(
shape=[1, 4], dtype=dtypes.int32, name='input')
out_tensor = double_op.double(in_tensor)
inputs = {'x': in_tensor}
outputs = {'z': out_tensor}
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
converter.target_spec.experimental_select_user_tf_ops = ['Double']
tflite_model = converter.convert()
self.assertTrue(tflite_model)
self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model))
# Check the model works with TensorFlow ops.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32)
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,42 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions used by multiple tflite test files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.python import schema_util
from tensorflow.lite.tools import visualize
def get_ops_list(model_data):
"""Return a set of ops in the tflite model data."""
model = schema_fb.Model.GetRootAsModel(model_data, 0)
op_set = set()
for subgraph_idx in range(model.SubgraphsLength()):
subgraph = model.Subgraphs(subgraph_idx)
for op_idx in range(subgraph.OperatorsLength()):
op = subgraph.Operators(op_idx)
opcode = model.OperatorCodes(op.OpcodeIndex())
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
if builtin_code == schema_fb.BuiltinOperator.CUSTOM:
opname = opcode.CustomCode().decode("utf-8")
op_set.add(opname)
else:
op_set.add(visualize.BuiltinCodeToName(builtin_code))
return op_set

View File

@ -0,0 +1,43 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for test_util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.python import test_util as tflite_test_util
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class TestUtilTest(test_util.TensorFlowTestCase):
def testBuiltinOp(self):
model_path = resource_loader.get_path_to_datafile('../testdata/add.bin')
op_set = tflite_test_util.get_ops_list(gfile.GFile(model_path, 'rb').read())
self.assertCountEqual(op_set, ['ADD'])
def testFlexOp(self):
model_path = resource_loader.get_path_to_datafile(
'../testdata/softplus_flex.bin')
op_set = tflite_test_util.get_ops_list(gfile.GFile(model_path, 'rb').read())
self.assertCountEqual(op_set, ['FlexSoftplus'])
if __name__ == '__main__':
test.main()

View File

@ -1,5 +1,6 @@
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
load("//tensorflow:tensorflow.bzl", "pybind_extension")
load("//tensorflow:tensorflow.bzl", "pybind_extension", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
package(
default_visibility = ["//tensorflow:internal"],
@ -86,6 +87,39 @@ cc_binary(
],
)
cc_library(
name = "double_op_and_kernels",
srcs = ["double_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
tf_custom_op_library(
name = "_double_op.so",
srcs = ["double_op.cc"],
)
tf_gen_op_wrapper_py(
name = "gen_double_op_wrapper",
out = "double_op_wrapper.py",
deps = [":double_op_and_kernels"],
)
tf_custom_op_py_library(
name = "double_op",
srcs = ["double_op.py"],
dso = [":_double_op.so"],
kernels = [":double_op_and_kernels"],
srcs_version = "PY2AND3",
deps = [
":gen_double_op_wrapper",
"//tensorflow/python:framework_for_generated_wrappers",
],
)
cc_library(
name = "test_registerer",
srcs = ["test_registerer.cc"],

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("Double")
.Input("input: int32")
.Output("doubled: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
class DoubleOp : public OpKernel {
public:
explicit DoubleOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input_flat = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input_flat.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 2 * input_flat(i);
}
}
};
REGISTER_KERNEL_BUILDER(Name("Double").Device(DEVICE_CPU), DoubleOp);
} // namespace tensorflow

View File

@ -0,0 +1,34 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Double op is a user's defined op for testing purpose."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.python.testdata import double_op_wrapper
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
_double_op = load_library.load_op_library(
resource_loader.get_path_to_datafile('_double_op.so'))
def double(input_tensor):
"""Double op applies element-wise double to input data."""
if input_tensor.dtype != dtypes.int32:
raise ValueError('Double op only accept int32 values.')
return double_op_wrapper.double(input_tensor)

View File

@ -192,6 +192,16 @@ def _convert_tf1_model(flags):
"{0}".format(",".join(ops_set_options)))
converter.target_spec.supported_ops.add(lite.OpsSet(option))
if flags.experimental_select_user_tf_ops:
if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
raise ValueError("--experimental_select_user_tf_ops can only be set if "
"--target_ops contains SELECT_TF_OPS.")
user_op_set = set()
for op_name in six.ensure_str(
flags.experimental_select_user_tf_ops).split(","):
user_op_set.add(op_name)
converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
if flags.post_training_quantize:
converter.optimizations = [lite.Optimize.DEFAULT]
if converter.inference_type != dtypes.float32:
@ -313,6 +323,10 @@ def _check_tf1_flags(flags, unparsed):
"--experimental_new_converter")
if flags.custom_opdefs and not flags.allow_custom_ops:
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
if (flags.experimental_select_user_tf_ops and
not flags.experimental_new_converter):
raise ValueError("--experimental_select_user_tf_ops must be used with "
"--experimental_new_converter")
def _check_tf2_flags(flags):
@ -491,6 +505,11 @@ def _get_tf1_flags(parser):
"indicating which converter to use. Options: {0}. One or more "
"option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
"".format(",".join(lite.OpsSet.get_options()))))
parser.add_argument(
"--experimental_select_user_tf_ops",
type=str,
help=("Experimental flag, subject to change. Comma separated list of "
"user's defined TensorFlow operators required in the runtime."))
# Logging flags.
parser.add_argument(

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow import keras
from tensorflow.core.framework import graph_pb2
from tensorflow.lite.python import test_util as tflite_test_util
from tensorflow.lite.python import tflite_convert
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.python import tf2
@ -50,7 +51,10 @@ class TestModels(test_util.TensorFlowTestCase):
def _getFilepath(self, filename):
return os.path.join(self.get_temp_dir(), filename)
def _run(self, flags_str, should_succeed):
def _run(self,
flags_str,
should_succeed,
expected_ops_in_converted_model=None):
output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
tflite_bin = resource_loader.get_path_to_datafile('tflite_convert')
cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file,
@ -61,6 +65,10 @@ class TestModels(test_util.TensorFlowTestCase):
with gfile.Open(output_file, 'rb') as model_file:
content = model_file.read()
self.assertEqual(content is not None, should_succeed)
if expected_ops_in_converted_model:
op_set = tflite_test_util.get_ops_list(content)
for opname in expected_ops_in_converted_model:
self.assertIn(opname, op_set)
os.remove(output_file)
else:
self.assertFalse(should_succeed)
@ -83,10 +91,14 @@ class TestModels(test_util.TensorFlowTestCase):
class TfLiteConvertV1Test(TestModels):
def _run(self, flags_str, should_succeed):
def _run(self,
flags_str,
should_succeed,
expected_ops_in_converted_model=None):
if tf2.enabled():
flags_str += ' --enable_v1_converter'
super(TfLiteConvertV1Test, self)._run(flags_str, should_succeed)
super(TfLiteConvertV1Test, self)._run(flags_str, should_succeed,
expected_ops_in_converted_model)
def testFrozenGraphDef(self):
with ops.Graph().as_default():
@ -186,8 +198,8 @@ class TfLiteConvertV1Test(TestModels):
# Define converter flags
flags_str = ('--std_dev_values=128,128 --mean_values=128,128 '
'--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2}'.format(
graph_def_file, 'inputA,inputB', 'output'))
'--output_arrays={2}'.format(graph_def_file, 'inputA,inputB',
'output'))
# Set inference_type UINT8 and (default) inference_input_type UINT8
flags_str_1 = flags_str + ' --inference_type=UINT8'
@ -213,9 +225,9 @@ class TfLiteConvertV1Test(TestModels):
flags_str = '--saved_model_dir={}'.format(saved_model_dir)
self._run(flags_str, should_succeed=True)
def _createSavedModelWithCustomOp(self):
def _createSavedModelWithCustomOp(self, opname='CustomAdd'):
custom_opdefs_str = (
'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
'\'Output\' type: DT_FLOAT}')
@ -231,10 +243,10 @@ class TfLiteConvertV1Test(TestModels):
new_graph.CopyFrom(sess.graph_def)
# Rename Add op name to CustomAdd.
# Rename Add op name to opname.
for node in new_graph.node:
if node.op.startswith('Add'):
node.op = 'CustomAdd'
node.op = opname
del node.attr['T']
# Register custom op defs to import modified graph def.
@ -264,7 +276,26 @@ class TfLiteConvertV1Test(TestModels):
'--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
'--experimental_new_converter'.format(saved_model_dir,
custom_opdefs_str))
self._run(flags_str, should_succeed=True)
self._run(
flags_str,
should_succeed=True,
expected_ops_in_converted_model=['CustomAdd'])
def testSavedModelWithFlex(self):
saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp(
opname='CustomAdd2')
# Valid conversion. OpDef already registered.
flags_str = ('--saved_model_dir={0} --allow_custom_ops '
'--custom_opdefs="{1}" '
'--experimental_new_converter '
'--experimental_select_user_tf_ops=CustomAdd2 '
'--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS'.format(
saved_model_dir, custom_opdefs_str))
self._run(
flags_str,
should_succeed=True,
expected_ops_in_converted_model=['FlexCustomAdd2'])
def testSavedModelWithInvalidCustomOpdefsFlag(self):
saved_model_dir, _ = self._createSavedModelWithCustomOp()
@ -393,7 +424,30 @@ class TfLiteConvertV1Test(TestModels):
# Valid conversion.
flags_str_final = ('{} --allow_custom_ops '
'--experimental_new_converter').format(flags_str)
self._run(flags_str_final, should_succeed=True)
self._run(
flags_str_final,
should_succeed=True,
expected_ops_in_converted_model=['TFLite_Detection_PostProcess'])
def testObjectDetectionMLIRWithFlex(self):
"""Tests object detection model through MLIR converter."""
self._initObjectDetectionArgs()
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2} --input_shapes={3}'.format(
self._graph_def_file, self._input_arrays,
self._output_arrays, self._input_shapes))
# Valid conversion.
flags_str_final = (
'{} --allow_custom_ops '
'--experimental_new_converter '
'--experimental_select_user_tf_ops=TFLite_Detection_PostProcess '
'--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS').format(flags_str)
self._run(
flags_str_final,
should_succeed=True,
expected_ops_in_converted_model=['FlexTFLite_Detection_PostProcess'])
class TfLiteConvertV2Test(TestModels):

View File

@ -38,7 +38,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
// Next ID to use: 33.
// Next ID to use: 34.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@ -226,4 +226,8 @@ message TocoFlags {
// String representing the custom ops OpDefs that are included in the
// GraphDef.
repeated string custom_opdefs = 32;
// Name of user's defined Tensorflow ops required in the TensorFlow Lite
// runtime. These ops will be supported as select TensorFlow ops.
repeated string select_user_tf_ops = 33;
}

View File

@ -4,6 +4,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\', \'experimental_select_user_tf_ops\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
}

View File

@ -4,6 +4,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\', \'experimental_select_user_tf_ops\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
}