Add flag --custom_opdefs to tflite_convert.
PiperOrigin-RevId: 280556264 Change-Id: Id9963930b26d55039c73993597ea0ab8ccc07d73
This commit is contained in:
parent
f931708171
commit
fab30f6349
@ -141,10 +141,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RegisterCustomBuiltinOps() {
|
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
|
||||||
std::vector<std::string> extra_tf_opdefs;
|
|
||||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
|
||||||
|
|
||||||
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
|
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
|
||||||
tensorflow::OpDef opdef;
|
tensorflow::OpDef opdef;
|
||||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
||||||
@ -253,7 +250,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
specs.upgrade_legacy = true;
|
specs.upgrade_legacy = true;
|
||||||
WarningUnusedFlags(model_flags, toco_flags);
|
WarningUnusedFlags(model_flags, toco_flags);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps());
|
// Register any custom OpDefs.
|
||||||
|
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||||
|
toco_flags.custom_opdefs().end());
|
||||||
|
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||||
|
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||||
|
@ -80,7 +80,10 @@ py_library(
|
|||||||
py_test(
|
py_test(
|
||||||
name = "tflite_convert_test",
|
name = "tflite_convert_test",
|
||||||
srcs = ["tflite_convert_test.py"],
|
srcs = ["tflite_convert_test.py"],
|
||||||
data = [":tflite_convert"],
|
data = [
|
||||||
|
":tflite_convert.par",
|
||||||
|
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
|
||||||
|
],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
|
@ -223,6 +223,7 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
drop_control_dependency=True,
|
drop_control_dependency=True,
|
||||||
reorder_across_fake_quant=False,
|
reorder_across_fake_quant=False,
|
||||||
allow_custom_ops=False,
|
allow_custom_ops=False,
|
||||||
|
custom_opdefs=None,
|
||||||
change_concat_input_ranges=False,
|
change_concat_input_ranges=False,
|
||||||
post_training_quantize=False,
|
post_training_quantize=False,
|
||||||
quantize_to_float16=False,
|
quantize_to_float16=False,
|
||||||
@ -273,6 +274,9 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
created for any op that is unknown. The developer will need to provide
|
created for any op that is unknown. The developer will need to provide
|
||||||
these to the TensorFlow Lite runtime with a custom resolver.
|
these to the TensorFlow Lite runtime with a custom resolver.
|
||||||
(default False)
|
(default False)
|
||||||
|
custom_opdefs: List of strings representing custom ops OpDefs that are
|
||||||
|
included in the GraphDef. Required when using custom operations with the
|
||||||
|
MLIR-based converter. (default None)
|
||||||
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
|
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
|
||||||
inputs and outputs of the concat operator for quantized models. Changes
|
inputs and outputs of the concat operator for quantized models. Changes
|
||||||
the ranges of concat operator overlap when true. (default False)
|
the ranges of concat operator overlap when true. (default False)
|
||||||
@ -320,6 +324,8 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
toco.drop_control_dependency = drop_control_dependency
|
toco.drop_control_dependency = drop_control_dependency
|
||||||
toco.reorder_across_fake_quant = reorder_across_fake_quant
|
toco.reorder_across_fake_quant = reorder_across_fake_quant
|
||||||
toco.allow_custom_ops = allow_custom_ops
|
toco.allow_custom_ops = allow_custom_ops
|
||||||
|
if custom_opdefs:
|
||||||
|
toco.custom_opdefs.extend(custom_opdefs)
|
||||||
toco.post_training_quantize = post_training_quantize
|
toco.post_training_quantize = post_training_quantize
|
||||||
toco.quantize_to_float16 = quantize_to_float16
|
toco.quantize_to_float16 = quantize_to_float16
|
||||||
if default_ranges_stats:
|
if default_ranges_stats:
|
||||||
|
@ -635,6 +635,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
self.dump_graphviz_video = False
|
self.dump_graphviz_video = False
|
||||||
self.conversion_summary_dir = None
|
self.conversion_summary_dir = None
|
||||||
self._debug_info_func = experimental_debug_info_func
|
self._debug_info_func = experimental_debug_info_func
|
||||||
|
self._custom_opdefs = None
|
||||||
|
|
||||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||||
if not self._has_valid_tensors():
|
if not self._has_valid_tensors():
|
||||||
@ -1005,7 +1006,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||||
"dump_graphviz_video": self.dump_graphviz_video,
|
"dump_graphviz_video": self.dump_graphviz_video,
|
||||||
"conversion_summary_dir": self.conversion_summary_dir
|
"conversion_summary_dir": self.conversion_summary_dir,
|
||||||
|
"custom_opdefs": self._custom_opdefs,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Converts model.
|
# Converts model.
|
||||||
|
@ -174,6 +174,8 @@ def _convert_tf1_model(flags):
|
|||||||
|
|
||||||
if flags.allow_custom_ops:
|
if flags.allow_custom_ops:
|
||||||
converter.allow_custom_ops = flags.allow_custom_ops
|
converter.allow_custom_ops = flags.allow_custom_ops
|
||||||
|
if flags.custom_opdefs:
|
||||||
|
converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access
|
||||||
if flags.target_ops:
|
if flags.target_ops:
|
||||||
ops_set_options = lite.OpsSet.get_options()
|
ops_set_options = lite.OpsSet.get_options()
|
||||||
converter.target_spec.supported_ops = set()
|
converter.target_spec.supported_ops = set()
|
||||||
@ -299,6 +301,12 @@ def _check_tf1_flags(flags, unparsed):
|
|||||||
raise ValueError("--dump_graphviz_video must be used with "
|
raise ValueError("--dump_graphviz_video must be used with "
|
||||||
"--dump_graphviz_dir")
|
"--dump_graphviz_dir")
|
||||||
|
|
||||||
|
if flags.custom_opdefs and not flags.experimental_new_converter:
|
||||||
|
raise ValueError("--custom_opdefs must be used with "
|
||||||
|
"--experimental_new_converter")
|
||||||
|
if flags.custom_opdefs and not flags.allow_custom_ops:
|
||||||
|
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
|
||||||
|
|
||||||
|
|
||||||
def _check_tf2_flags(flags):
|
def _check_tf2_flags(flags):
|
||||||
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X.
|
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X.
|
||||||
@ -462,6 +470,12 @@ def _get_tf1_flags(parser):
|
|||||||
"created for any op that is unknown. The developer will need to "
|
"created for any op that is unknown. The developer will need to "
|
||||||
"provide these to the TensorFlow Lite runtime with a custom "
|
"provide these to the TensorFlow Lite runtime with a custom "
|
||||||
"resolver. (default False)"))
|
"resolver. (default False)"))
|
||||||
|
parser.add_argument(
|
||||||
|
"--custom_opdefs",
|
||||||
|
type=str,
|
||||||
|
help=("String representing a list of custom ops OpDefs delineated with "
|
||||||
|
"commas that are included in the GraphDef. Required when using "
|
||||||
|
"custom operations with --experimental_new_converter."))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--target_ops",
|
"--target_ops",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -171,6 +171,82 @@ class TfLiteConvertV1Test(TestModels):
|
|||||||
num_items_conversion_summary = len(os.listdir(log_dir))
|
num_items_conversion_summary = len(os.listdir(log_dir))
|
||||||
self.assertEqual(num_items_conversion_summary, 0)
|
self.assertEqual(num_items_conversion_summary, 0)
|
||||||
|
|
||||||
|
def _initObjectDetectionArgs(self):
|
||||||
|
# Initializes the arguments required for the object detection model.
|
||||||
|
# Looks for the model file which is saved in a different location internally
|
||||||
|
# and externally.
|
||||||
|
filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
filename = os.path.join(
|
||||||
|
resource_loader.get_root_dir_with_all_resources(),
|
||||||
|
'../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
raise IOError("File '{0}' does not exist.".format(filename))
|
||||||
|
|
||||||
|
self._graph_def_file = filename
|
||||||
|
self._input_arrays = 'normalized_input_image_tensor'
|
||||||
|
self._output_arrays = (
|
||||||
|
'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,'
|
||||||
|
'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3')
|
||||||
|
self._input_shapes = '1,300,300,3'
|
||||||
|
|
||||||
|
def testObjectDetection(self):
|
||||||
|
"""Tests object detection model through TOCO."""
|
||||||
|
self._initObjectDetectionArgs()
|
||||||
|
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||||
|
'--output_arrays={2} --input_shapes={3} '
|
||||||
|
'--allow_custom_ops'.format(self._graph_def_file,
|
||||||
|
self._input_arrays,
|
||||||
|
self._output_arrays,
|
||||||
|
self._input_shapes))
|
||||||
|
self._run(flags_str, should_succeed=True)
|
||||||
|
|
||||||
|
def testObjectDetectionMLIR(self):
|
||||||
|
"""Tests object detection model through MLIR converter."""
|
||||||
|
self._initObjectDetectionArgs()
|
||||||
|
custom_opdefs_str = (
|
||||||
|
'name: \'TFLite_Detection_PostProcess\' '
|
||||||
|
'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } '
|
||||||
|
'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } '
|
||||||
|
'input_arg: { name: \'anchors\' type: DT_FLOAT } '
|
||||||
|
'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } '
|
||||||
|
'output_arg: { name: \'TFLite_Detection_PostProcess:1\' '
|
||||||
|
'type: DT_FLOAT } '
|
||||||
|
'output_arg: { name: \'TFLite_Detection_PostProcess:2\' '
|
||||||
|
'type: DT_FLOAT } '
|
||||||
|
'output_arg: { name: \'TFLite_Detection_PostProcess:3\' '
|
||||||
|
'type: DT_FLOAT } '
|
||||||
|
'attr : { name: \'h_scale\' type: \'float\'} '
|
||||||
|
'attr : { name: \'max_classes_per_detection\' type: \'int\'} '
|
||||||
|
'attr : { name: \'max_detections\' type: \'int\'} '
|
||||||
|
'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
|
||||||
|
'attr : { name: \'nms_score_threshold\' type: \'float\'} '
|
||||||
|
'attr : { name: \'num_classes\' type: \'int\'} '
|
||||||
|
'attr : { name: \'w_scale\' type: \'int\'} '
|
||||||
|
'attr : { name: \'x_scale\' type: \'int\'} '
|
||||||
|
'attr : { name: \'y_scale\' type: \'int\'}')
|
||||||
|
|
||||||
|
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||||
|
'--output_arrays={2} --input_shapes={3} '
|
||||||
|
'--custom_opdefs="{4}"'.format(self._graph_def_file,
|
||||||
|
self._input_arrays,
|
||||||
|
self._output_arrays,
|
||||||
|
self._input_shapes,
|
||||||
|
custom_opdefs_str))
|
||||||
|
|
||||||
|
# Ensure --experimental_new_converter.
|
||||||
|
flags_str_final = ('{} --allow_custom_ops').format(flags_str)
|
||||||
|
self._run(flags_str_final, should_succeed=False)
|
||||||
|
|
||||||
|
# Ensure --allow_custom_ops.
|
||||||
|
flags_str_final = ('{} --experimental_new_converter').format(flags_str)
|
||||||
|
self._run(flags_str_final, should_succeed=False)
|
||||||
|
|
||||||
|
# Valid conversion.
|
||||||
|
flags_str_final = ('{} --allow_custom_ops '
|
||||||
|
'--experimental_new_converter').format(flags_str)
|
||||||
|
self._run(flags_str_final, should_succeed=True)
|
||||||
|
|
||||||
|
|
||||||
class TfLiteConvertV2Test(TestModels):
|
class TfLiteConvertV2Test(TestModels):
|
||||||
|
|
||||||
|
@ -173,6 +173,7 @@ struct ParsedTocoFlags {
|
|||||||
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
||||||
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
||||||
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
|
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
|
||||||
|
Arg<string> custom_opdefs;
|
||||||
Arg<bool> post_training_quantize = Arg<bool>(false);
|
Arg<bool> post_training_quantize = Arg<bool>(false);
|
||||||
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
||||||
// Deprecated flags
|
// Deprecated flags
|
||||||
|
@ -124,6 +124,10 @@ bool ParseTocoFlagsFromCommandLineFlags(
|
|||||||
parsed_flags.allow_custom_ops.default_value(),
|
parsed_flags.allow_custom_ops.default_value(),
|
||||||
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
||||||
"unsupported TensorFlow ops."),
|
"unsupported TensorFlow ops."),
|
||||||
|
Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
|
||||||
|
parsed_flags.custom_opdefs.default_value(),
|
||||||
|
"List of strings representing custom ops OpDefs that are included "
|
||||||
|
"in the GraphDef."),
|
||||||
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
|
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
|
||||||
parsed_flags.allow_dynamic_tensors.default_value(),
|
parsed_flags.allow_dynamic_tensors.default_value(),
|
||||||
"Boolean flag indicating whether the converter should allow models "
|
"Boolean flag indicating whether the converter should allow models "
|
||||||
|
@ -38,7 +38,7 @@ enum FileFormat {
|
|||||||
// of as properties of models, instead describing how models are to be
|
// of as properties of models, instead describing how models are to be
|
||||||
// processed in the context of the present tooling job.
|
// processed in the context of the present tooling job.
|
||||||
//
|
//
|
||||||
// Next ID to use: 32.
|
// Next ID to use: 33.
|
||||||
message TocoFlags {
|
message TocoFlags {
|
||||||
// Input file format
|
// Input file format
|
||||||
optional FileFormat input_format = 1;
|
optional FileFormat input_format = 1;
|
||||||
@ -222,4 +222,8 @@ message TocoFlags {
|
|||||||
// Full filepath of the folder to dump conversion logs. This includes a global
|
// Full filepath of the folder to dump conversion logs. This includes a global
|
||||||
// view of the conversion process, and user can choose to submit those logs.
|
// view of the conversion process, and user can choose to submit those logs.
|
||||||
optional string conversion_summary_dir = 31;
|
optional string conversion_summary_dir = 31;
|
||||||
|
|
||||||
|
// String representing the custom ops OpDefs that are included in the
|
||||||
|
// GraphDef.
|
||||||
|
repeated string custom_opdefs = 32;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user