Propagate node debug information.
PiperOrigin-RevId: 257286387
This commit is contained in:
parent
1ef629438e
commit
8d9b34c4cd
@ -93,7 +93,11 @@ class ConverterError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
def toco_convert_protos(model_flags_str,
|
||||||
|
toco_flags_str,
|
||||||
|
input_data_str,
|
||||||
|
debug_info_str="",
|
||||||
|
enable_mlir_converter=False):
|
||||||
"""Convert `input_data_str` according to model and toco parameters.
|
"""Convert `input_data_str` according to model and toco parameters.
|
||||||
|
|
||||||
Unless you know what you are doing consider using
|
Unless you know what you are doing consider using
|
||||||
@ -105,6 +109,10 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
|||||||
toco_flags_str: Serialized proto describing conversion properties, see
|
toco_flags_str: Serialized proto describing conversion properties, see
|
||||||
`toco/toco_flags.proto`.
|
`toco/toco_flags.proto`.
|
||||||
input_data_str: Input data in serialized form (e.g. a graphdef is common)
|
input_data_str: Input data in serialized form (e.g. a graphdef is common)
|
||||||
|
debug_info_str: Serialized `GraphDebugInfo` proto describing logging
|
||||||
|
information. (default "")
|
||||||
|
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
|
||||||
|
converter. (default False)
|
||||||
Returns:
|
Returns:
|
||||||
Converted model in serialized form (e.g. a TFLITE model is common).
|
Converted model in serialized form (e.g. a TFLITE model is common).
|
||||||
Raises:
|
Raises:
|
||||||
@ -118,10 +126,12 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
|||||||
if not _toco_from_proto_bin:
|
if not _toco_from_proto_bin:
|
||||||
try:
|
try:
|
||||||
model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
|
model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
|
||||||
toco_flags_str, input_data_str)
|
toco_flags_str, input_data_str,
|
||||||
|
debug_info_str,
|
||||||
|
enable_mlir_converter)
|
||||||
return model_str
|
return model_str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ConverterError("TOCO failed: %s" % e)
|
raise ConverterError(str(e))
|
||||||
|
|
||||||
# Windows and TemporaryFile are not that useful together,
|
# Windows and TemporaryFile are not that useful together,
|
||||||
# since you cannot have two readers/writers. So we have to
|
# since you cannot have two readers/writers. So we have to
|
||||||
@ -132,16 +142,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
|||||||
# Build all input files
|
# Build all input files
|
||||||
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
|
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
|
||||||
_tempfile.NamedTemporaryFile(delete=False) as fp_model, \
|
_tempfile.NamedTemporaryFile(delete=False) as fp_model, \
|
||||||
_tempfile.NamedTemporaryFile(delete=False) as fp_input:
|
_tempfile.NamedTemporaryFile(delete=False) as fp_input, \
|
||||||
|
_tempfile.NamedTemporaryFile(delete=False) as fp_debug:
|
||||||
toco_filename = fp_toco.name
|
toco_filename = fp_toco.name
|
||||||
input_filename = fp_input.name
|
input_filename = fp_input.name
|
||||||
model_filename = fp_model.name
|
model_filename = fp_model.name
|
||||||
|
debug_filename = fp_debug.name
|
||||||
|
|
||||||
fp_model.write(model_flags_str)
|
fp_model.write(model_flags_str)
|
||||||
fp_toco.write(toco_flags_str)
|
fp_toco.write(toco_flags_str)
|
||||||
fp_input.write(input_data_str)
|
fp_input.write(input_data_str)
|
||||||
fp_model.flush()
|
fp_debug.write(debug_info_str)
|
||||||
fp_toco.flush()
|
|
||||||
fp_input.flush()
|
|
||||||
|
|
||||||
# Reserve an output file
|
# Reserve an output file
|
||||||
with _tempfile.NamedTemporaryFile(delete=False) as fp:
|
with _tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||||
@ -149,9 +160,15 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
|||||||
|
|
||||||
# Run
|
# Run
|
||||||
cmd = [
|
cmd = [
|
||||||
_toco_from_proto_bin, model_filename, toco_filename, input_filename,
|
_toco_from_proto_bin,
|
||||||
output_filename
|
model_filename,
|
||||||
|
toco_filename,
|
||||||
|
input_filename,
|
||||||
|
output_filename,
|
||||||
|
"--debug_proto_file={}".format(debug_filename),
|
||||||
]
|
]
|
||||||
|
if enable_mlir_converter:
|
||||||
|
cmd.append("--enable_mlir_converter")
|
||||||
cmdline = " ".join(cmd)
|
cmdline = " ".join(cmd)
|
||||||
is_windows = _platform.system() == "Windows"
|
is_windows = _platform.system() == "Windows"
|
||||||
proc = _subprocess.Popen(
|
proc = _subprocess.Popen(
|
||||||
@ -168,8 +185,7 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
|||||||
else:
|
else:
|
||||||
stdout = _try_convert_to_unicode(stdout)
|
stdout = _try_convert_to_unicode(stdout)
|
||||||
stderr = _try_convert_to_unicode(stderr)
|
stderr = _try_convert_to_unicode(stderr)
|
||||||
raise ConverterError(
|
raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
|
||||||
"TOCO failed. See console for info.\n%s\n%s\n" % (stdout, stderr))
|
|
||||||
finally:
|
finally:
|
||||||
# Must manually cleanup files.
|
# Must manually cleanup files.
|
||||||
for filename in [
|
for filename in [
|
||||||
@ -211,9 +227,9 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
output_tensors: List of output tensors (only .name is used from this).
|
output_tensors: List of output tensors (only .name is used from this).
|
||||||
inference_type: Target data type of real-number arrays in the output file.
|
inference_type: Target data type of real-number arrays in the output file.
|
||||||
Must be `{tf.float32, tf.uint8}`. (default tf.float32)
|
Must be `{tf.float32, tf.uint8}`. (default tf.float32)
|
||||||
|
Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
|
||||||
inference_input_type: Target data type of real-number input arrays. Allows
|
inference_input_type: Target data type of real-number input arrays. Allows
|
||||||
for a different type for input arrays in the case of quantization.
|
for a different type for input arrays in the case of quantization.
|
||||||
Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
|
|
||||||
input_format: Type of data to read Currently must be
|
input_format: Type of data to read Currently must be
|
||||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||||
input_shapes: Input array shape. It needs to be a list of the same length
|
input_shapes: Input array shape. It needs to be a list of the same length
|
||||||
@ -330,7 +346,7 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
|
|
||||||
|
|
||||||
def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||||
*args, **kwargs):
|
enable_mlir_converter, *args, **kwargs):
|
||||||
""""Convert a model using TOCO.
|
""""Convert a model using TOCO.
|
||||||
|
|
||||||
This function is used to convert GraphDefs that cannot be loaded into
|
This function is used to convert GraphDefs that cannot be loaded into
|
||||||
@ -347,6 +363,8 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
|||||||
output_arrays: List of output tensors to freeze graph with. Use only when
|
output_arrays: List of output tensors to freeze graph with. Use only when
|
||||||
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
||||||
(default None)
|
(default None)
|
||||||
|
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
|
||||||
|
converter.
|
||||||
*args: See `build_toco_convert_protos`,
|
*args: See `build_toco_convert_protos`,
|
||||||
**kwargs: See `build_toco_convert_protos`.
|
**kwargs: See `build_toco_convert_protos`.
|
||||||
|
|
||||||
@ -375,14 +393,16 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
|||||||
for name in output_arrays:
|
for name in output_arrays:
|
||||||
model_flags.output_arrays.append(name)
|
model_flags.output_arrays.append(name)
|
||||||
|
|
||||||
data = toco_convert_protos(model_flags.SerializeToString(),
|
data = toco_convert_protos(
|
||||||
|
model_flags.SerializeToString(),
|
||||||
toco_flags.SerializeToString(),
|
toco_flags.SerializeToString(),
|
||||||
input_data.SerializeToString())
|
input_data.SerializeToString(),
|
||||||
|
enable_mlir_converter=enable_mlir_converter)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
def toco_convert_impl(input_data, input_tensors, output_tensors,
|
||||||
**kwargs):
|
enable_mlir_converter, *args, **kwargs):
|
||||||
""""Convert a model using TOCO.
|
""""Convert a model using TOCO.
|
||||||
|
|
||||||
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
||||||
@ -394,6 +414,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
|||||||
input_tensors: List of input tensors. Type and shape are computed using
|
input_tensors: List of input tensors. Type and shape are computed using
|
||||||
`foo.shape` and `foo.dtype`.
|
`foo.shape` and `foo.dtype`.
|
||||||
output_tensors: List of output tensors (only .name is used from this).
|
output_tensors: List of output tensors (only .name is used from this).
|
||||||
|
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
|
||||||
|
converter.
|
||||||
*args: See `build_toco_convert_protos`,
|
*args: See `build_toco_convert_protos`,
|
||||||
**kwargs: See `build_toco_convert_protos`.
|
**kwargs: See `build_toco_convert_protos`.
|
||||||
|
|
||||||
@ -404,11 +426,15 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
|||||||
Raises:
|
Raises:
|
||||||
Defined in `build_toco_convert_protos`.
|
Defined in `build_toco_convert_protos`.
|
||||||
"""
|
"""
|
||||||
model_flags, toco_flags, _ = build_toco_convert_protos(
|
model_flags, toco_flags, debug_info = build_toco_convert_protos(
|
||||||
input_tensors, output_tensors, *args, **kwargs)
|
input_tensors, output_tensors, *args, **kwargs)
|
||||||
data = toco_convert_protos(model_flags.SerializeToString(),
|
debug_info_str = debug_info.SerializeToString() if debug_info else ""
|
||||||
|
data = toco_convert_protos(
|
||||||
|
model_flags.SerializeToString(),
|
||||||
toco_flags.SerializeToString(),
|
toco_flags.SerializeToString(),
|
||||||
input_data.SerializeToString())
|
input_data.SerializeToString(),
|
||||||
|
debug_info_str=debug_info_str,
|
||||||
|
enable_mlir_converter=enable_mlir_converter)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@ -437,5 +463,6 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
|
|||||||
Raises:
|
Raises:
|
||||||
Defined in `build_toco_convert_protos`.
|
Defined in `build_toco_convert_protos`.
|
||||||
"""
|
"""
|
||||||
return toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
enable_mlir_converter = kwargs.get("enable_mlir_converter", False)
|
||||||
**kwargs)
|
return toco_convert_impl(input_data, input_tensors, output_tensors,
|
||||||
|
enable_mlir_converter, *args, **kwargs)
|
||||||
|
@ -90,6 +90,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
tflite_model = convert.toco_convert_graph_def(
|
tflite_model = convert.toco_convert_graph_def(
|
||||||
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
||||||
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.FLOAT)
|
inference_type=lite_constants.FLOAT)
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
@ -126,6 +127,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
sess.graph_def,
|
sess.graph_def,
|
||||||
input_arrays_map,
|
input_arrays_map,
|
||||||
output_arrays,
|
output_arrays,
|
||||||
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||||
quantized_input_stats=[(0., 1.), (0., 1.)])
|
quantized_input_stats=[(0., 1.), (0., 1.)])
|
||||||
self.assertTrue(tflite_model)
|
self.assertTrue(tflite_model)
|
||||||
@ -171,6 +173,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
|||||||
sess.graph_def,
|
sess.graph_def,
|
||||||
input_arrays_map,
|
input_arrays_map,
|
||||||
output_arrays,
|
output_arrays,
|
||||||
|
enable_mlir_converter=False,
|
||||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"std_dev and mean must be defined when inference_input_type is "
|
"std_dev and mean must be defined when inference_input_type is "
|
||||||
|
@ -233,6 +233,25 @@ class TFLiteConverterBase(object):
|
|||||||
self.representative_dataset.input_gen, inference_input_type,
|
self.representative_dataset.input_gen, inference_input_type,
|
||||||
inference_output_type, allow_float)
|
inference_output_type, allow_float)
|
||||||
|
|
||||||
|
def _get_base_converter_args(self):
|
||||||
|
"""Returns the base converter args.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{key str: val}
|
||||||
|
"""
|
||||||
|
float16_quantize = self._is_float16_quantize()
|
||||||
|
args = {
|
||||||
|
"input_format": constants.TENSORFLOW_GRAPHDEF,
|
||||||
|
"allow_custom_ops": self.allow_custom_ops,
|
||||||
|
"post_training_quantize": (self._is_int8_weight_only_quantize() or
|
||||||
|
float16_quantize),
|
||||||
|
"quantize_to_float16": float16_quantize,
|
||||||
|
"debug_info": self._debug_info,
|
||||||
|
"target_ops": self._target_ops,
|
||||||
|
"enable_mlir_converter": self.experimental_enable_mlir_converter,
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
@_tf_export("lite.TFLiteConverter", v1=[])
|
@_tf_export("lite.TFLiteConverter", v1=[])
|
||||||
class TFLiteConverterV2(TFLiteConverterBase):
|
class TFLiteConverterV2(TFLiteConverterBase):
|
||||||
@ -251,6 +270,8 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
representative_dataset: A representative dataset that can be used to
|
representative_dataset: A representative dataset that can be used to
|
||||||
generate input and output samples for the model. The converter can use the
|
generate input and output samples for the model. The converter can use the
|
||||||
dataset to evaluate different optimizations.
|
dataset to evaluate different optimizations.
|
||||||
|
experimental_enable_mlir_converter: Experimental flag, subject to change.
|
||||||
|
Enables the MLIR converter instead of the TOCO converter.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
@ -287,6 +308,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
self.allow_custom_ops = False
|
self.allow_custom_ops = False
|
||||||
self.target_spec = TargetSpec()
|
self.target_spec = TargetSpec()
|
||||||
self._debug_info = None
|
self._debug_info = None
|
||||||
|
self.experimental_enable_mlir_converter = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_concrete_functions(cls, funcs):
|
def from_concrete_functions(cls, funcs):
|
||||||
@ -414,23 +436,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
self._validate_representative_dataset()
|
self._validate_representative_dataset()
|
||||||
self._debug_info = _get_debug_info(
|
self._debug_info = _get_debug_info(
|
||||||
_build_debug_info_func(self._funcs[0].graph), graph_def)
|
_build_debug_info_func(self._funcs[0].graph), graph_def)
|
||||||
|
converter_kwargs = self._get_base_converter_args()
|
||||||
float16_quantize = self._is_float16_quantize()
|
|
||||||
|
|
||||||
converter_kwargs = {
|
|
||||||
"input_format":
|
|
||||||
constants.TENSORFLOW_GRAPHDEF,
|
|
||||||
"allow_custom_ops":
|
|
||||||
self.allow_custom_ops,
|
|
||||||
"post_training_quantize":
|
|
||||||
self._is_int8_weight_only_quantize() or float16_quantize,
|
|
||||||
"quantize_to_float16":
|
|
||||||
float16_quantize,
|
|
||||||
"target_ops":
|
|
||||||
self.target_spec.supported_ops,
|
|
||||||
"debug_info":
|
|
||||||
self._debug_info
|
|
||||||
}
|
|
||||||
|
|
||||||
# Converts model.
|
# Converts model.
|
||||||
result = _toco_convert_impl(
|
result = _toco_convert_impl(
|
||||||
@ -522,6 +528,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
representative_dataset: A representative dataset that can be used to
|
representative_dataset: A representative dataset that can be used to
|
||||||
generate input and output samples for the model. The converter can use
|
generate input and output samples for the model. The converter can use
|
||||||
the dataset to evaluate different optimizations.
|
the dataset to evaluate different optimizations.
|
||||||
|
experimental_enable_mlir_converter: Experimental flag, subject to change.
|
||||||
|
Enables the MLIR converter instead of the TOCO converter.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
@ -597,6 +605,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
self.target_spec = TargetSpec()
|
self.target_spec = TargetSpec()
|
||||||
self._debug_info_func = experimental_debug_info_func
|
self._debug_info_func = experimental_debug_info_func
|
||||||
self._debug_info = None
|
self._debug_info = None
|
||||||
|
self.experimental_enable_mlir_converter = False
|
||||||
|
|
||||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||||
if not self._has_valid_tensors():
|
if not self._has_valid_tensors():
|
||||||
@ -939,31 +948,11 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
"Provide an inference_input_type and inference_output_type of type "
|
"Provide an inference_input_type and inference_output_type of type "
|
||||||
"tf.float32.")
|
"tf.float32.")
|
||||||
|
|
||||||
float16_quantize = self._is_float16_quantize()
|
|
||||||
|
|
||||||
if not post_training_optimize and self.inference_output_type is not None:
|
if not post_training_optimize and self.inference_output_type is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"inference_output_type is currently not supported if optimizations "
|
"inference_output_type is currently not supported if optimizations "
|
||||||
"are not enabled.")
|
"are not enabled.")
|
||||||
|
|
||||||
converter_kwargs = {
|
|
||||||
"inference_type": self.inference_type,
|
|
||||||
"inference_input_type": toco_inference_input_type,
|
|
||||||
"input_format": constants.TENSORFLOW_GRAPHDEF,
|
|
||||||
"output_format": self.output_format,
|
|
||||||
"quantized_input_stats": quantized_stats,
|
|
||||||
"default_ranges_stats": self.default_ranges_stats,
|
|
||||||
"drop_control_dependency": self.drop_control_dependency,
|
|
||||||
"reorder_across_fake_quant": self.reorder_across_fake_quant,
|
|
||||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
|
||||||
"allow_custom_ops": self.allow_custom_ops,
|
|
||||||
"post_training_quantize": weight_only_quantize or float16_quantize,
|
|
||||||
"quantize_to_float16": float16_quantize,
|
|
||||||
"target_ops": self._target_ops,
|
|
||||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
|
||||||
"dump_graphviz_video": self.dump_graphviz_video
|
|
||||||
}
|
|
||||||
|
|
||||||
optimized_graph = self._graph_def
|
optimized_graph = self._graph_def
|
||||||
if self.inference_type != constants.QUANTIZED_UINT8:
|
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||||
try:
|
try:
|
||||||
@ -977,13 +966,26 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
|
|
||||||
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
|
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
|
||||||
|
|
||||||
|
converter_kwargs = self._get_base_converter_args()
|
||||||
|
converter_kwargs.update({
|
||||||
|
"inference_type": self.inference_type,
|
||||||
|
"inference_input_type": toco_inference_input_type,
|
||||||
|
"output_format": self.output_format,
|
||||||
|
"quantized_input_stats": quantized_stats,
|
||||||
|
"default_ranges_stats": self.default_ranges_stats,
|
||||||
|
"drop_control_dependency": self.drop_control_dependency,
|
||||||
|
"reorder_across_fake_quant": self.reorder_across_fake_quant,
|
||||||
|
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||||
|
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||||
|
"dump_graphviz_video": self.dump_graphviz_video
|
||||||
|
})
|
||||||
|
|
||||||
# Converts model.
|
# Converts model.
|
||||||
if self._has_valid_tensors():
|
if self._has_valid_tensors():
|
||||||
result = _toco_convert_impl(
|
result = _toco_convert_impl(
|
||||||
input_data=optimized_graph,
|
input_data=optimized_graph,
|
||||||
input_tensors=self._input_tensors,
|
input_tensors=self._input_tensors,
|
||||||
output_tensors=self._output_tensors,
|
output_tensors=self._output_tensors,
|
||||||
debug_info=self._debug_info,
|
|
||||||
**converter_kwargs)
|
**converter_kwargs)
|
||||||
else:
|
else:
|
||||||
result = _toco_convert_graph_def(
|
result = _toco_convert_graph_def(
|
||||||
|
@ -29,10 +29,16 @@ _toco_python = LazyLoader(
|
|||||||
del LazyLoader
|
del LazyLoader
|
||||||
|
|
||||||
|
|
||||||
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str):
|
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
|
||||||
|
debug_info_str, enable_mlir_converter):
|
||||||
"""Wraps TocoConvert with lazy loader."""
|
"""Wraps TocoConvert with lazy loader."""
|
||||||
return _toco_python.TocoConvert(model_flags_str, toco_flags_str,
|
return _toco_python.TocoConvert(
|
||||||
input_data_str)
|
model_flags_str,
|
||||||
|
toco_flags_str,
|
||||||
|
input_data_str,
|
||||||
|
False, # extended_return
|
||||||
|
debug_info_str,
|
||||||
|
enable_mlir_converter)
|
||||||
|
|
||||||
|
|
||||||
def wrapped_get_potentially_supported_ops():
|
def wrapped_get_potentially_supported_ops():
|
||||||
|
@ -22,9 +22,11 @@ cc_library(
|
|||||||
name = "toco_python_api",
|
name = "toco_python_api",
|
||||||
srcs = ["toco_python_api.cc"],
|
srcs = ["toco_python_api.cc"],
|
||||||
hdrs = ["toco_python_api.h"],
|
hdrs = ["toco_python_api.h"],
|
||||||
|
features = ["no_layering_check"],
|
||||||
deps = [
|
deps = [
|
||||||
"//third_party/python_runtime:headers",
|
"//third_party/python_runtime:headers",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
|
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
|
||||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||||
|
@ -26,11 +26,16 @@ namespace toco {
|
|||||||
// parameters (see relevant .protos for more information). Returns a string
|
// parameters (see relevant .protos for more information). Returns a string
|
||||||
// representing the contents of the converted model. When extended_return
|
// representing the contents of the converted model. When extended_return
|
||||||
// flag is set to true returns a dictionary that contains string representation
|
// flag is set to true returns a dictionary that contains string representation
|
||||||
// of the converted model and some statitics like arithmetic ops count.
|
// of the converted model and some statistics like arithmetic ops count.
|
||||||
|
// `debug_info_str` contains the `GraphDebugInfo` proto. When
|
||||||
|
// `enable_mlir_converter` is True, the MLIR converter is used instead of the
|
||||||
|
// TOCO converter.
|
||||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||||
PyObject* toco_flags_proto_txt_raw,
|
PyObject* toco_flags_proto_txt_raw,
|
||||||
PyObject* input_contents_txt_raw,
|
PyObject* input_contents_txt_raw,
|
||||||
bool extended_return = false);
|
bool extended_return = false,
|
||||||
|
PyObject* debug_info_txt_raw = nullptr,
|
||||||
|
bool enable_mlir_converter = false);
|
||||||
|
|
||||||
// Returns a list of names of all ops potentially supported by tflite.
|
// Returns a list of names of all ops potentially supported by tflite.
|
||||||
PyObject* TocoGetPotentiallySupportedOps();
|
PyObject* TocoGetPotentiallySupportedOps();
|
||||||
|
@ -26,11 +26,30 @@ FLAGS = None
|
|||||||
|
|
||||||
|
|
||||||
def execute(unused_args):
|
def execute(unused_args):
|
||||||
model_str = open(FLAGS.model_proto_file, "rb").read()
|
"""Runs the converter."""
|
||||||
toco_str = open(FLAGS.toco_proto_file, "rb").read()
|
with open(FLAGS.model_proto_file, "rb") as model_file:
|
||||||
input_str = open(FLAGS.model_input_file, "rb").read()
|
model_str = model_file.read()
|
||||||
|
|
||||||
output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str)
|
with open(FLAGS.toco_proto_file, "rb") as toco_file:
|
||||||
|
toco_str = toco_file.read()
|
||||||
|
|
||||||
|
with open(FLAGS.model_input_file, "rb") as input_file:
|
||||||
|
input_str = input_file.read()
|
||||||
|
|
||||||
|
debug_info_str = ""
|
||||||
|
if FLAGS.debug_proto_file:
|
||||||
|
with open(FLAGS.debug_proto_file, "rb") as debug_info_file:
|
||||||
|
debug_info_str = debug_info_file.read()
|
||||||
|
|
||||||
|
enable_mlir_converter = FLAGS.enable_mlir_converter
|
||||||
|
|
||||||
|
output_str = tensorflow_wrap_toco.TocoConvert(
|
||||||
|
model_str,
|
||||||
|
toco_str,
|
||||||
|
input_str,
|
||||||
|
False, # extended_return
|
||||||
|
debug_info_str,
|
||||||
|
enable_mlir_converter)
|
||||||
open(FLAGS.model_output_file, "wb").write(output_str)
|
open(FLAGS.model_output_file, "wb").write(output_str)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
@ -53,6 +72,17 @@ def main():
|
|||||||
"model_output_file",
|
"model_output_file",
|
||||||
type=str,
|
type=str,
|
||||||
help="Result of applying TOCO conversion is written here.")
|
help="Result of applying TOCO conversion is written here.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug_proto_file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=("File containing serialized `GraphDebugInfo` proto that describes "
|
||||||
|
"logging information."))
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_mlir_converter",
|
||||||
|
action="store_true",
|
||||||
|
help=("Boolean indiciating whether to enable the MLIR converter instead "
|
||||||
|
"of TOCO converter. (default False)"))
|
||||||
|
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -20,20 +22,27 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||||
#include "tensorflow/lite/toco/import_tensorflow.h"
|
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||||
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
|
||||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||||
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
|
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
|
||||||
#include "tensorflow/lite/toco/toco_port.h"
|
#include "tensorflow/lite/toco/toco_port.h"
|
||||||
#include "tensorflow/lite/toco/toco_tooling.h"
|
#include "tensorflow/lite/toco/toco_tooling.h"
|
||||||
#include "tensorflow/lite/toco/toco_types.h"
|
#include "tensorflow/lite/toco/toco_types.h"
|
||||||
|
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
||||||
|
#else
|
||||||
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
// NOTE(aselle): We are using raw PyObject's here because we want to make
|
// NOTE(aselle): We are using raw PyObject's here because we want to make
|
||||||
// sure we input and output bytes rather than unicode strings for Python3.
|
// sure we input and output bytes rather than unicode strings for Python3.
|
||||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||||
PyObject* toco_flags_proto_txt_raw,
|
PyObject* toco_flags_proto_txt_raw,
|
||||||
PyObject* input_contents_txt_raw, bool extended_return) {
|
PyObject* input_contents_txt_raw, bool extended_return,
|
||||||
|
PyObject* debug_info_txt_raw,
|
||||||
|
bool enable_mlir_converter) {
|
||||||
// Use Python C API to validate and convert arguments. In py3 (bytes),
|
// Use Python C API to validate and convert arguments. In py3 (bytes),
|
||||||
// in py2 (str).
|
// in py2 (str).
|
||||||
auto ConvertArg = [&](PyObject* obj, bool* error) {
|
auto ConvertArg = [&](PyObject* obj, bool* error) {
|
||||||
@ -70,12 +79,35 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
// Use TOCO to produce new outputs.
|
// Use TOCO to produce new outputs.
|
||||||
toco::ModelFlags model_flags;
|
toco::ModelFlags model_flags;
|
||||||
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
|
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
|
||||||
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse.");
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
"Failed to convert Model to Python String.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
toco::TocoFlags toco_flags;
|
toco::TocoFlags toco_flags;
|
||||||
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
|
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
|
||||||
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse.");
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
"Failed to convert Toco to Python String.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::GraphDebugInfo debug_info;
|
||||||
|
if (debug_info_txt_raw) {
|
||||||
|
std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
|
||||||
|
if (error) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!debug_info.ParseFromString(debug_info_txt)) {
|
||||||
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
"Failed to convert DebugInfo to Python String.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::GraphDef graph_def;
|
||||||
|
if (!graph_def.ParseFromString(input_contents_txt)) {
|
||||||
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
"Failed to convert GraphDef to Python String.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,18 +119,36 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
|
dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert model.
|
|
||||||
std::unique_ptr<toco::Model> model =
|
|
||||||
toco::Import(toco_flags, model_flags, input_contents_txt);
|
|
||||||
toco::Transform(toco_flags, model.get());
|
|
||||||
string output_file_contents_txt;
|
string output_file_contents_txt;
|
||||||
auto status = Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
tensorflow::Status status;
|
||||||
|
std::unique_ptr<toco::Model> model;
|
||||||
|
|
||||||
|
// Convert model.
|
||||||
|
if (enable_mlir_converter) {
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
|
||||||
|
model_flags, toco_flags, debug_info, graph_def,
|
||||||
&output_file_contents_txt);
|
&output_file_contents_txt);
|
||||||
|
#else
|
||||||
|
// TODO(b/124314620): Remove this condition.
|
||||||
|
PyErr_SetString(PyExc_Exception,
|
||||||
|
"This flag is not supported by this version of the "
|
||||||
|
"TFLite converter. This functionality is being "
|
||||||
|
"actively worked on.");
|
||||||
|
return nullptr;
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
model = toco::Import(toco_flags, model_flags, input_contents_txt);
|
||||||
|
toco::Transform(toco_flags, model.get());
|
||||||
|
status = Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||||
|
&output_file_contents_txt);
|
||||||
|
}
|
||||||
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
PyErr_SetString(PyExc_Exception, status.error_message().c_str());
|
PyErr_SetString(PyExc_Exception, status.error_message().c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (extended_return) {
|
if (extended_return && !enable_mlir_converter) {
|
||||||
PyObject* dict = PyDict_New();
|
PyObject* dict = PyDict_New();
|
||||||
PyDict_SetItemString(
|
PyDict_SetItemString(
|
||||||
dict, "flatbuffer",
|
dict, "flatbuffer",
|
||||||
|
@ -25,11 +25,16 @@ namespace toco {
|
|||||||
// parameters (see relevant .protos for more information). Returns a string
|
// parameters (see relevant .protos for more information). Returns a string
|
||||||
// representing the contents of the converted model. When extended_return
|
// representing the contents of the converted model. When extended_return
|
||||||
// flag is set to true returns a dictionary that contains string representation
|
// flag is set to true returns a dictionary that contains string representation
|
||||||
// of the converted model and some statitics like arithmetic ops count.
|
// of the converted model and some statistics like arithmetic ops count.
|
||||||
|
// `debug_info_str` contains the `GraphDebugInfo` proto. When
|
||||||
|
// `enable_mlir_converter` is True, the MLIR converter is used instead of the
|
||||||
|
// TOCO converter.
|
||||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||||
PyObject* toco_flags_proto_txt_raw,
|
PyObject* toco_flags_proto_txt_raw,
|
||||||
PyObject* input_contents_txt_raw,
|
PyObject* input_contents_txt_raw,
|
||||||
bool extended_return = false);
|
bool extended_return = false,
|
||||||
|
PyObject* debug_info_txt_raw = nullptr,
|
||||||
|
bool enable_mlir_converter = false);
|
||||||
|
|
||||||
// Returns a list of names of all ops potentially supported by tflite.
|
// Returns a list of names of all ops potentially supported by tflite.
|
||||||
PyObject* TocoGetPotentiallySupportedOps();
|
PyObject* TocoGetPotentiallySupportedOps();
|
||||||
|
Loading…
Reference in New Issue
Block a user