Add flag to add ops to flex delegate allowlist
Usage: converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] converter.target_spec.select_user_tf_ops = ["UserTFOpname1", "UserTFOpname2"] PiperOrigin-RevId: 344001319 Change-Id: If414841bd7456ba8856df3630e94d991fc050319
This commit is contained in:
parent
bddfdecabf
commit
ab953ab8c5
@ -667,6 +667,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:logging",
|
||||
|
@ -387,19 +387,23 @@ class Translator {
|
||||
// internal error.
|
||||
static Optional<std::string> Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
|
||||
bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
private:
|
||||
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
|
||||
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper)
|
||||
: module_(module),
|
||||
name_mapper_(*op_or_arg_name_mapper),
|
||||
builder_(kInitialBufferSize),
|
||||
saved_model_tags_(saved_model_tags) {
|
||||
saved_model_tags_(saved_model_tags),
|
||||
select_user_tf_ops_(select_user_tf_ops) {
|
||||
// The first buffer must be empty according to the schema definition.
|
||||
empty_buffer_ = tflite::CreateBuffer(builder_);
|
||||
buffers_.push_back(empty_buffer_);
|
||||
@ -575,6 +579,8 @@ class Translator {
|
||||
|
||||
// Set of saved model tags, if any.
|
||||
const std::unordered_set<std::string> saved_model_tags_;
|
||||
// User's defined ops allowed with Flex.
|
||||
const std::unordered_set<std::string> select_user_tf_ops_;
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Value val) {
|
||||
@ -1104,12 +1110,15 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
resource_ops_.insert(node_def->op());
|
||||
}
|
||||
|
||||
const bool is_allowed_flex_op =
|
||||
IsAllowlistedFlexOp(node_def->op()) ||
|
||||
((select_user_tf_ops_.count(node_def->op()) != 0) &&
|
||||
(tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
|
||||
// Flex op case
|
||||
// Eventually, the allowlist will go away and we will rely on some TF op
|
||||
// trait (e.g. No side effect) to determine if it is a supported "Flex"
|
||||
// op or not.
|
||||
if (enabled_op_types_.contains(OpType::kSelectTf) &&
|
||||
IsAllowlistedFlexOp(node_def->op())) {
|
||||
if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
|
||||
// Construct ops as flex op encoding TensorFlow node definition
|
||||
// as custom options.
|
||||
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
||||
@ -1160,7 +1169,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
}
|
||||
|
||||
// Insert failed op to `flex_ops` or `custom_ops`.
|
||||
if (IsAllowlistedFlexOp(node_def->op())) {
|
||||
if (is_allowed_flex_op) {
|
||||
failed_flex_ops_.insert(os.str());
|
||||
} else {
|
||||
failed_custom_ops_.insert(os.str());
|
||||
@ -1620,12 +1629,15 @@ bool UpdateEntryFunction(ModuleOp module) {
|
||||
|
||||
Optional<std::string> Translator::Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
|
||||
bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
if (!UpdateEntryFunction(module)) return llvm::None;
|
||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, tags, op_or_arg_name_mapper);
|
||||
emit_custom_ops, select_user_tf_ops, tags,
|
||||
op_or_arg_name_mapper);
|
||||
return translator.TranslateInternal();
|
||||
}
|
||||
|
||||
@ -1877,9 +1889,22 @@ bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
std::unordered_set<std::string> select_user_tf_ops;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
||||
op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
auto maybe_translated = Translator::Translate(
|
||||
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
|
||||
saved_model_tags, op_or_arg_name_mapper);
|
||||
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
|
||||
if (!maybe_translated) return true;
|
||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||
return false;
|
||||
|
@ -52,6 +52,14 @@ bool MlirToFlatBufferTranslateFunction(
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
// Same as the above but with a list of allowed user's defined ops.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
@ -288,6 +289,10 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
|
||||
const std::unordered_set<std::string> select_user_tf_ops(
|
||||
toco_flags.select_user_tf_ops().begin(),
|
||||
toco_flags.select_user_tf_ops().end());
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
module.get(),
|
||||
@ -307,8 +312,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
|
||||
saved_model_tags, result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
|
||||
pass_config.quant_specs, saved_model_tags, result, &pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
|
@ -242,7 +242,8 @@ int main(int argc, char **argv) {
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops,
|
||||
/*select_user_tf_ops=*/{}, quant_specs, tags, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
std::string error_msg;
|
||||
|
@ -137,6 +137,7 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
std::string* result, mlir::PassManager* pass_manager) {
|
||||
@ -169,10 +170,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
}
|
||||
|
||||
// Write MLIR TFLite dialect into FlatBuffer
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
if (!quant_specs.RunWeightQuantization()) {
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, saved_model_tags)) {
|
||||
emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
||||
&op_or_arg_name_mapper)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
} else {
|
||||
@ -181,7 +184,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
std::string pre_quantized_result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, &pre_quantized_result, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, saved_model_tags)) {
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
|
||||
saved_model_tags, &op_or_arg_name_mapper)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
||||
|
@ -63,6 +63,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
std::string* result, mlir::PassManager* pass_manager);
|
||||
|
@ -91,6 +91,32 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":lite",
|
||||
":schema_util",
|
||||
"//tensorflow/lite/tools:visualize",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_util_test",
|
||||
srcs = ["test_util_test.py"],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/softplus_flex.bin",
|
||||
],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tflite_convert_test",
|
||||
srcs = ["tflite_convert_test.py"],
|
||||
@ -111,6 +137,7 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":convert",
|
||||
":test_util",
|
||||
":tflite_convert",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -219,6 +246,8 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":lite",
|
||||
":test_util",
|
||||
"//tensorflow/lite/python/testdata:double_op",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
|
@ -317,6 +317,7 @@ def build_toco_flags(inference_type=dtypes.float32,
|
||||
dump_graphviz_video=False,
|
||||
target_ops=None,
|
||||
conversion_summary_dir=None,
|
||||
select_user_tf_ops=None,
|
||||
**_):
|
||||
"""Build the TOCO flags object from params."""
|
||||
toco = _toco_flags_pb2.TocoFlags()
|
||||
@ -333,6 +334,8 @@ def build_toco_flags(inference_type=dtypes.float32,
|
||||
toco.allow_custom_ops = allow_custom_ops
|
||||
if custom_opdefs:
|
||||
toco.custom_opdefs.extend(custom_opdefs)
|
||||
if select_user_tf_ops:
|
||||
toco.select_user_tf_ops.extend(select_user_tf_ops)
|
||||
toco.post_training_quantize = post_training_quantize
|
||||
toco.quantize_to_float16 = quantize_to_float16
|
||||
if default_ranges_stats:
|
||||
@ -376,7 +379,8 @@ def build_toco_convert_protos(input_tensors,
|
||||
saved_model_dir=None,
|
||||
saved_model_version=0,
|
||||
saved_model_tags=None,
|
||||
saved_model_exported_names=None):
|
||||
saved_model_exported_names=None,
|
||||
select_user_tf_ops=None):
|
||||
"""Builds protocol buffers describing a conversion of a model using TOCO.
|
||||
|
||||
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
||||
@ -454,6 +458,9 @@ def build_toco_convert_protos(input_tensors,
|
||||
saved_model_exported_names: Names to be exported (default: export all) when
|
||||
the saved model import path is on. This value will be set only when the
|
||||
SavedModel import path will be used.
|
||||
select_user_tf_ops: List of user's defined TensorFlow ops need to be
|
||||
supported in the TensorFlow Lite runtime. These ops will be supported as
|
||||
select TensorFlow ops.
|
||||
|
||||
Returns:
|
||||
model_flags, toco_flags, debug_info: three protocol buffers describing the
|
||||
@ -472,7 +479,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
allow_custom_ops, custom_opdefs,
|
||||
post_training_quantize, quantize_to_float16,
|
||||
dump_graphviz_dir, dump_graphviz_video, target_ops,
|
||||
conversion_summary_dir)
|
||||
conversion_summary_dir, select_user_tf_ops)
|
||||
model = _model_flags_pb2.ModelFlags()
|
||||
model.change_concat_input_ranges = change_concat_input_ranges
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
|
@ -165,15 +165,29 @@ class TargetSpec(object):
|
||||
supported_types: List of types for constant values on the target device.
|
||||
Frequently, an optimization choice is driven by the most compact
|
||||
(i.e. smallest) type in this list (default [tf.float32])
|
||||
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
|
||||
of user's TensorFlow operators' names that are required in the TensorFlow
|
||||
Lite runtime. These ops will be exported as select TensorFlow ops in the
|
||||
model (in conjunction with the OpsSet.SELECT_TF_OPS flag). This is an
|
||||
advanced feature that should only be used if the client is using TF ops
|
||||
that may not be linked in by default with the TF ops that are provided
|
||||
when using the SELECT_TF_OPS path. The client is responsible for linking
|
||||
these ops into the target runtime.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, supported_ops=None, supported_types=None):
|
||||
def __init__(self,
|
||||
supported_ops=None,
|
||||
supported_types=None,
|
||||
experimental_select_user_tf_ops=None):
|
||||
if supported_ops is None:
|
||||
supported_ops = set([OpsSet.TFLITE_BUILTINS])
|
||||
self.supported_ops = supported_ops
|
||||
if supported_types is None:
|
||||
supported_types = []
|
||||
self.supported_types = supported_types
|
||||
if experimental_select_user_tf_ops is None:
|
||||
self.experimental_select_user_tf_ops = []
|
||||
|
||||
|
||||
class QuantizationMode(object):
|
||||
@ -482,6 +496,7 @@ class TFLiteConverterBase(object):
|
||||
"debug_info": self._debug_info,
|
||||
"target_ops": self.target_spec.supported_ops,
|
||||
"enable_mlir_converter": self.experimental_new_converter,
|
||||
"select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops,
|
||||
}
|
||||
|
||||
if self.saved_model_dir:
|
||||
@ -722,8 +737,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
|
||||
# TODO(b/162537905): Clean these indirect dependencies.
|
||||
self.saved_model_dir = None
|
||||
return super(TFLiteSavedModelConverterV2,
|
||||
self).convert(graph_def, input_tensors,
|
||||
output_tensors)
|
||||
self).convert(graph_def, input_tensors, output_tensors)
|
||||
|
||||
if self._trackable_obj is None:
|
||||
self._debug_info = _get_debug_info(
|
||||
@ -966,8 +980,9 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
||||
dataset to evaluate different optimizations. Note that this is an optional
|
||||
attribute but it is necessary if INT8 is the only support builtin ops in
|
||||
target ops.
|
||||
target_spec: Experimental flag, subject to change. Specification of target
|
||||
device.
|
||||
target_spec: Experimental flag, subject to change. Specifications of target
|
||||
device, including supported ops set, supported types and a set of user's
|
||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||
inference_input_type: Data type of the input layer. Note that integer types
|
||||
(tf.int8 and tf.uint8) are currently only supported for post training
|
||||
integer quantization and quantization aware training. (default tf.float32,
|
||||
@ -1686,8 +1701,9 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
|
||||
Set of OpsSet options indicating which converter to use. (default
|
||||
set([OpsSet.TFLITE_BUILTINS]))
|
||||
target_spec: Experimental flag, subject to change. Specification of target
|
||||
device.
|
||||
target_spec: Experimental flag, subject to change. Specifications of target
|
||||
device, including supported ops set, supported types and a set of user's
|
||||
defined TensorFlow operators required in the TensorFlow Lite runtime.
|
||||
optimizations: Experimental flag, subject to change. A list of optimizations
|
||||
to apply when converting the model. E.g. `[Optimize.DEFAULT]`
|
||||
representative_dataset: A representative dataset that can be used to
|
||||
|
@ -18,20 +18,28 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import test_util as tflite_test_util
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.lite.python.testdata import double_op
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework.importer import import_graph_def
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
@ -135,5 +143,87 @@ class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
|
||||
class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def _createGraphWithCustomOp(self, opname='CustomAdd'):
|
||||
custom_opdefs_str = (
|
||||
'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
|
||||
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
|
||||
'\'Output\' type: DT_FLOAT}')
|
||||
|
||||
# Create a graph that has one add op.
|
||||
new_graph = graph_pb2.GraphDef()
|
||||
with ops.Graph().as_default():
|
||||
with session.Session() as sess:
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
|
||||
out_tensor = in_tensor + in_tensor
|
||||
inputs = {'x': in_tensor}
|
||||
outputs = {'z': out_tensor}
|
||||
|
||||
new_graph.CopyFrom(sess.graph_def)
|
||||
|
||||
# Rename Add op name to opname.
|
||||
for node in new_graph.node:
|
||||
if node.op.startswith('Add'):
|
||||
node.op = opname
|
||||
del node.attr['T']
|
||||
|
||||
# Register custom op defs to import modified graph def.
|
||||
register_custom_opdefs([custom_opdefs_str])
|
||||
|
||||
return (new_graph, inputs, outputs)
|
||||
|
||||
def testFlexWithCustomOp(self):
|
||||
new_graph, inputs, outputs = self._createGraphWithCustomOp(
|
||||
opname='CustomAdd4')
|
||||
|
||||
# Import to load the custom opdef.
|
||||
saved_model_dir = os.path.join(self.get_temp_dir(), 'model')
|
||||
with ops.Graph().as_default():
|
||||
with session.Session() as sess:
|
||||
import_graph_def(new_graph, name='')
|
||||
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
|
||||
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
|
||||
converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4']
|
||||
tflite_model = converter.convert()
|
||||
|
||||
self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model))
|
||||
|
||||
def testFlexWithDoubleOp(self):
|
||||
# Create a graph that has one double op.
|
||||
saved_model_dir = os.path.join(self.get_temp_dir(), 'model2')
|
||||
with ops.Graph().as_default():
|
||||
with session.Session() as sess:
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 4], dtype=dtypes.int32, name='input')
|
||||
out_tensor = double_op.double(in_tensor)
|
||||
inputs = {'x': in_tensor}
|
||||
outputs = {'z': out_tensor}
|
||||
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
|
||||
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
|
||||
converter.target_spec.experimental_select_user_tf_ops = ['Double']
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model))
|
||||
|
||||
# Check the model works with TensorFlow ops.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32)
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32)
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
42
tensorflow/lite/python/test_util.py
Normal file
42
tensorflow/lite/python/test_util.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions used by multiple tflite test files."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
from tensorflow.lite.python import schema_util
|
||||
from tensorflow.lite.tools import visualize
|
||||
|
||||
|
||||
def get_ops_list(model_data):
|
||||
"""Return a set of ops in the tflite model data."""
|
||||
model = schema_fb.Model.GetRootAsModel(model_data, 0)
|
||||
op_set = set()
|
||||
|
||||
for subgraph_idx in range(model.SubgraphsLength()):
|
||||
subgraph = model.Subgraphs(subgraph_idx)
|
||||
for op_idx in range(subgraph.OperatorsLength()):
|
||||
op = subgraph.Operators(op_idx)
|
||||
opcode = model.OperatorCodes(op.OpcodeIndex())
|
||||
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
||||
if builtin_code == schema_fb.BuiltinOperator.CUSTOM:
|
||||
opname = opcode.CustomCode().decode("utf-8")
|
||||
op_set.add(opname)
|
||||
else:
|
||||
op_set.add(visualize.BuiltinCodeToName(builtin_code))
|
||||
return op_set
|
43
tensorflow/lite/python/test_util_test.py
Normal file
43
tensorflow/lite/python/test_util_test.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for test_util.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.python import test_util as tflite_test_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testBuiltinOp(self):
|
||||
model_path = resource_loader.get_path_to_datafile('../testdata/add.bin')
|
||||
op_set = tflite_test_util.get_ops_list(gfile.GFile(model_path, 'rb').read())
|
||||
self.assertCountEqual(op_set, ['ADD'])
|
||||
|
||||
def testFlexOp(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'../testdata/softplus_flex.bin')
|
||||
op_set = tflite_test_util.get_ops_list(gfile.GFile(model_path, 'rb').read())
|
||||
self.assertCountEqual(op_set, ['FlexSoftplus'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
36
tensorflow/lite/python/testdata/BUILD
vendored
36
tensorflow/lite/python/testdata/BUILD
vendored
@ -1,5 +1,6 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension", "tf_custom_op_py_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
@ -86,6 +87,39 @@ cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "double_op_and_kernels",
|
||||
srcs = ["double_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "_double_op.so",
|
||||
srcs = ["double_op.cc"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_double_op_wrapper",
|
||||
out = "double_op_wrapper.py",
|
||||
deps = [":double_op_and_kernels"],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "double_op",
|
||||
srcs = ["double_op.py"],
|
||||
dso = [":_double_op.so"],
|
||||
kernels = [":double_op_and_kernels"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_double_op_wrapper",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_registerer",
|
||||
srcs = ["test_registerer.cc"],
|
||||
|
53
tensorflow/lite/python/testdata/double_op.cc
vendored
Normal file
53
tensorflow/lite/python/testdata/double_op.cc
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("Double")
|
||||
.Input("input: int32")
|
||||
.Output("doubled: int32")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
class DoubleOp : public OpKernel {
|
||||
public:
|
||||
explicit DoubleOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Grab the input tensor
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
auto input_flat = input_tensor.flat<int32>();
|
||||
|
||||
// Create an output tensor
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int32>();
|
||||
|
||||
// Set all but the first element of the output tensor to 0.
|
||||
const int N = input_flat.size();
|
||||
for (int i = 0; i < N; i++) {
|
||||
output_flat(i) = 2 * input_flat(i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Double").Device(DEVICE_CPU), DoubleOp);
|
||||
} // namespace tensorflow
|
34
tensorflow/lite/python/testdata/double_op.py
vendored
Normal file
34
tensorflow/lite/python/testdata/double_op.py
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Double op is a user's defined op for testing purpose."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.python.testdata import double_op_wrapper
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_double_op = load_library.load_op_library(
|
||||
resource_loader.get_path_to_datafile('_double_op.so'))
|
||||
|
||||
|
||||
def double(input_tensor):
|
||||
"""Double op applies element-wise double to input data."""
|
||||
if input_tensor.dtype != dtypes.int32:
|
||||
raise ValueError('Double op only accept int32 values.')
|
||||
return double_op_wrapper.double(input_tensor)
|
@ -192,6 +192,16 @@ def _convert_tf1_model(flags):
|
||||
"{0}".format(",".join(ops_set_options)))
|
||||
converter.target_spec.supported_ops.add(lite.OpsSet(option))
|
||||
|
||||
if flags.experimental_select_user_tf_ops:
|
||||
if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
|
||||
raise ValueError("--experimental_select_user_tf_ops can only be set if "
|
||||
"--target_ops contains SELECT_TF_OPS.")
|
||||
user_op_set = set()
|
||||
for op_name in six.ensure_str(
|
||||
flags.experimental_select_user_tf_ops).split(","):
|
||||
user_op_set.add(op_name)
|
||||
converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
|
||||
|
||||
if flags.post_training_quantize:
|
||||
converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
if converter.inference_type != dtypes.float32:
|
||||
@ -313,6 +323,10 @@ def _check_tf1_flags(flags, unparsed):
|
||||
"--experimental_new_converter")
|
||||
if flags.custom_opdefs and not flags.allow_custom_ops:
|
||||
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
|
||||
if (flags.experimental_select_user_tf_ops and
|
||||
not flags.experimental_new_converter):
|
||||
raise ValueError("--experimental_select_user_tf_ops must be used with "
|
||||
"--experimental_new_converter")
|
||||
|
||||
|
||||
def _check_tf2_flags(flags):
|
||||
@ -491,6 +505,11 @@ def _get_tf1_flags(parser):
|
||||
"indicating which converter to use. Options: {0}. One or more "
|
||||
"option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
|
||||
"".format(",".join(lite.OpsSet.get_options()))))
|
||||
parser.add_argument(
|
||||
"--experimental_select_user_tf_ops",
|
||||
type=str,
|
||||
help=("Experimental flag, subject to change. Comma separated list of "
|
||||
"user's defined TensorFlow operators required in the runtime."))
|
||||
|
||||
# Logging flags.
|
||||
parser.add_argument(
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow import keras
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.lite.python import test_util as tflite_test_util
|
||||
from tensorflow.lite.python import tflite_convert
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.python import tf2
|
||||
@ -50,7 +51,10 @@ class TestModels(test_util.TensorFlowTestCase):
|
||||
def _getFilepath(self, filename):
|
||||
return os.path.join(self.get_temp_dir(), filename)
|
||||
|
||||
def _run(self, flags_str, should_succeed):
|
||||
def _run(self,
|
||||
flags_str,
|
||||
should_succeed,
|
||||
expected_ops_in_converted_model=None):
|
||||
output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
tflite_bin = resource_loader.get_path_to_datafile('tflite_convert')
|
||||
cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file,
|
||||
@ -61,6 +65,10 @@ class TestModels(test_util.TensorFlowTestCase):
|
||||
with gfile.Open(output_file, 'rb') as model_file:
|
||||
content = model_file.read()
|
||||
self.assertEqual(content is not None, should_succeed)
|
||||
if expected_ops_in_converted_model:
|
||||
op_set = tflite_test_util.get_ops_list(content)
|
||||
for opname in expected_ops_in_converted_model:
|
||||
self.assertIn(opname, op_set)
|
||||
os.remove(output_file)
|
||||
else:
|
||||
self.assertFalse(should_succeed)
|
||||
@ -83,10 +91,14 @@ class TestModels(test_util.TensorFlowTestCase):
|
||||
|
||||
class TfLiteConvertV1Test(TestModels):
|
||||
|
||||
def _run(self, flags_str, should_succeed):
|
||||
def _run(self,
|
||||
flags_str,
|
||||
should_succeed,
|
||||
expected_ops_in_converted_model=None):
|
||||
if tf2.enabled():
|
||||
flags_str += ' --enable_v1_converter'
|
||||
super(TfLiteConvertV1Test, self)._run(flags_str, should_succeed)
|
||||
super(TfLiteConvertV1Test, self)._run(flags_str, should_succeed,
|
||||
expected_ops_in_converted_model)
|
||||
|
||||
def testFrozenGraphDef(self):
|
||||
with ops.Graph().as_default():
|
||||
@ -186,8 +198,8 @@ class TfLiteConvertV1Test(TestModels):
|
||||
# Define converter flags
|
||||
flags_str = ('--std_dev_values=128,128 --mean_values=128,128 '
|
||||
'--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2}'.format(
|
||||
graph_def_file, 'inputA,inputB', 'output'))
|
||||
'--output_arrays={2}'.format(graph_def_file, 'inputA,inputB',
|
||||
'output'))
|
||||
|
||||
# Set inference_type UINT8 and (default) inference_input_type UINT8
|
||||
flags_str_1 = flags_str + ' --inference_type=UINT8'
|
||||
@ -213,9 +225,9 @@ class TfLiteConvertV1Test(TestModels):
|
||||
flags_str = '--saved_model_dir={}'.format(saved_model_dir)
|
||||
self._run(flags_str, should_succeed=True)
|
||||
|
||||
def _createSavedModelWithCustomOp(self):
|
||||
def _createSavedModelWithCustomOp(self, opname='CustomAdd'):
|
||||
custom_opdefs_str = (
|
||||
'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
|
||||
'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
|
||||
'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
|
||||
'\'Output\' type: DT_FLOAT}')
|
||||
|
||||
@ -231,10 +243,10 @@ class TfLiteConvertV1Test(TestModels):
|
||||
|
||||
new_graph.CopyFrom(sess.graph_def)
|
||||
|
||||
# Rename Add op name to CustomAdd.
|
||||
# Rename Add op name to opname.
|
||||
for node in new_graph.node:
|
||||
if node.op.startswith('Add'):
|
||||
node.op = 'CustomAdd'
|
||||
node.op = opname
|
||||
del node.attr['T']
|
||||
|
||||
# Register custom op defs to import modified graph def.
|
||||
@ -264,7 +276,26 @@ class TfLiteConvertV1Test(TestModels):
|
||||
'--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
|
||||
'--experimental_new_converter'.format(saved_model_dir,
|
||||
custom_opdefs_str))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
self._run(
|
||||
flags_str,
|
||||
should_succeed=True,
|
||||
expected_ops_in_converted_model=['CustomAdd'])
|
||||
|
||||
def testSavedModelWithFlex(self):
|
||||
saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp(
|
||||
opname='CustomAdd2')
|
||||
|
||||
# Valid conversion. OpDef already registered.
|
||||
flags_str = ('--saved_model_dir={0} --allow_custom_ops '
|
||||
'--custom_opdefs="{1}" '
|
||||
'--experimental_new_converter '
|
||||
'--experimental_select_user_tf_ops=CustomAdd2 '
|
||||
'--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS'.format(
|
||||
saved_model_dir, custom_opdefs_str))
|
||||
self._run(
|
||||
flags_str,
|
||||
should_succeed=True,
|
||||
expected_ops_in_converted_model=['FlexCustomAdd2'])
|
||||
|
||||
def testSavedModelWithInvalidCustomOpdefsFlag(self):
|
||||
saved_model_dir, _ = self._createSavedModelWithCustomOp()
|
||||
@ -393,7 +424,30 @@ class TfLiteConvertV1Test(TestModels):
|
||||
# Valid conversion.
|
||||
flags_str_final = ('{} --allow_custom_ops '
|
||||
'--experimental_new_converter').format(flags_str)
|
||||
self._run(flags_str_final, should_succeed=True)
|
||||
self._run(
|
||||
flags_str_final,
|
||||
should_succeed=True,
|
||||
expected_ops_in_converted_model=['TFLite_Detection_PostProcess'])
|
||||
|
||||
def testObjectDetectionMLIRWithFlex(self):
|
||||
"""Tests object detection model through MLIR converter."""
|
||||
self._initObjectDetectionArgs()
|
||||
|
||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2} --input_shapes={3}'.format(
|
||||
self._graph_def_file, self._input_arrays,
|
||||
self._output_arrays, self._input_shapes))
|
||||
|
||||
# Valid conversion.
|
||||
flags_str_final = (
|
||||
'{} --allow_custom_ops '
|
||||
'--experimental_new_converter '
|
||||
'--experimental_select_user_tf_ops=TFLite_Detection_PostProcess '
|
||||
'--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS').format(flags_str)
|
||||
self._run(
|
||||
flags_str_final,
|
||||
should_succeed=True,
|
||||
expected_ops_in_converted_model=['FlexTFLite_Detection_PostProcess'])
|
||||
|
||||
|
||||
class TfLiteConvertV2Test(TestModels):
|
||||
|
@ -38,7 +38,7 @@ enum FileFormat {
|
||||
// of as properties of models, instead describing how models are to be
|
||||
// processed in the context of the present tooling job.
|
||||
//
|
||||
// Next ID to use: 33.
|
||||
// Next ID to use: 34.
|
||||
message TocoFlags {
|
||||
// Input file format
|
||||
optional FileFormat input_format = 1;
|
||||
@ -226,4 +226,8 @@ message TocoFlags {
|
||||
// String representing the custom ops OpDefs that are included in the
|
||||
// GraphDef.
|
||||
repeated string custom_opdefs = 32;
|
||||
|
||||
// Name of user's defined Tensorflow ops required in the TensorFlow Lite
|
||||
// runtime. These ops will be supported as select TensorFlow ops.
|
||||
repeated string select_user_tf_ops = 33;
|
||||
}
|
||||
|
@ -4,6 +4,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\', \'experimental_select_user_tf_ops\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'supported_ops\', \'supported_types\', \'experimental_select_user_tf_ops\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user