diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 3be488a8784..20acbab51cd 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -141,10 +141,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { return Status::OK(); } -Status RegisterCustomBuiltinOps() { - std::vector extra_tf_opdefs; - extra_tf_opdefs.push_back(kDetectionPostProcessOp); - +Status RegisterCustomBuiltinOps(const std::vector extra_tf_opdefs) { for (const auto& tf_opdefs_string : extra_tf_opdefs) { tensorflow::OpDef opdef; if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, @@ -253,7 +250,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, specs.upgrade_legacy = true; WarningUnusedFlags(model_flags, toco_flags); - TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps()); + // Register any custom OpDefs. + std::vector extra_tf_opdefs(toco_flags.custom_opdefs().begin(), + toco_flags.custom_opdefs().end()); + extra_tf_opdefs.push_back(kDetectionPostProcessOp); + TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs)); TF_ASSIGN_OR_RETURN( auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index a7f4c3e4804..902f69b79aa 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -80,7 +80,10 @@ py_library( py_test( name = "tflite_convert_test", srcs = ["tflite_convert_test.py"], - data = [":tflite_convert"], + data = [ + ":tflite_convert.par", + "@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb", + ], python_version = "PY3", srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 5f41a1142c2..f719f62e2d8 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -223,6 +223,7 @@ def build_toco_convert_protos(input_tensors, drop_control_dependency=True, reorder_across_fake_quant=False, allow_custom_ops=False, + custom_opdefs=None, change_concat_input_ranges=False, post_training_quantize=False, quantize_to_float16=False, @@ -273,6 +274,9 @@ def build_toco_convert_protos(input_tensors, created for any op that is unknown. The developer will need to provide these to the TensorFlow Lite runtime with a custom resolver. (default False) + custom_opdefs: List of strings representing custom ops OpDefs that are + included in the GraphDef. Required when using custom operations with the + MLIR-based converter. (default None) change_concat_input_ranges: Boolean to change behavior of min/max ranges for inputs and outputs of the concat operator for quantized models. Changes the ranges of concat operator overlap when true. (default False) @@ -320,6 +324,8 @@ def build_toco_convert_protos(input_tensors, toco.drop_control_dependency = drop_control_dependency toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops + if custom_opdefs: + toco.custom_opdefs.extend(custom_opdefs) toco.post_training_quantize = post_training_quantize toco.quantize_to_float16 = quantize_to_float16 if default_ranges_stats: diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 2ae37d892b3..57a0f21e72e 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -635,6 +635,7 @@ class TFLiteConverter(TFLiteConverterBase): self.dump_graphviz_video = False self.conversion_summary_dir = None self._debug_info_func = experimental_debug_info_func + self._custom_opdefs = None # Attributes are used by models that cannot be loaded into TensorFlow. if not self._has_valid_tensors(): @@ -1005,7 +1006,8 @@ class TFLiteConverter(TFLiteConverterBase): "change_concat_input_ranges": self.change_concat_input_ranges, "dump_graphviz_dir": self.dump_graphviz_dir, "dump_graphviz_video": self.dump_graphviz_video, - "conversion_summary_dir": self.conversion_summary_dir + "conversion_summary_dir": self.conversion_summary_dir, + "custom_opdefs": self._custom_opdefs, }) # Converts model. diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index 1b779fb4673..d66fe0bb5a9 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -174,6 +174,8 @@ def _convert_tf1_model(flags): if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops + if flags.custom_opdefs: + converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access if flags.target_ops: ops_set_options = lite.OpsSet.get_options() converter.target_spec.supported_ops = set() @@ -299,6 +301,12 @@ def _check_tf1_flags(flags, unparsed): raise ValueError("--dump_graphviz_video must be used with " "--dump_graphviz_dir") + if flags.custom_opdefs and not flags.experimental_new_converter: + raise ValueError("--custom_opdefs must be used with " + "--experimental_new_converter") + if flags.custom_opdefs and not flags.allow_custom_ops: + raise ValueError("--custom_opdefs must be used with --allow_custom_ops") + def _check_tf2_flags(flags): """Checks the parsed and unparsed flags to ensure they are valid in 2.X. @@ -462,6 +470,12 @@ def _get_tf1_flags(parser): "created for any op that is unknown. The developer will need to " "provide these to the TensorFlow Lite runtime with a custom " "resolver. (default False)")) + parser.add_argument( + "--custom_opdefs", + type=str, + help=("String representing a list of custom ops OpDefs delineated with " + "commas that are included in the GraphDef. Required when using " + "custom operations with --experimental_new_converter.")) parser.add_argument( "--target_ops", type=str, diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py index 018a40f3214..610f5c5e98b 100644 --- a/tensorflow/lite/python/tflite_convert_test.py +++ b/tensorflow/lite/python/tflite_convert_test.py @@ -171,6 +171,82 @@ class TfLiteConvertV1Test(TestModels): num_items_conversion_summary = len(os.listdir(log_dir)) self.assertEqual(num_items_conversion_summary, 0) + def _initObjectDetectionArgs(self): + # Initializes the arguments required for the object detection model. + # Looks for the model file which is saved in a different location internally + # and externally. + filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') + if not os.path.exists(filename): + filename = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') + if not os.path.exists(filename): + raise IOError("File '{0}' does not exist.".format(filename)) + + self._graph_def_file = filename + self._input_arrays = 'normalized_input_image_tensor' + self._output_arrays = ( + 'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,' + 'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3') + self._input_shapes = '1,300,300,3' + + def testObjectDetection(self): + """Tests object detection model through TOCO.""" + self._initObjectDetectionArgs() + flags_str = ('--graph_def_file={0} --input_arrays={1} ' + '--output_arrays={2} --input_shapes={3} ' + '--allow_custom_ops'.format(self._graph_def_file, + self._input_arrays, + self._output_arrays, + self._input_shapes)) + self._run(flags_str, should_succeed=True) + + def testObjectDetectionMLIR(self): + """Tests object detection model through MLIR converter.""" + self._initObjectDetectionArgs() + custom_opdefs_str = ( + 'name: \'TFLite_Detection_PostProcess\' ' + 'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } ' + 'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } ' + 'input_arg: { name: \'anchors\' type: DT_FLOAT } ' + 'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } ' + 'output_arg: { name: \'TFLite_Detection_PostProcess:1\' ' + 'type: DT_FLOAT } ' + 'output_arg: { name: \'TFLite_Detection_PostProcess:2\' ' + 'type: DT_FLOAT } ' + 'output_arg: { name: \'TFLite_Detection_PostProcess:3\' ' + 'type: DT_FLOAT } ' + 'attr : { name: \'h_scale\' type: \'float\'} ' + 'attr : { name: \'max_classes_per_detection\' type: \'int\'} ' + 'attr : { name: \'max_detections\' type: \'int\'} ' + 'attr : { name: \'nms_iou_threshold\' type: \'float\'} ' + 'attr : { name: \'nms_score_threshold\' type: \'float\'} ' + 'attr : { name: \'num_classes\' type: \'int\'} ' + 'attr : { name: \'w_scale\' type: \'int\'} ' + 'attr : { name: \'x_scale\' type: \'int\'} ' + 'attr : { name: \'y_scale\' type: \'int\'}') + + flags_str = ('--graph_def_file={0} --input_arrays={1} ' + '--output_arrays={2} --input_shapes={3} ' + '--custom_opdefs="{4}"'.format(self._graph_def_file, + self._input_arrays, + self._output_arrays, + self._input_shapes, + custom_opdefs_str)) + + # Ensure --experimental_new_converter. + flags_str_final = ('{} --allow_custom_ops').format(flags_str) + self._run(flags_str_final, should_succeed=False) + + # Ensure --allow_custom_ops. + flags_str_final = ('{} --experimental_new_converter').format(flags_str) + self._run(flags_str_final, should_succeed=False) + + # Valid conversion. + flags_str_final = ('{} --allow_custom_ops ' + '--experimental_new_converter').format(flags_str) + self._run(flags_str_final, should_succeed=True) + class TfLiteConvertV2Test(TestModels): diff --git a/tensorflow/lite/toco/args.h b/tensorflow/lite/toco/args.h index c30ec316128..20fa5ecc20c 100644 --- a/tensorflow/lite/toco/args.h +++ b/tensorflow/lite/toco/args.h @@ -173,6 +173,7 @@ struct ParsedTocoFlags { Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); Arg allow_dynamic_tensors = Arg(true); + Arg custom_opdefs; Arg post_training_quantize = Arg(false); Arg quantize_to_float16 = Arg(false); // Deprecated flags diff --git a/tensorflow/lite/toco/toco_cmdline_flags.cc b/tensorflow/lite/toco/toco_cmdline_flags.cc index d21f8d14112..25a286ee76d 100644 --- a/tensorflow/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/lite/toco/toco_cmdline_flags.cc @@ -124,6 +124,10 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.allow_custom_ops.default_value(), "If true, allow TOCO to create TF Lite Custom operators for all the " "unsupported TensorFlow ops."), + Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(), + parsed_flags.custom_opdefs.default_value(), + "List of strings representing custom ops OpDefs that are included " + "in the GraphDef."), Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(), parsed_flags.allow_dynamic_tensors.default_value(), "Boolean flag indicating whether the converter should allow models " diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index 422f5129412..83f1d7bd79e 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -38,7 +38,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 32. +// Next ID to use: 33. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -222,4 +222,8 @@ message TocoFlags { // Full filepath of the folder to dump conversion logs. This includes a global // view of the conversion process, and user can choose to submit those logs. optional string conversion_summary_dir = 31; + + // String representing the custom ops OpDefs that are included in the + // GraphDef. + repeated string custom_opdefs = 32; }