Propagate node debug information.
PiperOrigin-RevId: 257286387
This commit is contained in:
parent
1ef629438e
commit
8d9b34c4cd
@ -93,7 +93,11 @@ class ConverterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
def toco_convert_protos(model_flags_str,
|
||||
toco_flags_str,
|
||||
input_data_str,
|
||||
debug_info_str="",
|
||||
enable_mlir_converter=False):
|
||||
"""Convert `input_data_str` according to model and toco parameters.
|
||||
|
||||
Unless you know what you are doing consider using
|
||||
@ -105,6 +109,10 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
toco_flags_str: Serialized proto describing conversion properties, see
|
||||
`toco/toco_flags.proto`.
|
||||
input_data_str: Input data in serialized form (e.g. a graphdef is common)
|
||||
debug_info_str: Serialized `GraphDebugInfo` proto describing logging
|
||||
information. (default "")
|
||||
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
|
||||
converter. (default False)
|
||||
Returns:
|
||||
Converted model in serialized form (e.g. a TFLITE model is common).
|
||||
Raises:
|
||||
@ -118,10 +126,12 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
if not _toco_from_proto_bin:
|
||||
try:
|
||||
model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
|
||||
toco_flags_str, input_data_str)
|
||||
toco_flags_str, input_data_str,
|
||||
debug_info_str,
|
||||
enable_mlir_converter)
|
||||
return model_str
|
||||
except Exception as e:
|
||||
raise ConverterError("TOCO failed: %s" % e)
|
||||
raise ConverterError(str(e))
|
||||
|
||||
# Windows and TemporaryFile are not that useful together,
|
||||
# since you cannot have two readers/writers. So we have to
|
||||
@ -132,16 +142,17 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
# Build all input files
|
||||
with _tempfile.NamedTemporaryFile(delete=False) as fp_toco, \
|
||||
_tempfile.NamedTemporaryFile(delete=False) as fp_model, \
|
||||
_tempfile.NamedTemporaryFile(delete=False) as fp_input:
|
||||
_tempfile.NamedTemporaryFile(delete=False) as fp_input, \
|
||||
_tempfile.NamedTemporaryFile(delete=False) as fp_debug:
|
||||
toco_filename = fp_toco.name
|
||||
input_filename = fp_input.name
|
||||
model_filename = fp_model.name
|
||||
debug_filename = fp_debug.name
|
||||
|
||||
fp_model.write(model_flags_str)
|
||||
fp_toco.write(toco_flags_str)
|
||||
fp_input.write(input_data_str)
|
||||
fp_model.flush()
|
||||
fp_toco.flush()
|
||||
fp_input.flush()
|
||||
fp_debug.write(debug_info_str)
|
||||
|
||||
# Reserve an output file
|
||||
with _tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||
@ -149,9 +160,15 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
|
||||
# Run
|
||||
cmd = [
|
||||
_toco_from_proto_bin, model_filename, toco_filename, input_filename,
|
||||
output_filename
|
||||
_toco_from_proto_bin,
|
||||
model_filename,
|
||||
toco_filename,
|
||||
input_filename,
|
||||
output_filename,
|
||||
"--debug_proto_file={}".format(debug_filename),
|
||||
]
|
||||
if enable_mlir_converter:
|
||||
cmd.append("--enable_mlir_converter")
|
||||
cmdline = " ".join(cmd)
|
||||
is_windows = _platform.system() == "Windows"
|
||||
proc = _subprocess.Popen(
|
||||
@ -168,8 +185,7 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
|
||||
else:
|
||||
stdout = _try_convert_to_unicode(stdout)
|
||||
stderr = _try_convert_to_unicode(stderr)
|
||||
raise ConverterError(
|
||||
"TOCO failed. See console for info.\n%s\n%s\n" % (stdout, stderr))
|
||||
raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
|
||||
finally:
|
||||
# Must manually cleanup files.
|
||||
for filename in [
|
||||
@ -211,9 +227,9 @@ def build_toco_convert_protos(input_tensors,
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
inference_type: Target data type of real-number arrays in the output file.
|
||||
Must be `{tf.float32, tf.uint8}`. (default tf.float32)
|
||||
Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays in the case of quantization.
|
||||
Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
|
||||
input_format: Type of data to read Currently must be
|
||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||
input_shapes: Input array shape. It needs to be a list of the same length
|
||||
@ -266,7 +282,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
|
||||
Returns:
|
||||
model_flags, toco_flags, debug_info: three protocol buffers describing the
|
||||
conversion process and debug information.
|
||||
conversion process and debug information.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
@ -330,7 +346,7 @@ def build_toco_convert_protos(input_tensors,
|
||||
|
||||
|
||||
def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
*args, **kwargs):
|
||||
enable_mlir_converter, *args, **kwargs):
|
||||
""""Convert a model using TOCO.
|
||||
|
||||
This function is used to convert GraphDefs that cannot be loaded into
|
||||
@ -347,6 +363,8 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
output_arrays: List of output tensors to freeze graph with. Use only when
|
||||
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
|
||||
(default None)
|
||||
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
|
||||
converter.
|
||||
*args: See `build_toco_convert_protos`,
|
||||
**kwargs: See `build_toco_convert_protos`.
|
||||
|
||||
@ -375,14 +393,16 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
|
||||
for name in output_arrays:
|
||||
model_flags.output_arrays.append(name)
|
||||
|
||||
data = toco_convert_protos(model_flags.SerializeToString(),
|
||||
toco_flags.SerializeToString(),
|
||||
input_data.SerializeToString())
|
||||
data = toco_convert_protos(
|
||||
model_flags.SerializeToString(),
|
||||
toco_flags.SerializeToString(),
|
||||
input_data.SerializeToString(),
|
||||
enable_mlir_converter=enable_mlir_converter)
|
||||
return data
|
||||
|
||||
|
||||
def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
||||
**kwargs):
|
||||
def toco_convert_impl(input_data, input_tensors, output_tensors,
|
||||
enable_mlir_converter, *args, **kwargs):
|
||||
""""Convert a model using TOCO.
|
||||
|
||||
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
||||
@ -394,6 +414,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
||||
input_tensors: List of input tensors. Type and shape are computed using
|
||||
`foo.shape` and `foo.dtype`.
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
enable_mlir_converter: Enables the MLIR converter instead of the TOCO
|
||||
converter.
|
||||
*args: See `build_toco_convert_protos`,
|
||||
**kwargs: See `build_toco_convert_protos`.
|
||||
|
||||
@ -404,11 +426,15 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
||||
Raises:
|
||||
Defined in `build_toco_convert_protos`.
|
||||
"""
|
||||
model_flags, toco_flags, _ = build_toco_convert_protos(
|
||||
model_flags, toco_flags, debug_info = build_toco_convert_protos(
|
||||
input_tensors, output_tensors, *args, **kwargs)
|
||||
data = toco_convert_protos(model_flags.SerializeToString(),
|
||||
toco_flags.SerializeToString(),
|
||||
input_data.SerializeToString())
|
||||
debug_info_str = debug_info.SerializeToString() if debug_info else ""
|
||||
data = toco_convert_protos(
|
||||
model_flags.SerializeToString(),
|
||||
toco_flags.SerializeToString(),
|
||||
input_data.SerializeToString(),
|
||||
debug_info_str=debug_info_str,
|
||||
enable_mlir_converter=enable_mlir_converter)
|
||||
return data
|
||||
|
||||
|
||||
@ -437,5 +463,6 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
|
||||
Raises:
|
||||
Defined in `build_toco_convert_protos`.
|
||||
"""
|
||||
return toco_convert_impl(input_data, input_tensors, output_tensors, *args,
|
||||
**kwargs)
|
||||
enable_mlir_converter = kwargs.get("enable_mlir_converter", False)
|
||||
return toco_convert_impl(input_data, input_tensors, output_tensors,
|
||||
enable_mlir_converter, *args, **kwargs)
|
||||
|
@ -90,6 +90,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
tflite_model = convert.toco_convert_graph_def(
|
||||
sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.FLOAT)
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
@ -126,6 +127,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
sess.graph_def,
|
||||
input_arrays_map,
|
||||
output_arrays,
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.QUANTIZED_UINT8,
|
||||
quantized_input_stats=[(0., 1.), (0., 1.)])
|
||||
self.assertTrue(tflite_model)
|
||||
@ -171,6 +173,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
|
||||
sess.graph_def,
|
||||
input_arrays_map,
|
||||
output_arrays,
|
||||
enable_mlir_converter=False,
|
||||
inference_type=lite_constants.QUANTIZED_UINT8)
|
||||
self.assertEqual(
|
||||
"std_dev and mean must be defined when inference_input_type is "
|
||||
|
@ -233,6 +233,25 @@ class TFLiteConverterBase(object):
|
||||
self.representative_dataset.input_gen, inference_input_type,
|
||||
inference_output_type, allow_float)
|
||||
|
||||
def _get_base_converter_args(self):
|
||||
"""Returns the base converter args.
|
||||
|
||||
Returns:
|
||||
{key str: val}
|
||||
"""
|
||||
float16_quantize = self._is_float16_quantize()
|
||||
args = {
|
||||
"input_format": constants.TENSORFLOW_GRAPHDEF,
|
||||
"allow_custom_ops": self.allow_custom_ops,
|
||||
"post_training_quantize": (self._is_int8_weight_only_quantize() or
|
||||
float16_quantize),
|
||||
"quantize_to_float16": float16_quantize,
|
||||
"debug_info": self._debug_info,
|
||||
"target_ops": self._target_ops,
|
||||
"enable_mlir_converter": self.experimental_enable_mlir_converter,
|
||||
}
|
||||
return args
|
||||
|
||||
|
||||
@_tf_export("lite.TFLiteConverter", v1=[])
|
||||
class TFLiteConverterV2(TFLiteConverterBase):
|
||||
@ -251,6 +270,8 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
representative_dataset: A representative dataset that can be used to
|
||||
generate input and output samples for the model. The converter can use the
|
||||
dataset to evaluate different optimizations.
|
||||
experimental_enable_mlir_converter: Experimental flag, subject to change.
|
||||
Enables the MLIR converter instead of the TOCO converter.
|
||||
|
||||
Example usage:
|
||||
|
||||
@ -287,6 +308,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
self.allow_custom_ops = False
|
||||
self.target_spec = TargetSpec()
|
||||
self._debug_info = None
|
||||
self.experimental_enable_mlir_converter = False
|
||||
|
||||
@classmethod
|
||||
def from_concrete_functions(cls, funcs):
|
||||
@ -414,23 +436,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
self._validate_representative_dataset()
|
||||
self._debug_info = _get_debug_info(
|
||||
_build_debug_info_func(self._funcs[0].graph), graph_def)
|
||||
|
||||
float16_quantize = self._is_float16_quantize()
|
||||
|
||||
converter_kwargs = {
|
||||
"input_format":
|
||||
constants.TENSORFLOW_GRAPHDEF,
|
||||
"allow_custom_ops":
|
||||
self.allow_custom_ops,
|
||||
"post_training_quantize":
|
||||
self._is_int8_weight_only_quantize() or float16_quantize,
|
||||
"quantize_to_float16":
|
||||
float16_quantize,
|
||||
"target_ops":
|
||||
self.target_spec.supported_ops,
|
||||
"debug_info":
|
||||
self._debug_info
|
||||
}
|
||||
converter_kwargs = self._get_base_converter_args()
|
||||
|
||||
# Converts model.
|
||||
result = _toco_convert_impl(
|
||||
@ -522,6 +528,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
representative_dataset: A representative dataset that can be used to
|
||||
generate input and output samples for the model. The converter can use
|
||||
the dataset to evaluate different optimizations.
|
||||
experimental_enable_mlir_converter: Experimental flag, subject to change.
|
||||
Enables the MLIR converter instead of the TOCO converter.
|
||||
|
||||
Example usage:
|
||||
|
||||
@ -597,6 +605,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
self.target_spec = TargetSpec()
|
||||
self._debug_info_func = experimental_debug_info_func
|
||||
self._debug_info = None
|
||||
self.experimental_enable_mlir_converter = False
|
||||
|
||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||
if not self._has_valid_tensors():
|
||||
@ -939,31 +948,11 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
"Provide an inference_input_type and inference_output_type of type "
|
||||
"tf.float32.")
|
||||
|
||||
float16_quantize = self._is_float16_quantize()
|
||||
|
||||
if not post_training_optimize and self.inference_output_type is not None:
|
||||
raise ValueError(
|
||||
"inference_output_type is currently not supported if optimizations "
|
||||
"are not enabled.")
|
||||
|
||||
converter_kwargs = {
|
||||
"inference_type": self.inference_type,
|
||||
"inference_input_type": toco_inference_input_type,
|
||||
"input_format": constants.TENSORFLOW_GRAPHDEF,
|
||||
"output_format": self.output_format,
|
||||
"quantized_input_stats": quantized_stats,
|
||||
"default_ranges_stats": self.default_ranges_stats,
|
||||
"drop_control_dependency": self.drop_control_dependency,
|
||||
"reorder_across_fake_quant": self.reorder_across_fake_quant,
|
||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||
"allow_custom_ops": self.allow_custom_ops,
|
||||
"post_training_quantize": weight_only_quantize or float16_quantize,
|
||||
"quantize_to_float16": float16_quantize,
|
||||
"target_ops": self._target_ops,
|
||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||
"dump_graphviz_video": self.dump_graphviz_video
|
||||
}
|
||||
|
||||
optimized_graph = self._graph_def
|
||||
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||
try:
|
||||
@ -977,13 +966,26 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
|
||||
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
|
||||
|
||||
converter_kwargs = self._get_base_converter_args()
|
||||
converter_kwargs.update({
|
||||
"inference_type": self.inference_type,
|
||||
"inference_input_type": toco_inference_input_type,
|
||||
"output_format": self.output_format,
|
||||
"quantized_input_stats": quantized_stats,
|
||||
"default_ranges_stats": self.default_ranges_stats,
|
||||
"drop_control_dependency": self.drop_control_dependency,
|
||||
"reorder_across_fake_quant": self.reorder_across_fake_quant,
|
||||
"change_concat_input_ranges": self.change_concat_input_ranges,
|
||||
"dump_graphviz_dir": self.dump_graphviz_dir,
|
||||
"dump_graphviz_video": self.dump_graphviz_video
|
||||
})
|
||||
|
||||
# Converts model.
|
||||
if self._has_valid_tensors():
|
||||
result = _toco_convert_impl(
|
||||
input_data=optimized_graph,
|
||||
input_tensors=self._input_tensors,
|
||||
output_tensors=self._output_tensors,
|
||||
debug_info=self._debug_info,
|
||||
**converter_kwargs)
|
||||
else:
|
||||
result = _toco_convert_graph_def(
|
||||
|
@ -29,10 +29,16 @@ _toco_python = LazyLoader(
|
||||
del LazyLoader
|
||||
|
||||
|
||||
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str):
|
||||
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
|
||||
debug_info_str, enable_mlir_converter):
|
||||
"""Wraps TocoConvert with lazy loader."""
|
||||
return _toco_python.TocoConvert(model_flags_str, toco_flags_str,
|
||||
input_data_str)
|
||||
return _toco_python.TocoConvert(
|
||||
model_flags_str,
|
||||
toco_flags_str,
|
||||
input_data_str,
|
||||
False, # extended_return
|
||||
debug_info_str,
|
||||
enable_mlir_converter)
|
||||
|
||||
|
||||
def wrapped_get_potentially_supported_ops():
|
||||
|
@ -22,9 +22,11 @@ cc_library(
|
||||
name = "toco_python_api",
|
||||
srcs = ["toco_python_api.cc"],
|
||||
hdrs = ["toco_python_api.h"],
|
||||
features = ["no_layering_check"],
|
||||
deps = [
|
||||
"//third_party/python_runtime:headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
|
@ -26,13 +26,18 @@ namespace toco {
|
||||
// parameters (see relevant .protos for more information). Returns a string
|
||||
// representing the contents of the converted model. When extended_return
|
||||
// flag is set to true returns a dictionary that contains string representation
|
||||
// of the converted model and some statitics like arithmetic ops count.
|
||||
// of the converted model and some statistics like arithmetic ops count.
|
||||
// `debug_info_str` contains the `GraphDebugInfo` proto. When
|
||||
// `enable_mlir_converter` is True, the MLIR converter is used instead of the
|
||||
// TOCO converter.
|
||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
PyObject* toco_flags_proto_txt_raw,
|
||||
PyObject* input_contents_txt_raw,
|
||||
bool extended_return = false);
|
||||
bool extended_return = false,
|
||||
PyObject* debug_info_txt_raw = nullptr,
|
||||
bool enable_mlir_converter = false);
|
||||
|
||||
// Returns a list of names of all ops potentially supported by tflite.
|
||||
PyObject* TocoGetPotentiallySupportedOps();
|
||||
|
||||
} // namespace toco
|
||||
} // namespace toco
|
||||
|
@ -26,11 +26,30 @@ FLAGS = None
|
||||
|
||||
|
||||
def execute(unused_args):
|
||||
model_str = open(FLAGS.model_proto_file, "rb").read()
|
||||
toco_str = open(FLAGS.toco_proto_file, "rb").read()
|
||||
input_str = open(FLAGS.model_input_file, "rb").read()
|
||||
"""Runs the converter."""
|
||||
with open(FLAGS.model_proto_file, "rb") as model_file:
|
||||
model_str = model_file.read()
|
||||
|
||||
output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str)
|
||||
with open(FLAGS.toco_proto_file, "rb") as toco_file:
|
||||
toco_str = toco_file.read()
|
||||
|
||||
with open(FLAGS.model_input_file, "rb") as input_file:
|
||||
input_str = input_file.read()
|
||||
|
||||
debug_info_str = ""
|
||||
if FLAGS.debug_proto_file:
|
||||
with open(FLAGS.debug_proto_file, "rb") as debug_info_file:
|
||||
debug_info_str = debug_info_file.read()
|
||||
|
||||
enable_mlir_converter = FLAGS.enable_mlir_converter
|
||||
|
||||
output_str = tensorflow_wrap_toco.TocoConvert(
|
||||
model_str,
|
||||
toco_str,
|
||||
input_str,
|
||||
False, # extended_return
|
||||
debug_info_str,
|
||||
enable_mlir_converter)
|
||||
open(FLAGS.model_output_file, "wb").write(output_str)
|
||||
sys.exit(0)
|
||||
|
||||
@ -53,6 +72,17 @@ def main():
|
||||
"model_output_file",
|
||||
type=str,
|
||||
help="Result of applying TOCO conversion is written here.")
|
||||
parser.add_argument(
|
||||
"--debug_proto_file",
|
||||
type=str,
|
||||
default="",
|
||||
help=("File containing serialized `GraphDebugInfo` proto that describes "
|
||||
"logging information."))
|
||||
parser.add_argument(
|
||||
"--enable_mlir_converter",
|
||||
action="store_true",
|
||||
help=("Boolean indiciating whether to enable the MLIR converter instead "
|
||||
"of TOCO converter. (default False)"))
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -20,20 +22,27 @@ limitations under the License.
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
|
||||
#include "tensorflow/lite/toco/toco_port.h"
|
||||
#include "tensorflow/lite/toco/toco_tooling.h"
|
||||
#include "tensorflow/lite/toco/toco_types.h"
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
||||
#else
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#endif
|
||||
|
||||
namespace toco {
|
||||
|
||||
// NOTE(aselle): We are using raw PyObject's here because we want to make
|
||||
// sure we input and output bytes rather than unicode strings for Python3.
|
||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
PyObject* toco_flags_proto_txt_raw,
|
||||
PyObject* input_contents_txt_raw, bool extended_return) {
|
||||
PyObject* input_contents_txt_raw, bool extended_return,
|
||||
PyObject* debug_info_txt_raw,
|
||||
bool enable_mlir_converter) {
|
||||
// Use Python C API to validate and convert arguments. In py3 (bytes),
|
||||
// in py2 (str).
|
||||
auto ConvertArg = [&](PyObject* obj, bool* error) {
|
||||
@ -70,12 +79,35 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
// Use TOCO to produce new outputs.
|
||||
toco::ModelFlags model_flags;
|
||||
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse.");
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Failed to convert Model to Python String.");
|
||||
return nullptr;
|
||||
}
|
||||
toco::TocoFlags toco_flags;
|
||||
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse.");
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Failed to convert Toco to Python String.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::GraphDebugInfo debug_info;
|
||||
if (debug_info_txt_raw) {
|
||||
std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
|
||||
if (error) {
|
||||
PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
|
||||
return nullptr;
|
||||
}
|
||||
if (!debug_info.ParseFromString(debug_info_txt)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Failed to convert DebugInfo to Python String.");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(input_contents_txt)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Failed to convert GraphDef to Python String.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -87,18 +119,36 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
|
||||
}
|
||||
|
||||
// Convert model.
|
||||
std::unique_ptr<toco::Model> model =
|
||||
toco::Import(toco_flags, model_flags, input_contents_txt);
|
||||
toco::Transform(toco_flags, model.get());
|
||||
string output_file_contents_txt;
|
||||
auto status = Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||
&output_file_contents_txt);
|
||||
tensorflow::Status status;
|
||||
std::unique_ptr<toco::Model> model;
|
||||
|
||||
// Convert model.
|
||||
if (enable_mlir_converter) {
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
|
||||
model_flags, toco_flags, debug_info, graph_def,
|
||||
&output_file_contents_txt);
|
||||
#else
|
||||
// TODO(b/124314620): Remove this condition.
|
||||
PyErr_SetString(PyExc_Exception,
|
||||
"This flag is not supported by this version of the "
|
||||
"TFLite converter. This functionality is being "
|
||||
"actively worked on.");
|
||||
return nullptr;
|
||||
#endif
|
||||
} else {
|
||||
model = toco::Import(toco_flags, model_flags, input_contents_txt);
|
||||
toco::Transform(toco_flags, model.get());
|
||||
status = Export(toco_flags, *model, toco_flags.allow_custom_ops(),
|
||||
&output_file_contents_txt);
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
PyErr_SetString(PyExc_Exception, status.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (extended_return) {
|
||||
if (extended_return && !enable_mlir_converter) {
|
||||
PyObject* dict = PyDict_New();
|
||||
PyDict_SetItemString(
|
||||
dict, "flatbuffer",
|
||||
|
@ -25,11 +25,16 @@ namespace toco {
|
||||
// parameters (see relevant .protos for more information). Returns a string
|
||||
// representing the contents of the converted model. When extended_return
|
||||
// flag is set to true returns a dictionary that contains string representation
|
||||
// of the converted model and some statitics like arithmetic ops count.
|
||||
// of the converted model and some statistics like arithmetic ops count.
|
||||
// `debug_info_str` contains the `GraphDebugInfo` proto. When
|
||||
// `enable_mlir_converter` is True, the MLIR converter is used instead of the
|
||||
// TOCO converter.
|
||||
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
PyObject* toco_flags_proto_txt_raw,
|
||||
PyObject* input_contents_txt_raw,
|
||||
bool extended_return = false);
|
||||
bool extended_return = false,
|
||||
PyObject* debug_info_txt_raw = nullptr,
|
||||
bool enable_mlir_converter = false);
|
||||
|
||||
// Returns a list of names of all ops potentially supported by tflite.
|
||||
PyObject* TocoGetPotentiallySupportedOps();
|
||||
|
Loading…
Reference in New Issue
Block a user