Utility for converting TF Lite files to embedded C data arrays

PiperOrigin-RevId: 277200994
Change-Id: I152594d72e47a30fb6aefeec541dc96181453295
This commit is contained in:
Pete Warden 2019-10-28 20:38:57 -07:00 committed by TensorFlower Gardener
parent c7a4062cc2
commit ba7ff487f0
5 changed files with 363 additions and 2 deletions

View File

@ -1,5 +1,3 @@
load("//tensorflow:tensorflow.bzl", "py_test")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
@ -309,3 +307,22 @@ py_test(
"//tensorflow/python/saved_model",
],
)
py_binary(
name = "convert_file_to_c_source",
srcs = ["convert_file_to_c_source.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":lite",
":util",
"@six_archive//:six",
],
)
sh_test(
name = "convert_file_to_c_source_test",
srcs = ["convert_file_to_c_source_test.sh"],
data = [":convert_file_to_c_source"],
)

View File

@ -0,0 +1,106 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python command line interface for converting TF Lite files into C source."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.lite.python import util
from tensorflow.python.platform import app
def run_main(_):
"""Main in convert_file_to_c_source.py."""
parser = argparse.ArgumentParser(
description=("Command line tool to run TensorFlow Lite Converter."))
parser.add_argument(
"--input_tflite_file",
type=str,
help="Full filepath of the input TensorFlow Lite file.",
required=True)
parser.add_argument(
"--output_source_file",
type=str,
help="Full filepath of the output C source file.",
required=True)
parser.add_argument(
"--output_header_file",
type=str,
help="Full filepath of the output C header file.",
required=True)
parser.add_argument(
"--array_variable_name",
type=str,
help="Name to use for the C data array variable.",
required=True)
parser.add_argument(
"--line_width", type=int, help="Width to use for formatting.", default=80)
parser.add_argument(
"--include_guard",
type=str,
help="Name to use for the C header include guard.",
default=None)
parser.add_argument(
"--include_path",
type=str,
help="Optional path to include in generated source file.",
default=None)
parser.add_argument(
"--use_tensorflow_license",
dest="use_tensorflow_license",
help="Whether to prefix the generated files with the TF Apache2 license.",
action="store_true")
parser.set_defaults(use_tensorflow_license=False)
flags, _ = parser.parse_known_args(args=sys.argv[1:])
with open(flags.input_tflite_file, "rb") as input_handle:
input_data = input_handle.read()
source, header = util.convert_bytes_to_c_source(
data=input_data,
array_name=flags.array_variable_name,
max_line_width=flags.line_width,
include_guard=flags.include_guard,
include_path=flags.include_path,
use_tensorflow_license=flags.use_tensorflow_license)
with open(flags.output_source_file, "w") as source_handle:
source_handle.write(source)
with open(flags.output_header_file, "w") as header_handle:
header_handle.write(header)
def main():
app.run(main=run_main, argv=sys.argv[:1])
if __name__ == "__main__":
main()

View File

