Propagate node debug information.

PiperOrigin-RevId: 257286387
This commit is contained in:
Nupur Garg 2019-07-09 15:39:18 -07:00 committed by TensorFlower Gardener
parent 1ef629438e
commit 8d9b34c4cd
9 changed files with 216 additions and 86 deletions

View File

@ -93,7 +93,11 @@ class ConverterError(Exception):
pass pass
def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): def toco_convert_protos(model_flags_str,
toco_flags_str,
input_data_str,
debug_info_str="",
enable_mlir_converter=False):
"""Convert `input_data_str` according to model and toco parameters. """Convert `input_data_str` according to model and toco parameters.
Unless you know what you are doing consider using Unless you know what you are doing consider using
@ -105,6 +109,10 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
toco_flags_str: Serialized proto describing conversion properties, see toco_flags_str: Serialized proto describing conversion properties, see
`toco/toco_flags.proto`. `toco/toco_flags.proto`.
input_data_str: Input data in serialized form (e.g. a graphdef is common) input_data_str: Input data in serialized form (e.g. a graphdef is common)
debug_info_str: Serialized `GraphDebugInfo` proto describing logging
information. (default "")
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
converter. (default False)
Returns: Returns:
Converted model in serialized form (e.g. a TFLITE model is common). Converted model in serialized form (e.g. a TFLITE model is common).
Raises: Raises:
@ -118,10 +126,12 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
if not _toco_from_proto_bin: if not _toco_from_proto_bin:
try: try:
model_str = wrap_toco.wrapped_toco_convert(model_flags_str, model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
toco_flags_str, input_data_str) toco_flags_str, input_data_str,
debug_info_str,
enable_mlir_converter)
return model_str return model_str
except Exception as e: except Exception as e:
raise ConverterError("TOCO failed: %s" % e) raise ConverterError(str(e))
# Windows and TemporaryFile are not that useful together, # Windows and TemporaryFile are not that useful together,
# since you cannot have two readers/writers. So we have to # since you cannot have two readers/writers. So we have to
@ -132,16 +142,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
# Build all input files # Build all input files
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \ with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
_tempfile.NamedTemporaryFile(delete=False) as fp_model, \ _tempfile.NamedTemporaryFile(delete=False) as fp_model, \
_tempfile.NamedTemporaryFile(delete=False) as fp_input: _tempfile.NamedTemporaryFile(delete=False) as fp_input, \
_tempfile.NamedTemporaryFile(delete=False) as fp_debug:
toco_filename = fp_toco.name toco_filename = fp_toco.name
input_filename = fp_input.name input_filename = fp_input.name
model_filename = fp_model.name model_filename = fp_model.name
debug_filename = fp_debug.name
fp_model.write(model_flags_str) fp_model.write(model_flags_str)
fp_toco.write(toco_flags_str) fp_toco.write(toco_flags_str)
fp_input.write(input_data_str) fp_input.write(input_data_str)
fp_model.flush() fp_debug.write(debug_info_str)
fp_toco.flush()
fp_input.flush()
# Reserve an output file # Reserve an output file
with _tempfile.NamedTemporaryFile(delete=False) as fp: with _tempfile.NamedTemporaryFile(delete=False) as fp:
@ -149,9 +160,15 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
# Run # Run
cmd = [ cmd = [
_toco_from_proto_bin, model_filename, toco_filename, input_filename, _toco_from_proto_bin,
output_filename model_filename,
toco_filename,
input_filename,
output_filename,
"--debug_proto_file={}".format(debug_filename),
] ]
if enable_mlir_converter:
cmd.append("--enable_mlir_converter")
cmdline = " ".join(cmd) cmdline = " ".join(cmd)
is_windows = _platform.system() == "Windows" is_windows = _platform.system() == "Windows"
proc = _subprocess.Popen( proc = _subprocess.Popen(
@ -168,8 +185,7 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
else: else:
stdout = _try_convert_to_unicode(stdout) stdout = _try_convert_to_unicode(stdout)
stderr = _try_convert_to_unicode(stderr) stderr = _try_convert_to_unicode(stderr)
raise ConverterError( raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
"TOCO failed. See console for info.\n%s\n%s\n" % (stdout, stderr))
finally: finally:
# Must manually cleanup files. # Must manually cleanup files.
for filename in [ for filename in [
@ -211,9 +227,9 @@ def build_toco_convert_protos(input_tensors,
output_tensors: List of output tensors (only .name is used from this). output_tensors: List of output tensors (only .name is used from this).
inference_type: Target data type of real-number arrays in the output file. inference_type: Target data type of real-number arrays in the output file.
Must be `{tf.float32, tf.uint8}`. (default tf.float32) Must be `{tf.float32, tf.uint8}`. (default tf.float32)
Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
inference_input_type: Target data type of real-number input arrays. Allows inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays in the case of quantization. for a different type for input arrays in the case of quantization.
Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
input_format: Type of data to read Currently must be input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
input_shapes: Input array shape. It needs to be a list of the same length input_shapes: Input array shape. It needs to be a list of the same length
@ -330,7 +346,7 @@ def build_toco_convert_protos(input_tensors,
def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays, def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
*args, **kwargs): enable_mlir_converter, *args, **kwargs):
""""Convert a model using TOCO. """"Convert a model using TOCO.
This function is used to convert GraphDefs that cannot be loaded into This function is used to convert GraphDefs that cannot be loaded into
@ -347,6 +363,8 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
output_arrays: List of output tensors to freeze graph with. Use only when output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `output_tensors` is None. graph cannot be loaded into TensorFlow and when `output_tensors` is None.
(default None) (default None)
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
converter.
*args: See `build_toco_convert_protos`, *args: See `build_toco_convert_protos`,
**kwargs: See `build_toco_convert_protos`. **kwargs: See `build_toco_convert_protos`.
@ -375,14 +393,16 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
for name in output_arrays: for name in output_arrays:
model_flags.output_arrays.append(name) model_flags.output_arrays.append(name)
data = toco_convert_protos(model_flags.SerializeToString(), data = toco_convert_protos(
model_flags.SerializeToString(),
toco_flags.SerializeToString(), toco_flags.SerializeToString(),
input_data.SerializeToString()) input_data.SerializeToString(),
enable_mlir_converter=enable_mlir_converter)
return data return data
def toco_convert_impl(input_data, input_tensors, output_tensors, *args, def toco_convert_impl(input_data, input_tensors, output_tensors,
**kwargs): enable_mlir_converter, *args, **kwargs):
""""Convert a model using TOCO. """"Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite. Typically this function is used to convert from TensorFlow GraphDef to TFLite.
@ -394,6 +414,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
input_tensors: List of input tensors. Type and shape are computed using input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`. `foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this). output_tensors: List of output tensors (only .name is used from this).
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
converter.
*args: See `build_toco_convert_protos`, *args: See `build_toco_convert_protos`,
**kwargs: See `build_toco_convert_protos`. **kwargs: See `build_toco_convert_protos`.
@ -404,11 +426,15 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
Raises: Raises:
Defined in `build_toco_convert_protos`. Defined in `build_toco_convert_protos`.
""" """
model_flags, toco_flags, _ = build_toco_convert_protos( model_flags, toco_flags, debug_info = build_toco_convert_protos(
input_tensors, output_tensors, *args, **kwargs) input_tensors, output_tensors, *args, **kwargs)
data = toco_convert_protos(model_flags.SerializeToString(), debug_info_str = debug_info.SerializeToString() if debug_info else ""
data = toco_convert_protos(
model_flags.SerializeToString(),
toco_flags.SerializeToString(), toco_flags.SerializeToString(),
input_data.SerializeToString()) input_data.SerializeToString(),
debug_info_str=debug_info_str,
enable_mlir_converter=enable_mlir_converter)
return data return data
@ -437,5 +463,6 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
Raises: Raises:
Defined in `build_toco_convert_protos`. Defined in `build_toco_convert_protos`.
""" """
return toco_convert_impl(input_data, input_tensors, output_tensors, *args, enable_mlir_converter = kwargs.get("enable_mlir_converter", False)
**kwargs) return toco_convert_impl(input_data, input_tensors, output_tensors,
enable_mlir_converter, *args, **kwargs)

View File

@ -90,6 +90,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
tflite_model = convert.toco_convert_graph_def( tflite_model = convert.toco_convert_graph_def(
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"], sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
enable_mlir_converter=False,
inference_type=lite_constants.FLOAT) inference_type=lite_constants.FLOAT)
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
@ -126,6 +127,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
sess.graph_def, sess.graph_def,
input_arrays_map, input_arrays_map,
output_arrays, output_arrays,
enable_mlir_converter=False,
inference_type=lite_constants.QUANTIZED_UINT8, inference_type=lite_constants.QUANTIZED_UINT8,
quantized_input_stats=[(0., 1.), (0., 1.)]) quantized_input_stats=[(0., 1.), (0., 1.)])
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
@ -171,6 +173,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
sess.graph_def, sess.graph_def,
input_arrays_map, input_arrays_map,
output_arrays, output_arrays,
enable_mlir_converter=False,
inference_type=lite_constants.QUANTIZED_UINT8) inference_type=lite_constants.QUANTIZED_UINT8)
self.assertEqual( self.assertEqual(
"std_dev and mean must be defined when inference_input_type is " "std_dev and mean must be defined when inference_input_type is "

View File

@ -233,6 +233,25 @@ class TFLiteConverterBase(object):
self.representative_dataset.input_gen, inference_input_type, self.representative_dataset.input_gen, inference_input_type,
inference_output_type, allow_float) inference_output_type, allow_float)
def _get_base_converter_args(self):
"""Returns the base converter args.
Returns:
{key str: val}
"""
float16_quantize = self._is_float16_quantize()
args = {
"input_format": constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops": self.allow_custom_ops,
"post_training_quantize": (self._is_int8_weight_only_quantize() or
float16_quantize),
"quantize_to_float16": float16_quantize,
"debug_info": self._debug_info,
"target_ops": self._target_ops,
"enable_mlir_converter": self.experimental_enable_mlir_converter,
}
return args
@_tf_export("lite.TFLiteConverter", v1=[]) @_tf_export("lite.TFLiteConverter", v1=[])
class TFLiteConverterV2(TFLiteConverterBase): class TFLiteConverterV2(TFLiteConverterBase):
@ -251,6 +270,8 @@ class TFLiteConverterV2(TFLiteConverterBase):
representative_dataset: A representative dataset that can be used to representative_dataset: A representative dataset that can be used to
generate input and output samples for the model. The converter can use the generate input and output samples for the model. The converter can use the
dataset to evaluate different optimizations. dataset to evaluate different optimizations.
experimental_enable_mlir_converter: Experimental flag, subject to change.
Enables the MLIR converter instead of the TOCO converter.
Example usage: Example usage:
@ -287,6 +308,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
self.allow_custom_ops = False self.allow_custom_ops = False
self.target_spec = TargetSpec() self.target_spec = TargetSpec()
self._debug_info = None self._debug_info = None
self.experimental_enable_mlir_converter = False
@classmethod @classmethod
def from_concrete_functions(cls, funcs): def from_concrete_functions(cls, funcs):
@ -414,23 +436,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
self._validate_representative_dataset() self._validate_representative_dataset()
self._debug_info = _get_debug_info( self._debug_info = _get_debug_info(
_build_debug_info_func(self._funcs[0].graph), graph_def) _build_debug_info_func(self._funcs[0].graph), graph_def)
converter_kwargs = self._get_base_converter_args()
float16_quantize = self._is_float16_quantize()
converter_kwargs = {
"input_format":
constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops":
self.allow_custom_ops,
"post_training_quantize":
self._is_int8_weight_only_quantize() or float16_quantize,
"quantize_to_float16":
float16_quantize,
"target_ops":
self.target_spec.supported_ops,
"debug_info":
self._debug_info
}
# Converts model. # Converts model.
result = _toco_convert_impl( result = _toco_convert_impl(
@ -522,6 +528,8 @@ class TFLiteConverter(TFLiteConverterBase):
representative_dataset: A representative dataset that can be used to representative_dataset: A representative dataset that can be used to
generate input and output samples for the model. The converter can use generate input and output samples for the model. The converter can use
the dataset to evaluate different optimizations. the dataset to evaluate different optimizations.
experimental_enable_mlir_converter: Experimental flag, subject to change.
Enables the MLIR converter instead of the TOCO converter.
Example usage: Example usage:
@ -597,6 +605,7 @@ class TFLiteConverter(TFLiteConverterBase):
self.target_spec = TargetSpec() self.target_spec = TargetSpec()
self._debug_info_func = experimental_debug_info_func self._debug_info_func = experimental_debug_info_func
self._debug_info = None self._debug_info = None
self.experimental_enable_mlir_converter = False
# Attributes are used by models that cannot be loaded into TensorFlow. # Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors(): if not self._has_valid_tensors():
@ -939,31 +948,11 @@ class TFLiteConverter(TFLiteConverterBase):
"Provide an inference_input_type and inference_output_type of type " "Provide an inference_input_type and inference_output_type of type "
"tf.float32.") "tf.float32.")
float16_quantize = self._is_float16_quantize()
if not post_training_optimize and self.inference_output_type is not None: if not post_training_optimize and self.inference_output_type is not None:
raise ValueError( raise ValueError(
"inference_output_type is currently not supported if optimizations " "inference_output_type is currently not supported if optimizations "
"are not enabled.") "are not enabled.")
converter_kwargs = {
"inference_type": self.inference_type,
"inference_input_type": toco_inference_input_type,
"input_format": constants.TENSORFLOW_GRAPHDEF,
"output_format": self.output_format,
"quantized_input_stats": quantized_stats,
"default_ranges_stats": self.default_ranges_stats,
"drop_control_dependency": self.drop_control_dependency,
"reorder_across_fake_quant": self.reorder_across_fake_quant,
"change_concat_input_ranges": self.change_concat_input_ranges,
"allow_custom_ops": self.allow_custom_ops,
"post_training_quantize": weight_only_quantize or float16_quantize,
"quantize_to_float16": float16_quantize,
"target_ops": self._target_ops,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video
}
optimized_graph = self._graph_def optimized_graph = self._graph_def
if self.inference_type != constants.QUANTIZED_UINT8: if self.inference_type != constants.QUANTIZED_UINT8:
try: try:
@ -977,13 +966,26 @@ class TFLiteConverter(TFLiteConverterBase):
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph) self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
converter_kwargs = self._get_base_converter_args()
converter_kwargs.update({
"inference_type": self.inference_type,
"inference_input_type": toco_inference_input_type,
"output_format": self.output_format,
"quantized_input_stats": quantized_stats,
"default_ranges_stats": self.default_ranges_stats,
"drop_control_dependency": self.drop_control_dependency,
"reorder_across_fake_quant": self.reorder_across_fake_quant,
"change_concat_input_ranges": self.change_concat_input_ranges,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video
})
# Converts model. # Converts model.
if self._has_valid_tensors(): if self._has_valid_tensors():
result = _toco_convert_impl( result = _toco_convert_impl(
input_data=optimized_graph, input_data=optimized_graph,
input_tensors=self._input_tensors, input_tensors=self._input_tensors,
output_tensors=self._output_tensors, output_tensors=self._output_tensors,
debug_info=self._debug_info,
**converter_kwargs) **converter_kwargs)
else: else:
result = _toco_convert_graph_def( result = _toco_convert_graph_def(

View File

@ -29,10 +29,16 @@ _toco_python = LazyLoader(
del LazyLoader del LazyLoader
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str): def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
debug_info_str, enable_mlir_converter):
"""Wraps TocoConvert with lazy loader.""" """Wraps TocoConvert with lazy loader."""
return _toco_python.TocoConvert(model_flags_str, toco_flags_str, return _toco_python.TocoConvert(
input_data_str) model_flags_str,
toco_flags_str,
input_data_str,
False, # extended_return
debug_info_str,
enable_mlir_converter)
def wrapped_get_potentially_supported_ops(): def wrapped_get_potentially_supported_ops():

View File

@ -22,9 +22,11 @@ cc_library(
name = "toco_python_api", name = "toco_python_api",
srcs = ["toco_python_api.cc"], srcs = ["toco_python_api.cc"],
hdrs = ["toco_python_api.h"], hdrs = ["toco_python_api.h"],
features = ["no_layering_check"],
deps = [ deps = [
"//third_party/python_runtime:headers", "//third_party/python_runtime:headers",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/python/interpreter_wrapper:python_utils", "//tensorflow/lite/python/interpreter_wrapper:python_utils",
"//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",

View File

@ -26,11 +26,16 @@ namespace toco {
// parameters (see relevant .protos for more information). Returns a string // parameters (see relevant .protos for more information). Returns a string
// representing the contents of the converted model. When extended_return // representing the contents of the converted model. When extended_return
// flag is set to true returns a dictionary that contains string representation // flag is set to true returns a dictionary that contains string representation
// of the converted model and some statitics like arithmetic ops count. // of the converted model and some statistics like arithmetic ops count.
// `debug_info_str` contains the `GraphDebugInfo` proto. When
// `enable_mlir_converter` is True, the MLIR converter is used instead of the
// TOCO converter.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw,
PyObject* input_contents_txt_raw, PyObject* input_contents_txt_raw,
bool extended_return = false); bool extended_return = false,
PyObject* debug_info_txt_raw = nullptr,
bool enable_mlir_converter = false);
// Returns a list of names of all ops potentially supported by tflite. // Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps(); PyObject* TocoGetPotentiallySupportedOps();

View File

@ -26,11 +26,30 @@ FLAGS = None
def execute(unused_args): def execute(unused_args):
model_str = open(FLAGS.model_proto_file, "rb").read() """Runs the converter."""
toco_str = open(FLAGS.toco_proto_file, "rb").read() with open(FLAGS.model_proto_file, "rb") as model_file:
input_str = open(FLAGS.model_input_file, "rb").read() model_str = model_file.read()
output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str) with open(FLAGS.toco_proto_file, "rb") as toco_file:
toco_str = toco_file.read()
with open(FLAGS.model_input_file, "rb") as input_file:
input_str = input_file.read()
debug_info_str = ""
if FLAGS.debug_proto_file:
with open(FLAGS.debug_proto_file, "rb") as debug_info_file:
debug_info_str = debug_info_file.read()
enable_mlir_converter = FLAGS.enable_mlir_converter
output_str = tensorflow_wrap_toco.TocoConvert(
model_str,
toco_str,
input_str,
False, # extended_return
debug_info_str,
enable_mlir_converter)
open(FLAGS.model_output_file, "wb").write(output_str) open(FLAGS.model_output_file, "wb").write(output_str)
sys.exit(0) sys.exit(0)
@ -53,6 +72,17 @@ def main():
"model_output_file", "model_output_file",
type=str, type=str,
help="Result of applying TOCO conversion is written here.") help="Result of applying TOCO conversion is written here.")
parser.add_argument(
"--debug_proto_file",
type=str,
default="",
help=("File containing serialized `GraphDebugInfo` proto that describes "
"logging information."))
parser.add_argument(
"--enable_mlir_converter",
action="store_true",
help=("Boolean indiciating whether to enable the MLIR converter instead "
"of TOCO converter. (default False)"))
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/toco/python/toco_python_api.h"
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
@ -20,20 +22,27 @@ limitations under the License.
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/toco/import_tensorflow.h" #include "tensorflow/lite/toco/import_tensorflow.h"
#include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/python/toco_python_api.h"
#include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
#include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_port.h"
#include "tensorflow/lite/toco/toco_tooling.h" #include "tensorflow/lite/toco/toco_tooling.h"
#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/toco_types.h"
#if defined(PLATFORM_GOOGLE)
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
#else
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#endif
namespace toco { namespace toco {
// NOTE(aselle): We are using raw PyObject's here because we want to make // NOTE(aselle): We are using raw PyObject's here because we want to make
// sure we input and output bytes rather than unicode strings for Python3. // sure we input and output bytes rather than unicode strings for Python3.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw,
PyObject* input_contents_txt_raw, bool extended_return) { PyObject* input_contents_txt_raw, bool extended_return,
PyObject* debug_info_txt_raw,
bool enable_mlir_converter) {
// Use Python C API to validate and convert arguments. In py3 (bytes), // Use Python C API to validate and convert arguments. In py3 (bytes),
// in py2 (str). // in py2 (str).
auto ConvertArg = [&](PyObject* obj, bool* error) { auto ConvertArg = [&](PyObject* obj, bool* error) {
@ -70,12 +79,35 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
// Use TOCO to produce new outputs. // Use TOCO to produce new outputs.
toco::ModelFlags model_flags; toco::ModelFlags model_flags;
if (!model_flags.ParseFromString(model_flags_proto_txt)) { if (!model_flags.ParseFromString(model_flags_proto_txt)) {
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse."); PyErr_SetString(PyExc_ValueError,
"Failed to convert Model to Python String.");
return nullptr; return nullptr;
} }
toco::TocoFlags toco_flags; toco::TocoFlags toco_flags;
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse."); PyErr_SetString(PyExc_ValueError,
"Failed to convert Toco to Python String.");
return nullptr;
}
tensorflow::GraphDebugInfo debug_info;
if (debug_info_txt_raw) {
std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
if (error) {
PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
return nullptr;
}
if (!debug_info.ParseFromString(debug_info_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert DebugInfo to Python String.");
return nullptr;
}
}
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(input_contents_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert GraphDef to Python String.");
return nullptr; return nullptr;
} }
@ -87,18 +119,36 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video(); dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
} }
// Convert model.
std::unique_ptr<toco::Model> model =
toco::Import(toco_flags, model_flags, input_contents_txt);
toco::Transform(toco_flags, model.get());
string output_file_contents_txt; string output_file_contents_txt;
auto status = Export(toco_flags, *model, toco_flags.allow_custom_ops(), tensorflow::Status status;
std::unique_ptr<toco::Model> model;
// Convert model.
if (enable_mlir_converter) {
#if defined(PLATFORM_GOOGLE)
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
model_flags, toco_flags, debug_info, graph_def,
&output_file_contents_txt); &output_file_contents_txt);
#else
// TODO(b/124314620): Remove this condition.
PyErr_SetString(PyExc_Exception,
"This flag is not supported by this version of the "
"TFLite converter. This functionality is being "
"actively worked on.");
return nullptr;
#endif
} else {
model = toco::Import(toco_flags, model_flags, input_contents_txt);
toco::Transform(toco_flags, model.get());
status = Export(toco_flags, *model, toco_flags.allow_custom_ops(),
&output_file_contents_txt);
}
if (!status.ok()) { if (!status.ok()) {
PyErr_SetString(PyExc_Exception, status.error_message().c_str()); PyErr_SetString(PyExc_Exception, status.error_message().c_str());
return nullptr; return nullptr;
} }
if (extended_return) { if (extended_return && !enable_mlir_converter) {
PyObject* dict = PyDict_New(); PyObject* dict = PyDict_New();
PyDict_SetItemString( PyDict_SetItemString(
dict, "flatbuffer", dict, "flatbuffer",

View File

@ -25,11 +25,16 @@ namespace toco {
// parameters (see relevant .protos for more information). Returns a string // parameters (see relevant .protos for more information). Returns a string
// representing the contents of the converted model. When extended_return // representing the contents of the converted model. When extended_return
// flag is set to true returns a dictionary that contains string representation // flag is set to true returns a dictionary that contains string representation
// of the converted model and some statitics like arithmetic ops count. // of the converted model and some statistics like arithmetic ops count.
// `debug_info_str` contains the `GraphDebugInfo` proto. When
// `enable_mlir_converter` is True, the MLIR converter is used instead of the
// TOCO converter.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw, PyObject* toco_flags_proto_txt_raw,
PyObject* input_contents_txt_raw, PyObject* input_contents_txt_raw,
bool extended_return = false); bool extended_return = false,
PyObject* debug_info_txt_raw = nullptr,
bool enable_mlir_converter = false);
// Returns a list of names of all ops potentially supported by tflite. // Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps(); PyObject* TocoGetPotentiallySupportedOps();