Add flag --custom_opdefs to tflite_convert.

PiperOrigin-RevId: 280556264
Change-Id: Id9963930b26d55039c73993597ea0ab8ccc07d73
This commit is contained in:
Nupur Garg 2019-11-14 17:57:52 -08:00 committed by TensorFlower Gardener
parent f931708171
commit fab30f6349
9 changed files with 119 additions and 8 deletions

View File

@ -141,10 +141,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
return Status::OK(); return Status::OK();
} }
Status RegisterCustomBuiltinOps() { Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
std::vector<std::string> extra_tf_opdefs;
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
for (const auto& tf_opdefs_string : extra_tf_opdefs) { for (const auto& tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef; tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
@ -253,7 +250,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
specs.upgrade_legacy = true; specs.upgrade_legacy = true;
WarningUnusedFlags(model_flags, toco_flags); WarningUnusedFlags(model_flags, toco_flags);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps()); // Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));

View File

@ -80,7 +80,10 @@ py_library(
py_test( py_test(
name = "tflite_convert_test", name = "tflite_convert_test",
srcs = ["tflite_convert_test.py"], srcs = ["tflite_convert_test.py"],
data = [":tflite_convert"], data = [
":tflite_convert.par",
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = [ tags = [

View File

@ -223,6 +223,7 @@ def build_toco_convert_protos(input_tensors,
drop_control_dependency=True, drop_control_dependency=True,
reorder_across_fake_quant=False, reorder_across_fake_quant=False,
allow_custom_ops=False, allow_custom_ops=False,
custom_opdefs=None,
change_concat_input_ranges=False, change_concat_input_ranges=False,
post_training_quantize=False, post_training_quantize=False,
quantize_to_float16=False, quantize_to_float16=False,
@ -273,6 +274,9 @@ def build_toco_convert_protos(input_tensors,
created for any op that is unknown. The developer will need to provide created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. these to the TensorFlow Lite runtime with a custom resolver.
(default False) (default False)
custom_opdefs: List of strings representing custom ops OpDefs that are
included in the GraphDef. Required when using custom operations with the
MLIR-based converter. (default None)
change_concat_input_ranges: Boolean to change behavior of min/max ranges for change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False) the ranges of concat operator overlap when true. (default False)
@ -320,6 +324,8 @@ def build_toco_convert_protos(input_tensors,
toco.drop_control_dependency = drop_control_dependency toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops toco.allow_custom_ops = allow_custom_ops
if custom_opdefs:
toco.custom_opdefs.extend(custom_opdefs)
toco.post_training_quantize = post_training_quantize toco.post_training_quantize = post_training_quantize
toco.quantize_to_float16 = quantize_to_float16 toco.quantize_to_float16 = quantize_to_float16
if default_ranges_stats: if default_ranges_stats:

View File

@ -635,6 +635,7 @@ class TFLiteConverter(TFLiteConverterBase):
self.dump_graphviz_video = False self.dump_graphviz_video = False
self.conversion_summary_dir = None self.conversion_summary_dir = None
self._debug_info_func = experimental_debug_info_func self._debug_info_func = experimental_debug_info_func
self._custom_opdefs = None
# Attributes are used by models that cannot be loaded into TensorFlow. # Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors(): if not self._has_valid_tensors():
@ -1005,7 +1006,8 @@ class TFLiteConverter(TFLiteConverterBase):
"change_concat_input_ranges": self.change_concat_input_ranges, "change_concat_input_ranges": self.change_concat_input_ranges,
"dump_graphviz_dir": self.dump_graphviz_dir, "dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video, "dump_graphviz_video": self.dump_graphviz_video,
"conversion_summary_dir": self.conversion_summary_dir "conversion_summary_dir": self.conversion_summary_dir,
"custom_opdefs": self._custom_opdefs,
}) })
# Converts model. # Converts model.

View File

@ -174,6 +174,8 @@ def _convert_tf1_model(flags):
if flags.allow_custom_ops: if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops converter.allow_custom_ops = flags.allow_custom_ops
if flags.custom_opdefs:
converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access
if flags.target_ops: if flags.target_ops:
ops_set_options = lite.OpsSet.get_options() ops_set_options = lite.OpsSet.get_options()
converter.target_spec.supported_ops = set() converter.target_spec.supported_ops = set()
@ -299,6 +301,12 @@ def _check_tf1_flags(flags, unparsed):
raise ValueError("--dump_graphviz_video must be used with " raise ValueError("--dump_graphviz_video must be used with "
"--dump_graphviz_dir") "--dump_graphviz_dir")
if flags.custom_opdefs and not flags.experimental_new_converter:
raise ValueError("--custom_opdefs must be used with "
"--experimental_new_converter")
if flags.custom_opdefs and not flags.allow_custom_ops:
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
def _check_tf2_flags(flags): def _check_tf2_flags(flags):
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X. """Checks the parsed and unparsed flags to ensure they are valid in 2.X.
@ -462,6 +470,12 @@ def _get_tf1_flags(parser):
"created for any op that is unknown. The developer will need to " "created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom " "provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)")) "resolver. (default False)"))
parser.add_argument(
"--custom_opdefs",
type=str,
help=("String representing a list of custom ops OpDefs delineated with "
"commas that are included in the GraphDef. Required when using "
"custom operations with --experimental_new_converter."))
parser.add_argument( parser.add_argument(
"--target_ops", "--target_ops",
type=str, type=str,

View File

@ -171,6 +171,82 @@ class TfLiteConvertV1Test(TestModels):
num_items_conversion_summary = len(os.listdir(log_dir)) num_items_conversion_summary = len(os.listdir(log_dir))
self.assertEqual(num_items_conversion_summary, 0) self.assertEqual(num_items_conversion_summary, 0)
def _initObjectDetectionArgs(self):
# Initializes the arguments required for the object detection model.
# Looks for the model file which is saved in a different location internally
# and externally.
filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
if not os.path.exists(filename):
filename = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
'../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
if not os.path.exists(filename):
raise IOError("File '{0}' does not exist.".format(filename))
self._graph_def_file = filename
self._input_arrays = 'normalized_input_image_tensor'
self._output_arrays = (
'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,'
'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3')
self._input_shapes = '1,300,300,3'
def testObjectDetection(self):
"""Tests object detection model through TOCO."""
self._initObjectDetectionArgs()
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2} --input_shapes={3} '
'--allow_custom_ops'.format(self._graph_def_file,
self._input_arrays,
self._output_arrays,
self._input_shapes))
self._run(flags_str, should_succeed=True)
def testObjectDetectionMLIR(self):
"""Tests object detection model through MLIR converter."""
self._initObjectDetectionArgs()
custom_opdefs_str = (
'name: \'TFLite_Detection_PostProcess\' '
'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } '
'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } '
'input_arg: { name: \'anchors\' type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess:1\' '
'type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess:2\' '
'type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess:3\' '
'type: DT_FLOAT } '
'attr : { name: \'h_scale\' type: \'float\'} '
'attr : { name: \'max_classes_per_detection\' type: \'int\'} '
'attr : { name: \'max_detections\' type: \'int\'} '
'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
'attr : { name: \'nms_score_threshold\' type: \'float\'} '
'attr : { name: \'num_classes\' type: \'int\'} '
'attr : { name: \'w_scale\' type: \'int\'} '
'attr : { name: \'x_scale\' type: \'int\'} '
'attr : { name: \'y_scale\' type: \'int\'}')
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2} --input_shapes={3} '
'--custom_opdefs="{4}"'.format(self._graph_def_file,
self._input_arrays,
self._output_arrays,
self._input_shapes,
custom_opdefs_str))
# Ensure --experimental_new_converter.
flags_str_final = ('{} --allow_custom_ops').format(flags_str)
self._run(flags_str_final, should_succeed=False)
# Ensure --allow_custom_ops.
flags_str_final = ('{} --experimental_new_converter').format(flags_str)
self._run(flags_str_final, should_succeed=False)
# Valid conversion.
flags_str_final = ('{} --allow_custom_ops '
'--experimental_new_converter').format(flags_str)
self._run(flags_str_final, should_succeed=True)
class TfLiteConvertV2Test(TestModels): class TfLiteConvertV2Test(TestModels):

View File

@ -173,6 +173,7 @@ struct ParsedTocoFlags {
Arg<bool> reorder_across_fake_quant = Arg<bool>(false); Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false); Arg<bool> allow_custom_ops = Arg<bool>(false);
Arg<bool> allow_dynamic_tensors = Arg<bool>(true); Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
Arg<string> custom_opdefs;
Arg<bool> post_training_quantize = Arg<bool>(false); Arg<bool> post_training_quantize = Arg<bool>(false);
Arg<bool> quantize_to_float16 = Arg<bool>(false); Arg<bool> quantize_to_float16 = Arg<bool>(false);
// Deprecated flags // Deprecated flags

View File

@ -124,6 +124,10 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.allow_custom_ops.default_value(), parsed_flags.allow_custom_ops.default_value(),
"If true, allow TOCO to create TF Lite Custom operators for all the " "If true, allow TOCO to create TF Lite Custom operators for all the "
"unsupported TensorFlow ops."), "unsupported TensorFlow ops."),
Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
parsed_flags.custom_opdefs.default_value(),
"List of strings representing custom ops OpDefs that are included "
"in the GraphDef."),
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(), Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
parsed_flags.allow_dynamic_tensors.default_value(), parsed_flags.allow_dynamic_tensors.default_value(),
"Boolean flag indicating whether the converter should allow models " "Boolean flag indicating whether the converter should allow models "

View File

@ -38,7 +38,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be // of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job. // processed in the context of the present tooling job.
// //
// Next ID to use: 32. // Next ID to use: 33.
message TocoFlags { message TocoFlags {
// Input file format // Input file format
optional FileFormat input_format = 1; optional FileFormat input_format = 1;
@ -222,4 +222,8 @@ message TocoFlags {
// Full filepath of the folder to dump conversion logs. This includes a global // Full filepath of the folder to dump conversion logs. This includes a global
// view of the conversion process, and user can choose to submit those logs. // view of the conversion process, and user can choose to submit those logs.
optional string conversion_summary_dir = 31; optional string conversion_summary_dir = 31;
// String representing the custom ops OpDefs that are included in the
// GraphDef.
repeated string custom_opdefs = 32;
} }