@ -0,0 +1,87 @@
#!/bin/bash
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Bash unit tests for the TensorFlow Lite Micro project generator.
set -e
INPUT_FILE=${TEST_TMPDIR}/input.tflite
printf "\x00\x01\x02\x03" > ${INPUT_FILE}
OUTPUT_SOURCE_FILE=${TEST_TMPDIR}/output_source.cc
OUTPUT_HEADER_FILE=${TEST_TMPDIR}/output_header.h
# Needed for copybara compatibility.
SCRIPT_BASE_DIR=/org_"tensor"flow
${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/python/convert_file_to_c_source \
--input_tflite_file="${INPUT_FILE}" \
--output_source_file="${OUTPUT_SOURCE_FILE}" \
--output_header_file="${OUTPUT_HEADER_FILE}" \
--array_variable_name="g_some_array" \
--line_width=80 \
--include_guard="SOME_GUARD_H_" \
--include_path="some/guard.h" \
--use_tensorflow_license
if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then
echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'"
exit 1
fi
if ! grep -q '0x00, 0x01, 0x02, 0x03' ${OUTPUT_SOURCE_FILE}; then
echo "ERROR: No array values found in output '${OUTPUT_SOURCE_FILE}'"
exit 1
fi
if ! grep -q 'const int g_some_array_len = 4;' ${OUTPUT_SOURCE_FILE}; then
echo "ERROR: No array length found in output '${OUTPUT_SOURCE_FILE}'"
exit 1
fi
if ! grep -q 'The TensorFlow Authors. All Rights Reserved' ${OUTPUT_SOURCE_FILE}; then
echo "ERROR: No license found in output '${OUTPUT_SOURCE_FILE}'"
exit 1
fi
if ! grep -q '\#include "some/guard\.h"' ${OUTPUT_SOURCE_FILE}; then
echo "ERROR: No include found in output '${OUTPUT_SOURCE_FILE}'"
exit 1
fi
if ! grep -q '#ifndef SOME_GUARD_H_' ${OUTPUT_HEADER_FILE}; then
echo "ERROR: No include guard found in output '${OUTPUT_HEADER_FILE}'"
exit 1
fi
if ! grep -q 'extern const unsigned char g_some_array' ${OUTPUT_HEADER_FILE}; then
echo "ERROR: No array found in output '${OUTPUT_HEADER_FILE}'"
exit 1
fi
if ! grep -q 'extern const int g_some_array_len;' ${OUTPUT_HEADER_FILE}; then
echo "ERROR: No array length found in output '${OUTPUT_HEADER_FILE}'"
exit 1
fi
if ! grep -q 'The TensorFlow Authors. All Rights Reserved' ${OUTPUT_HEADER_FILE}; then
echo "ERROR: No license found in output '${OUTPUT_HEADER_FILE}'"
exit 1
fi
echo
echo "SUCCESS: convert_file_to_c_source test PASSED"

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import sys
import six
@ -354,3 +355,123 @@ def get_debug_info(nodes_to_debug_info_func, converted_graph):
# Convert the nodes to the debug info proto object.
return nodes_to_debug_info_func(original_nodes)
def convert_bytes_to_c_source(data,
array_name,
max_line_width=80,
include_guard=None,
include_path=None,
use_tensorflow_license=False):
"""Returns strings representing a C constant array containing `data`.
Args:
data: Byte array that will be converted into a C constant.
array_name: String to use as the variable name for the constant array.
max_line_width: The longest line length, for formatting purposes.
include_guard: Name to use for the include guard macro definition.
include_path: Optional path to include in the source file.
use_tensorflow_license: Whether to include the standard TensorFlow Apache2
license in the generated files.
Returns:
Text that can be compiled as a C source file to link in the data as a
literal array of values.
Text that can be used as a C header file to reference the literal array.
"""
starting_pad = " "
array_lines = []
array_line = starting_pad
for value in bytearray(data):
if (len(array_line) + 4) > max_line_width:
array_lines.append(array_line + "\n")
array_line = starting_pad
array_line += " 0x%02x," % (value)
if len(array_line) > len(starting_pad):
array_lines.append(array_line + "\n")
array_values = "".join(array_lines)
if include_guard is None:
include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
if include_path is not None:
include_line = "#include \"{include_path}\"\n".format(
include_path=include_path)
else:
include_line = ""
if use_tensorflow_license:
license_text = """
/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
""".format(year=datetime.date.today().year)
else:
license_text = ""
source_template = """{license_text}
// This is a TensorFlow Lite model file that has been converted into a C data
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
// This form is useful for compiling into a binary for devices that don't have a
// file system.
{include_line}
// We need to keep the data array aligned on some architectures.
#ifdef __has_attribute
#define HAVE_ATTRIBUTE(x) __has_attribute(x)
#else
#define HAVE_ATTRIBUTE(x) 0
#endif
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
#else
#define DATA_ALIGN_ATTRIBUTE
#endif
const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
{array_values}}};
const int {array_name}_len = {array_length};
"""
source_text = source_template.format(
array_name=array_name,
array_length=len(data),
array_values=array_values,
license_text=license_text,
include_line=include_line)
header_template = """
{license_text}
// This is a TensorFlow Lite model file that has been converted into a C data
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
// This form is useful for compiling into a binary for devices that don't have a
// file system.
#ifndef {include_guard}
#define {include_guard}
extern const unsigned char {array_name}[];
extern const int {array_name}_len;
#endif // {include_guard}
"""
header_text = header_template.format(
array_name=array_name,
include_guard=include_guard,
license_text=license_text)
return source_text, header_text

View File

@ -90,6 +90,36 @@ class UtilTest(test_util.TensorFlowTestCase):
lower_using_switch_merge_is_removed = True
self.assertEqual(lower_using_switch_merge_is_removed, True)
def testConvertBytes(self):
source, header = util.convert_bytes_to_c_source(
b"\x00\x01\x02\x23", "foo", 16, use_tensorflow_license=False)
self.assertTrue(
source.find("const unsigned char foo[] DATA_ALIGN_ATTRIBUTE = {"))
self.assertTrue(source.find(""" 0x00, 0x01,
0x02, 0x23,"""))
self.assertNotEqual(-1, source.find("const int foo_len = 4;"))
self.assertEqual(-1, source.find("/* Copyright"))
self.assertEqual(-1, source.find("#include " ""))
self.assertNotEqual(-1, header.find("extern const unsigned char foo[];"))
self.assertNotEqual(-1, header.find("extern const int foo_len;"))
self.assertEqual(-1, header.find("/* Copyright"))
source, header = util.convert_bytes_to_c_source(
b"\xff\xfe\xfd\xfc",
"bar",
80,
include_guard="MY_GUARD",
include_path="my/guard.h",
use_tensorflow_license=True)
self.assertNotEqual(
-1, source.find("const unsigned char bar[] DATA_ALIGN_ATTRIBUTE = {"))
self.assertNotEqual(-1, source.find(""" 0xff, 0xfe, 0xfd, 0xfc,"""))
self.assertNotEqual(-1, source.find("/* Copyright"))
self.assertNotEqual(-1, source.find("#include \"my/guard.h\""))
self.assertNotEqual(-1, header.find("#ifndef MY_GUARD"))
self.assertNotEqual(-1, header.find("#define MY_GUARD"))
self.assertNotEqual(-1, header.find("/* Copyright"))
class TensorFunctionsTest(test_util.TensorFlowTestCase):