diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index b4e0dd0c3eb..4a2fc7ba12a 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -1,5 +1,3 @@ -load("//tensorflow:tensorflow.bzl", "py_test") - package( default_visibility = ["//tensorflow:internal"], licenses = ["notice"], # Apache 2.0 @@ -309,3 +307,22 @@ py_test( "//tensorflow/python/saved_model", ], ) + +py_binary( + name = "convert_file_to_c_source", + srcs = ["convert_file_to_c_source.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ":util", + "@six_archive//:six", + ], +) + +sh_test( + name = "convert_file_to_c_source_test", + srcs = ["convert_file_to_c_source_test.sh"], + data = [":convert_file_to_c_source"], +) diff --git a/tensorflow/lite/python/convert_file_to_c_source.py b/tensorflow/lite/python/convert_file_to_c_source.py new file mode 100644 index 00000000000..c967f812f60 --- /dev/null +++ b/tensorflow/lite/python/convert_file_to_c_source.py @@ -0,0 +1,106 @@ +# Lint as: python2, python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python command line interface for converting TF Lite files into C source.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.lite.python import util +from tensorflow.python.platform import app + + +def run_main(_): + """Main in convert_file_to_c_source.py.""" + + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Converter.")) + + parser.add_argument( + "--input_tflite_file", + type=str, + help="Full filepath of the input TensorFlow Lite file.", + required=True) + + parser.add_argument( + "--output_source_file", + type=str, + help="Full filepath of the output C source file.", + required=True) + + parser.add_argument( + "--output_header_file", + type=str, + help="Full filepath of the output C header file.", + required=True) + + parser.add_argument( + "--array_variable_name", + type=str, + help="Name to use for the C data array variable.", + required=True) + + parser.add_argument( + "--line_width", type=int, help="Width to use for formatting.", default=80) + + parser.add_argument( + "--include_guard", + type=str, + help="Name to use for the C header include guard.", + default=None) + + parser.add_argument( + "--include_path", + type=str, + help="Optional path to include in generated source file.", + default=None) + + parser.add_argument( + "--use_tensorflow_license", + dest="use_tensorflow_license", + help="Whether to prefix the generated files with the TF Apache2 license.", + action="store_true") + parser.set_defaults(use_tensorflow_license=False) + + flags, _ = parser.parse_known_args(args=sys.argv[1:]) + + with open(flags.input_tflite_file, "rb") as input_handle: + input_data = input_handle.read() + + source, header = util.convert_bytes_to_c_source( + data=input_data, + array_name=flags.array_variable_name, + max_line_width=flags.line_width, + include_guard=flags.include_guard, + include_path=flags.include_path, + use_tensorflow_license=flags.use_tensorflow_license) + + with open(flags.output_source_file, "w") as source_handle: + source_handle.write(source) + + with open(flags.output_header_file, "w") as header_handle: + header_handle.write(header) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/lite/python/convert_file_to_c_source_test.sh b/tensorflow/lite/python/convert_file_to_c_source_test.sh new file mode 100755 index 00000000000..1c738008a57 --- /dev/null +++ b/tensorflow/lite/python/convert_file_to_c_source_test.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Bash unit tests for the TensorFlow Lite Micro project generator. + +set -e + +INPUT_FILE=${TEST_TMPDIR}/input.tflite +printf "\x00\x01\x02\x03" > ${INPUT_FILE} + +OUTPUT_SOURCE_FILE=${TEST_TMPDIR}/output_source.cc +OUTPUT_HEADER_FILE=${TEST_TMPDIR}/output_header.h + +# Needed for copybara compatibility. +SCRIPT_BASE_DIR=/org_"tensor"flow +${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/python/convert_file_to_c_source \ + --input_tflite_file="${INPUT_FILE}" \ + --output_source_file="${OUTPUT_SOURCE_FILE}" \ + --output_header_file="${OUTPUT_HEADER_FILE}" \ + --array_variable_name="g_some_array" \ + --line_width=80 \ + --include_guard="SOME_GUARD_H_" \ + --include_path="some/guard.h" \ + --use_tensorflow_license + +if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then + echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'" + exit 1 +fi + +if ! grep -q '0x00, 0x01, 0x02, 0x03' ${OUTPUT_SOURCE_FILE}; then + echo "ERROR: No array values found in output '${OUTPUT_SOURCE_FILE}'" + exit 1 +fi + +if ! grep -q 'const int g_some_array_len = 4;' ${OUTPUT_SOURCE_FILE}; then + echo "ERROR: No array length found in output '${OUTPUT_SOURCE_FILE}'" + exit 1 +fi + +if ! grep -q 'The TensorFlow Authors. All Rights Reserved' ${OUTPUT_SOURCE_FILE}; then + echo "ERROR: No license found in output '${OUTPUT_SOURCE_FILE}'" + exit 1 +fi + +if ! grep -q '\#include "some/guard\.h"' ${OUTPUT_SOURCE_FILE}; then + echo "ERROR: No include found in output '${OUTPUT_SOURCE_FILE}'" + exit 1 +fi + + +if ! grep -q '#ifndef SOME_GUARD_H_' ${OUTPUT_HEADER_FILE}; then + echo "ERROR: No include guard found in output '${OUTPUT_HEADER_FILE}'" + exit 1 +fi + +if ! grep -q 'extern const unsigned char g_some_array' ${OUTPUT_HEADER_FILE}; then + echo "ERROR: No array found in output '${OUTPUT_HEADER_FILE}'" + exit 1 +fi + +if ! grep -q 'extern const int g_some_array_len;' ${OUTPUT_HEADER_FILE}; then + echo "ERROR: No array length found in output '${OUTPUT_HEADER_FILE}'" + exit 1 +fi + +if ! grep -q 'The TensorFlow Authors. All Rights Reserved' ${OUTPUT_HEADER_FILE}; then + echo "ERROR: No license found in output '${OUTPUT_HEADER_FILE}'" + exit 1 +fi + + +echo +echo "SUCCESS: convert_file_to_c_source test PASSED" diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 550718f8407..01461686cbc 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import datetime import sys import six @@ -354,3 +355,123 @@ def get_debug_info(nodes_to_debug_info_func, converted_graph): # Convert the nodes to the debug info proto object. return nodes_to_debug_info_func(original_nodes) + + +def convert_bytes_to_c_source(data, + array_name, + max_line_width=80, + include_guard=None, + include_path=None, + use_tensorflow_license=False): + """Returns strings representing a C constant array containing `data`. + + Args: + data: Byte array that will be converted into a C constant. + array_name: String to use as the variable name for the constant array. + max_line_width: The longest line length, for formatting purposes. + include_guard: Name to use for the include guard macro definition. + include_path: Optional path to include in the source file. + use_tensorflow_license: Whether to include the standard TensorFlow Apache2 + license in the generated files. + + Returns: + Text that can be compiled as a C source file to link in the data as a + literal array of values. + Text that can be used as a C header file to reference the literal array. + """ + + starting_pad = " " + array_lines = [] + array_line = starting_pad + for value in bytearray(data): + if (len(array_line) + 4) > max_line_width: + array_lines.append(array_line + "\n") + array_line = starting_pad + array_line += " 0x%02x," % (value) + if len(array_line) > len(starting_pad): + array_lines.append(array_line + "\n") + array_values = "".join(array_lines) + + if include_guard is None: + include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_" + + if include_path is not None: + include_line = "#include \"{include_path}\"\n".format( + include_path=include_path) + else: + include_line = "" + + if use_tensorflow_license: + license_text = """ +/* Copyright {year} The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +""".format(year=datetime.date.today().year) + else: + license_text = "" + + source_template = """{license_text} +// This is a TensorFlow Lite model file that has been converted into a C data +// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. +// This form is useful for compiling into a binary for devices that don't have a +// file system. + +{include_line} +// We need to keep the data array aligned on some architectures. +#ifdef __has_attribute +#define HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define HAVE_ATTRIBUTE(x) 0 +#endif +#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) +#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) +#else +#define DATA_ALIGN_ATTRIBUTE +#endif + +const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{ +{array_values}}}; +const int {array_name}_len = {array_length}; +""" + + source_text = source_template.format( + array_name=array_name, + array_length=len(data), + array_values=array_values, + license_text=license_text, + include_line=include_line) + + header_template = """ +{license_text} + +// This is a TensorFlow Lite model file that has been converted into a C data +// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. +// This form is useful for compiling into a binary for devices that don't have a +// file system. + +#ifndef {include_guard} +#define {include_guard} + +extern const unsigned char {array_name}[]; +extern const int {array_name}_len; + +#endif // {include_guard} +""" + + header_text = header_template.format( + array_name=array_name, + include_guard=include_guard, + license_text=license_text) + + return source_text, header_text diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index d353f5ba81f..7b1324bb757 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -90,6 +90,36 @@ class UtilTest(test_util.TensorFlowTestCase): lower_using_switch_merge_is_removed = True self.assertEqual(lower_using_switch_merge_is_removed, True) + def testConvertBytes(self): + source, header = util.convert_bytes_to_c_source( + b"\x00\x01\x02\x23", "foo", 16, use_tensorflow_license=False) + self.assertTrue( + source.find("const unsigned char foo[] DATA_ALIGN_ATTRIBUTE = {")) + self.assertTrue(source.find(""" 0x00, 0x01, + 0x02, 0x23,""")) + self.assertNotEqual(-1, source.find("const int foo_len = 4;")) + self.assertEqual(-1, source.find("/* Copyright")) + self.assertEqual(-1, source.find("#include " "")) + self.assertNotEqual(-1, header.find("extern const unsigned char foo[];")) + self.assertNotEqual(-1, header.find("extern const int foo_len;")) + self.assertEqual(-1, header.find("/* Copyright")) + + source, header = util.convert_bytes_to_c_source( + b"\xff\xfe\xfd\xfc", + "bar", + 80, + include_guard="MY_GUARD", + include_path="my/guard.h", + use_tensorflow_license=True) + self.assertNotEqual( + -1, source.find("const unsigned char bar[] DATA_ALIGN_ATTRIBUTE = {")) + self.assertNotEqual(-1, source.find(""" 0xff, 0xfe, 0xfd, 0xfc,""")) + self.assertNotEqual(-1, source.find("/* Copyright")) + self.assertNotEqual(-1, source.find("#include \"my/guard.h\"")) + self.assertNotEqual(-1, header.find("#ifndef MY_GUARD")) + self.assertNotEqual(-1, header.find("#define MY_GUARD")) + self.assertNotEqual(-1, header.find("/* Copyright")) + class TensorFunctionsTest(test_util.TensorFlowTestCase):