Add flag --custom_opdefs to tflite_convert.
PiperOrigin-RevId: 280556264 Change-Id: Id9963930b26d55039c73993597ea0ab8ccc07d73
This commit is contained in:
parent
f931708171
commit
fab30f6349
@ -141,10 +141,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterCustomBuiltinOps() {
|
||||
std::vector<std::string> extra_tf_opdefs;
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
|
||||
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
|
||||
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
|
||||
tensorflow::OpDef opdef;
|
||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
||||
@ -253,7 +250,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
specs.upgrade_legacy = true;
|
||||
WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps());
|
||||
// Register any custom OpDefs.
|
||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
@ -80,7 +80,10 @@ py_library(
|
||||
py_test(
|
||||
name = "tflite_convert_test",
|
||||
srcs = ["tflite_convert_test.py"],
|
||||
data = [":tflite_convert"],
|
||||
data = [
|
||||
":tflite_convert.par",
|
||||
"@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
|
@ -223,6 +223,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
drop_control_dependency=True,
|
||||
reorder_across_fake_quant=False,
|
||||
allow_custom_ops=False,
|
||||
custom_opdefs=None,
|
||||
change_concat_input_ranges=False,
|
||||
post_training_quantize=False,
|
||||
quantize_to_float16=False,
|
||||
@ -273,6 +274,9 @@ def build_toco_convert_protos(input_tensors,
|
||||
created for any op that is unknown. The developer will need to provide
|
||||
these to the TensorFlow Lite runtime with a custom resolver.
|
||||
(default False)
|
||||
custom_opdefs: List of strings representing custom ops OpDefs that are
|
||||
included in the GraphDef. Required when using custom operations with the
|
||||
MLIR-based converter. (default None)
|
||||
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
|
||||
inputs and outputs of the concat operator for quantized models. Changes
|
||||
the ranges of concat operator overlap when true. (default False)
|
||||
@ -320,6 +324,8 @@ def build_toco_convert_protos(input_tensors,
|
||||
toco.drop_control_dependency = drop_control_dependency
|
||||
toco.reorder_across_fake_quant = reorder_across_fake_quant
|
||||
toco.allow_custom_ops = allow_custom_ops
|
||||
if custom_opdefs:
|
||||
toco.custom_opdefs.extend(custom_opdefs)
|
||||
toco.post_training_quantize = post_training_quantize
|
||||
toco.quantize_to_float16 = quantize_to_float16
|
||||
if default_ranges_stats:
|
||||
|
@ -635,6 +635,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
self.dump_graphviz_video = False
|
||||
self.conversion_summary_dir = None
|
||||
self._debug_info_func = experimental_debug_info_func
|
||||
self._custom_opdefs = None
|
||||
|
||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||
if not self._has_valid_tensors():
|
||||
@ -1005,7 +1006,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||
"dump_graphviz_video": self.dump_graphviz_video,
|
||||
"conversion_summary_dir": self.conversion_summary_dir
|
||||
"conversion_summary_dir": self.conversion_summary_dir,
|
||||
"custom_opdefs": self._custom_opdefs,
|
||||
})
|
||||
|
||||
# Converts model.
|
||||
|
@ -174,6 +174,8 @@ def _convert_tf1_model(flags):
|
||||
|
||||
if flags.allow_custom_ops:
|
||||
converter.allow_custom_ops = flags.allow_custom_ops
|
||||
if flags.custom_opdefs:
|
||||
converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access
|
||||
if flags.target_ops:
|
||||
ops_set_options = lite.OpsSet.get_options()
|
||||
converter.target_spec.supported_ops = set()
|
||||
@ -299,6 +301,12 @@ def _check_tf1_flags(flags, unparsed):
|
||||
raise ValueError("--dump_graphviz_video must be used with "
|
||||
"--dump_graphviz_dir")
|
||||
|
||||
if flags.custom_opdefs and not flags.experimental_new_converter:
|
||||
raise ValueError("--custom_opdefs must be used with "
|
||||
"--experimental_new_converter")
|
||||
if flags.custom_opdefs and not flags.allow_custom_ops:
|
||||
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
|
||||
|
||||
|
||||
def _check_tf2_flags(flags):
|
||||
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X.
|
||||
@ -462,6 +470,12 @@ def _get_tf1_flags(parser):
|
||||
"created for any op that is unknown. The developer will need to "
|
||||
"provide these to the TensorFlow Lite runtime with a custom "
|
||||
"resolver. (default False)"))
|
||||
parser.add_argument(
|
||||
"--custom_opdefs",
|
||||
type=str,
|
||||
help=("String representing a list of custom ops OpDefs delineated with "
|
||||
"commas that are included in the GraphDef. Required when using "
|
||||
"custom operations with --experimental_new_converter."))
|
||||
parser.add_argument(
|
||||
"--target_ops",
|
||||
type=str,
|
||||
|
@ -171,6 +171,82 @@ class TfLiteConvertV1Test(TestModels):
|
||||
num_items_conversion_summary = len(os.listdir(log_dir))
|
||||
self.assertEqual(num_items_conversion_summary, 0)
|
||||
|
||||
def _initObjectDetectionArgs(self):
|
||||
# Initializes the arguments required for the object detection model.
|
||||
# Looks for the model file which is saved in a different location internally
|
||||
# and externally.
|
||||
filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
|
||||
if not os.path.exists(filename):
|
||||
filename = os.path.join(
|
||||
resource_loader.get_root_dir_with_all_resources(),
|
||||
'../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("File '{0}' does not exist.".format(filename))
|
||||
|
||||
self._graph_def_file = filename
|
||||
self._input_arrays = 'normalized_input_image_tensor'
|
||||
self._output_arrays = (
|
||||
'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,'
|
||||
'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3')
|
||||
self._input_shapes = '1,300,300,3'
|
||||
|
||||
def testObjectDetection(self):
|
||||
"""Tests object detection model through TOCO."""
|
||||
self._initObjectDetectionArgs()
|
||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2} --input_shapes={3} '
|
||||
'--allow_custom_ops'.format(self._graph_def_file,
|
||||
self._input_arrays,
|
||||
self._output_arrays,
|
||||
self._input_shapes))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
|
||||
def testObjectDetectionMLIR(self):
|
||||
"""Tests object detection model through MLIR converter."""
|
||||
self._initObjectDetectionArgs()
|
||||
custom_opdefs_str = (
|
||||
'name: \'TFLite_Detection_PostProcess\' '
|
||||
'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } '
|
||||
'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } '
|
||||
'input_arg: { name: \'anchors\' type: DT_FLOAT } '
|
||||
'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } '
|
||||
'output_arg: { name: \'TFLite_Detection_PostProcess:1\' '
|
||||
'type: DT_FLOAT } '
|
||||
'output_arg: { name: \'TFLite_Detection_PostProcess:2\' '
|
||||
'type: DT_FLOAT } '
|
||||
'output_arg: { name: \'TFLite_Detection_PostProcess:3\' '
|
||||
'type: DT_FLOAT } '
|
||||
'attr : { name: \'h_scale\' type: \'float\'} '
|
||||
'attr : { name: \'max_classes_per_detection\' type: \'int\'} '
|
||||
'attr : { name: \'max_detections\' type: \'int\'} '
|
||||
'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
|
||||
'attr : { name: \'nms_score_threshold\' type: \'float\'} '
|
||||
'attr : { name: \'num_classes\' type: \'int\'} '
|
||||
'attr : { name: \'w_scale\' type: \'int\'} '
|
||||
'attr : { name: \'x_scale\' type: \'int\'} '
|
||||
'attr : { name: \'y_scale\' type: \'int\'}')
|
||||
|
||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2} --input_shapes={3} '
|
||||
'--custom_opdefs="{4}"'.format(self._graph_def_file,
|
||||
self._input_arrays,
|
||||
self._output_arrays,
|
||||
self._input_shapes,
|
||||
custom_opdefs_str))
|
||||
|
||||
# Ensure --experimental_new_converter.
|
||||
flags_str_final = ('{} --allow_custom_ops').format(flags_str)
|
||||
self._run(flags_str_final, should_succeed=False)
|
||||
|
||||
# Ensure --allow_custom_ops.
|
||||
flags_str_final = ('{} --experimental_new_converter').format(flags_str)
|
||||
self._run(flags_str_final, should_succeed=False)
|
||||
|
||||
# Valid conversion.
|
||||
flags_str_final = ('{} --allow_custom_ops '
|
||||
'--experimental_new_converter').format(flags_str)
|
||||
self._run(flags_str_final, should_succeed=True)
|
||||
|
||||
|
||||
class TfLiteConvertV2Test(TestModels):
|
||||
|
||||
|
@ -173,6 +173,7 @@ struct ParsedTocoFlags {
|
||||
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
||||
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
||||
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
|
||||
Arg<string> custom_opdefs;
|
||||
Arg<bool> post_training_quantize = Arg<bool>(false);
|
||||
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
||||
// Deprecated flags
|
||||
|
@ -124,6 +124,10 @@ bool ParseTocoFlagsFromCommandLineFlags(
|
||||
parsed_flags.allow_custom_ops.default_value(),
|
||||
"If true, allow TOCO to create TF Lite Custom operators for all the "
|
||||
"unsupported TensorFlow ops."),
|
||||
Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
|
||||
parsed_flags.custom_opdefs.default_value(),
|
||||
"List of strings representing custom ops OpDefs that are included "
|
||||
"in the GraphDef."),
|
||||
Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
|
||||
parsed_flags.allow_dynamic_tensors.default_value(),
|
||||
"Boolean flag indicating whether the converter should allow models "
|
||||
|
@ -38,7 +38,7 @@ enum FileFormat {
|
||||
// of as properties of models, instead describing how models are to be
|
||||
// processed in the context of the present tooling job.
|
||||
//
|
||||
// Next ID to use: 32.
|
||||
// Next ID to use: 33.
|
||||
message TocoFlags {
|
||||
// Input file format
|
||||
optional FileFormat input_format = 1;
|
||||
@ -222,4 +222,8 @@ message TocoFlags {
|
||||
// Full filepath of the folder to dump conversion logs. This includes a global
|
||||
// view of the conversion process, and user can choose to submit those logs.
|
||||
optional string conversion_summary_dir = 31;
|
||||
|
||||
// String representing the custom ops OpDefs that are included in the
|
||||
// GraphDef.
|
||||
repeated string custom_opdefs = 32;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user