Add flag --custom_opdefs to tflite_convert.

PiperOrigin-RevId: 280556264
Change-Id: Id9963930b26d55039c73993597ea0ab8ccc07d73
This commit is contained in:
Nupur Garg 2019-11-14 17:57:52 -08:00 committed by TensorFlower Gardener
parent f931708171
commit fab30f6349
9 changed files with 119 additions and 8 deletions

View File

@ -141,10 +141,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
return Status::OK();
}
Status RegisterCustomBuiltinOps() {
std::vector<std::string> extra_tf_opdefs;
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
@ -253,7 +250,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
specs.upgrade_legacy = true;
WarningUnusedFlags(model_flags, toco_flags);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps());
// Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));

View File

@ -80,7 +80,10 @@ py_library(
py_test(
name = "tflite_convert_test",
srcs = ["tflite_convert_test.py"],
data = [":tflite_convert"],
data = [
":tflite_convert.par",
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [

View File

@ -223,6 +223,7 @@ def build_toco_convert_protos(input_tensors,
drop_control_dependency=True,
reorder_across_fake_quant=False,
allow_custom_ops=False,
custom_opdefs=None,
change_concat_input_ranges=False,
post_training_quantize=False,
quantize_to_float16=False,
@ -273,6 +274,9 @@ def build_toco_convert_protos(input_tensors,
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
custom_opdefs: List of strings representing custom ops OpDefs that are
included in the GraphDef. Required when using custom operations with the
MLIR-based converter. (default None)
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
@ -320,6 +324,8 @@ def build_toco_convert_protos(input_tensors,
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
if custom_opdefs:
toco.custom_opdefs.extend(custom_opdefs)
toco.post_training_quantize = post_training_quantize
toco.quantize_to_float16 = quantize_to_float16
if default_ranges_stats:

View File

@ -635,6 +635,7 @@ class TFLiteConverter(TFLiteConverterBase):
self.dump_graphviz_video = False
self.conversion_summary_dir = None
self._debug_info_func = experimental_debug_info_func
self._custom_opdefs = None
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
@ -1005,7 +1006,8 @@ class TFLiteConverter(TFLiteConverterBase):
"change_concat_input_ranges": self.change_concat_input_ranges,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video,
"conversion_summary_dir": self.conversion_summary_dir
"conversion_summary_dir": self.conversion_summary_dir,
"custom_opdefs": self._custom_opdefs,
})
# Converts model.

View File

@ -174,6 +174,8 @@ def _convert_tf1_model(flags):
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
if flags.custom_opdefs:
converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access
if flags.target_ops:
ops_set_options = lite.OpsSet.get_options()
converter.target_spec.supported_ops = set()
@ -299,6 +301,12 @@ def _check_tf1_flags(flags, unparsed):
raise ValueError("--dump_graphviz_video must be used with "
"--dump_graphviz_dir")
if flags.custom_opdefs and not flags.experimental_new_converter:
raise ValueError("--custom_opdefs must be used with "
"--experimental_new_converter")
if flags.custom_opdefs and not flags.allow_custom_ops:
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
def _check_tf2_flags(flags):
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X.
@ -462,6 +470,12 @@ def _get_tf1_flags(parser):
"created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
parser.add_argument(
"--custom_opdefs",
type=str,
help=("String representing a list of custom ops OpDefs delineated with "
"commas that are included in the GraphDef. Required when using "
"custom operations with --experimental_new_converter."))
parser.add_argument(
"--target_ops",
type=str,

View File

@ -171,6 +171,82 @@ class TfLiteConvertV1Test(TestModels):
num_items_conversion_summary = len(os.listdir(log_dir))
self.assertEqual(num_items_conversion_summary, 0)
def _initObjectDetectionArgs(self):
# Initializes the arguments required for the object detection model.
# Looks for the model file which is saved in a different location internally
# and externally.
filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
if not os.path.exists(filename):
filename = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
'../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
if not os.path.exists(filename):
raise IOError("File '{0}' does not exist.".format(filename))
self._graph_def_file = filename
self._input_arrays = 'normalized_input_image_tensor'
self._output_arrays = (
'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,'
'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3')
self._input_shapes = '1,300,300,3'
def testObjectDetection(self):
"""Tests object detection model through TOCO."""
self._initObjectDetectionArgs()
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2} --input_shapes={3} '
'--allow_custom_ops'.format(self._graph_def_file,
self._input_arrays,
self._output_arrays,
self._input_shapes))
self._run(flags_str, should_succeed=True)
def testObjectDetectionMLIR(self):
"""Tests object detection model through MLIR converter."""
self._initObjectDetectionArgs()
custom_opdefs_str = (
'name: \'TFLite_Detection_PostProcess\' '
'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } '
'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } '
'input_arg: { name: \'anchors\' type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess:1\' '
'type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess:2\' '
'type: DT_FLOAT } '
'output_arg: { name: \'TFLite_Detection_PostProcess:3\' '
'type: DT_FLOAT } '
'attr : { name: \'h_scale\' type: \'float\'} '
'attr : { name: \'max_classes_per_detection\' type: \'int\'} '
'attr : { name: \'max_detections\' type: \'int\'} '
'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
'attr : { name: \'nms_score_threshold\' type: \'float\'} '
'attr : { name: \'num_classes\' type: \'int\'} '
'attr : { name: \'w_scale\' type: \'int\'} '
'attr : { name: \'x_scale\' type: \'int\'} '
'attr : { name: \'y_scale\' type: \'int\'}')
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2} --input_shapes={3} '
'--custom_opdefs="{4}"'.format(self._graph_def_file,
self._input_arrays,
self._output_arrays,
self._input_shapes,
custom_opdefs_str))
# Ensure --experimental_new_converter.
flags_str_final = ('{} --allow_custom_ops').format(flags_str)
self._run(flags_str_final, should_succeed=False)
# Ensure --allow_custom_ops.
flags_str_final = ('{} --experimental_new_converter').format(flags_str)
self._run(flags_str_final, should_succeed=False)
# Valid conversion.
flags_str_final = ('{} --allow_custom_ops '
'--experimental_new_converter').format(flags_str)
self._run(flags_str_final, should_succeed=True)
class TfLiteConvertV2Test(TestModels):

View File

@ -173,6 +173,7 @@ struct ParsedTocoFlags {
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
Arg<string> custom_opdefs;
Arg<bool> post_training_quantize = Arg<bool>(false);
Arg<bool> quantize_to_float16 = Arg<bool>(false);
// Deprecated flags

View File

@ -124,6 +124,10 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.allow_custom_ops.default_value(),
"If true, allow TOCO to create TF Lite Custom operators for all the "
"unsupported TensorFlow ops."),
Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
parsed_flags.custom_opdefs.default_value(),
"List of strings representing custom ops OpDefs that are included "
"in the GraphDef."),
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
parsed_flags.allow_dynamic_tensors.default_value(),
"Boolean flag indicating whether the converter should allow models "

View File

@ -38,7 +38,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
// Next ID to use: 32.
// Next ID to use: 33.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@ -222,4 +222,8 @@ message TocoFlags {
// Full filepath of the folder to dump conversion logs. This includes a global
// view of the conversion process, and user can choose to submit those logs.
optional string conversion_summary_dir = 31;
// String representing the custom ops OpDefs that are included in the
// GraphDef.
repeated string custom_opdefs = 32;
}