Utility for converting TF Lite files to embedded C data arrays
PiperOrigin-RevId: 277200994 Change-Id: I152594d72e47a30fb6aefeec541dc96181453295
This commit is contained in:
parent
c7a4062cc2
commit
ba7ff487f0
@ -1,5 +1,3 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -309,3 +307,22 @@ py_test(
|
||||
"//tensorflow/python/saved_model",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "convert_file_to_c_source",
|
||||
srcs = ["convert_file_to_c_source.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lite",
|
||||
":util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
sh_test(
|
||||
name = "convert_file_to_c_source_test",
|
||||
srcs = ["convert_file_to_c_source_test.sh"],
|
||||
data = [":convert_file_to_c_source"],
|
||||
)
|
||||
|
106
tensorflow/lite/python/convert_file_to_c_source.py
Normal file
106
tensorflow/lite/python/convert_file_to_c_source.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python command line interface for converting TF Lite files into C source."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.lite.python import util
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
|
||||
def run_main(_):
|
||||
"""Main in convert_file_to_c_source.py."""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Command line tool to run TensorFlow Lite Converter."))
|
||||
|
||||
parser.add_argument(
|
||||
"--input_tflite_file",
|
||||
type=str,
|
||||
help="Full filepath of the input TensorFlow Lite file.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_source_file",
|
||||
type=str,
|
||||
help="Full filepath of the output C source file.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_header_file",
|
||||
type=str,
|
||||
help="Full filepath of the output C header file.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--array_variable_name",
|
||||
type=str,
|
||||
help="Name to use for the C data array variable.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--line_width", type=int, help="Width to use for formatting.", default=80)
|
||||
|
||||
parser.add_argument(
|
||||
"--include_guard",
|
||||
type=str,
|
||||
help="Name to use for the C header include guard.",
|
||||
default=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--include_path",
|
||||
type=str,
|
||||
help="Optional path to include in generated source file.",
|
||||
default=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_tensorflow_license",
|
||||
dest="use_tensorflow_license",
|
||||
help="Whether to prefix the generated files with the TF Apache2 license.",
|
||||
action="store_true")
|
||||
parser.set_defaults(use_tensorflow_license=False)
|
||||
|
||||
flags, _ = parser.parse_known_args(args=sys.argv[1:])
|
||||
|
||||
with open(flags.input_tflite_file, "rb") as input_handle:
|
||||
input_data = input_handle.read()
|
||||
|
||||
source, header = util.convert_bytes_to_c_source(
|
||||
data=input_data,
|
||||
array_name=flags.array_variable_name,
|
||||
max_line_width=flags.line_width,
|
||||
include_guard=flags.include_guard,
|
||||
include_path=flags.include_path,
|
||||
use_tensorflow_license=flags.use_tensorflow_license)
|
||||
|
||||
with open(flags.output_source_file, "w") as source_handle:
|
||||
source_handle.write(source)
|
||||
|
||||
with open(flags.output_header_file, "w") as header_handle:
|
||||
header_handle.write(header)
|
||||
|
||||
|
||||
def main():
|
||||
app.run(main=run_main, argv=sys.argv[:1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
87
tensorflow/lite/python/convert_file_to_c_source_test.sh
Executable file
87
tensorflow/lite/python/convert_file_to_c_source_test.sh
Executable file
@ -0,0 +1,87 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
#
|
||||
# Bash unit tests for the TensorFlow Lite Micro project generator.
|
||||
|
||||
set -e
|
||||
|
||||
INPUT_FILE=${TEST_TMPDIR}/input.tflite
|
||||
printf "\x00\x01\x02\x03" > ${INPUT_FILE}
|
||||
|
||||
OUTPUT_SOURCE_FILE=${TEST_TMPDIR}/output_source.cc
|
||||
OUTPUT_HEADER_FILE=${TEST_TMPDIR}/output_header.h
|
||||
|
||||
# Needed for copybara compatibility.
|
||||
SCRIPT_BASE_DIR=/org_"tensor"flow
|
||||
${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/python/convert_file_to_c_source \
|
||||
--input_tflite_file="${INPUT_FILE}" \
|
||||
--output_source_file="${OUTPUT_SOURCE_FILE}" \
|
||||
--output_header_file="${OUTPUT_HEADER_FILE}" \
|
||||
--array_variable_name="g_some_array" \
|
||||
--line_width=80 \
|
||||
--include_guard="SOME_GUARD_H_" \
|
||||
--include_path="some/guard.h" \
|
||||
--use_tensorflow_license
|
||||
|
||||
if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then
|
||||
echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q '0x00, 0x01, 0x02, 0x03' ${OUTPUT_SOURCE_FILE}; then
|
||||
echo "ERROR: No array values found in output '${OUTPUT_SOURCE_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q 'const int g_some_array_len = 4;' ${OUTPUT_SOURCE_FILE}; then
|
||||
echo "ERROR: No array length found in output '${OUTPUT_SOURCE_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q 'The TensorFlow Authors. All Rights Reserved' ${OUTPUT_SOURCE_FILE}; then
|
||||
echo "ERROR: No license found in output '${OUTPUT_SOURCE_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q '\#include "some/guard\.h"' ${OUTPUT_SOURCE_FILE}; then
|
||||
echo "ERROR: No include found in output '${OUTPUT_SOURCE_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
if ! grep -q '#ifndef SOME_GUARD_H_' ${OUTPUT_HEADER_FILE}; then
|
||||
echo "ERROR: No include guard found in output '${OUTPUT_HEADER_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q 'extern const unsigned char g_some_array' ${OUTPUT_HEADER_FILE}; then
|
||||
echo "ERROR: No array found in output '${OUTPUT_HEADER_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q 'extern const int g_some_array_len;' ${OUTPUT_HEADER_FILE}; then
|
||||
echo "ERROR: No array length found in output '${OUTPUT_HEADER_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! grep -q 'The TensorFlow Authors. All Rights Reserved' ${OUTPUT_HEADER_FILE}; then
|
||||
echo "ERROR: No license found in output '${OUTPUT_HEADER_FILE}'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo
|
||||
echo "SUCCESS: convert_file_to_c_source test PASSED"
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
import six
|
||||
@ -354,3 +355,123 @@ def get_debug_info(nodes_to_debug_info_func, converted_graph):
|
||||
|
||||
# Convert the nodes to the debug info proto object.
|
||||
return nodes_to_debug_info_func(original_nodes)
|
||||
|
||||
|
||||
def convert_bytes_to_c_source(data,
|
||||
array_name,
|
||||
max_line_width=80,
|
||||
include_guard=None,
|
||||
include_path=None,
|
||||
use_tensorflow_license=False):
|
||||
"""Returns strings representing a C constant array containing `data`.
|
||||
|
||||
Args:
|
||||
data: Byte array that will be converted into a C constant.
|
||||
array_name: String to use as the variable name for the constant array.
|
||||
max_line_width: The longest line length, for formatting purposes.
|
||||
include_guard: Name to use for the include guard macro definition.
|
||||
include_path: Optional path to include in the source file.
|
||||
use_tensorflow_license: Whether to include the standard TensorFlow Apache2
|
||||
license in the generated files.
|
||||
|
||||
Returns:
|
||||
Text that can be compiled as a C source file to link in the data as a
|
||||
literal array of values.
|
||||
Text that can be used as a C header file to reference the literal array.
|
||||
"""
|
||||
|
||||
starting_pad = " "
|
||||
array_lines = []
|
||||
array_line = starting_pad
|
||||
for value in bytearray(data):
|
||||
if (len(array_line) + 4) > max_line_width:
|
||||
array_lines.append(array_line + "\n")
|
||||
array_line = starting_pad
|
||||
array_line += " 0x%02x," % (value)
|
||||
if len(array_line) > len(starting_pad):
|
||||
array_lines.append(array_line + "\n")
|
||||
array_values = "".join(array_lines)
|
||||
|
||||
if include_guard is None:
|
||||
include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
|
||||
|
||||
if include_path is not None:
|
||||
include_line = "#include \"{include_path}\"\n".format(
|
||||
include_path=include_path)
|
||||
else:
|
||||
include_line = ""
|
||||
|
||||
if use_tensorflow_license:
|
||||
license_text = """
|
||||
/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
""".format(year=datetime.date.today().year)
|
||||
else:
|
||||
license_text = ""
|
||||
|
||||
source_template = """{license_text}
|
||||
// This is a TensorFlow Lite model file that has been converted into a C data
|
||||
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
|
||||
// This form is useful for compiling into a binary for devices that don't have a
|
||||
// file system.
|
||||
|
||||
{include_line}
|
||||
// We need to keep the data array aligned on some architectures.
|
||||
#ifdef __has_attribute
|
||||
#define HAVE_ATTRIBUTE(x) __has_attribute(x)
|
||||
#else
|
||||
#define HAVE_ATTRIBUTE(x) 0
|
||||
#endif
|
||||
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
|
||||
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
|
||||
#else
|
||||
#define DATA_ALIGN_ATTRIBUTE
|
||||
#endif
|
||||
|
||||
const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
|
||||
{array_values}}};
|
||||
const int {array_name}_len = {array_length};
|
||||
"""
|
||||
|
||||
source_text = source_template.format(
|
||||
array_name=array_name,
|
||||
array_length=len(data),
|
||||
array_values=array_values,
|
||||
license_text=license_text,
|
||||
include_line=include_line)
|
||||
|
||||
header_template = """
|
||||
{license_text}
|
||||
|
||||
// This is a TensorFlow Lite model file that has been converted into a C data
|
||||
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
|
||||
// This form is useful for compiling into a binary for devices that don't have a
|
||||
// file system.
|
||||
|
||||
#ifndef {include_guard}
|
||||
#define {include_guard}
|
||||
|
||||
extern const unsigned char {array_name}[];
|
||||
extern const int {array_name}_len;
|
||||
|
||||
#endif // {include_guard}
|
||||
"""
|
||||
|
||||
header_text = header_template.format(
|
||||
array_name=array_name,
|
||||
include_guard=include_guard,
|
||||
license_text=license_text)
|
||||
|
||||
return source_text, header_text
|
||||
|
@ -90,6 +90,36 @@ class UtilTest(test_util.TensorFlowTestCase):
|
||||
lower_using_switch_merge_is_removed = True
|
||||
self.assertEqual(lower_using_switch_merge_is_removed, True)
|
||||
|
||||
def testConvertBytes(self):
|
||||
source, header = util.convert_bytes_to_c_source(
|
||||
b"\x00\x01\x02\x23", "foo", 16, use_tensorflow_license=False)
|
||||
self.assertTrue(
|
||||
source.find("const unsigned char foo[] DATA_ALIGN_ATTRIBUTE = {"))
|
||||
self.assertTrue(source.find(""" 0x00, 0x01,
|
||||
0x02, 0x23,"""))
|
||||
self.assertNotEqual(-1, source.find("const int foo_len = 4;"))
|
||||
self.assertEqual(-1, source.find("/* Copyright"))
|
||||
self.assertEqual(-1, source.find("#include " ""))
|
||||
self.assertNotEqual(-1, header.find("extern const unsigned char foo[];"))
|
||||
self.assertNotEqual(-1, header.find("extern const int foo_len;"))
|
||||
self.assertEqual(-1, header.find("/* Copyright"))
|
||||
|
||||
source, header = util.convert_bytes_to_c_source(
|
||||
b"\xff\xfe\xfd\xfc",
|
||||
"bar",
|
||||
80,
|
||||
include_guard="MY_GUARD",
|
||||
include_path="my/guard.h",
|
||||
use_tensorflow_license=True)
|
||||
self.assertNotEqual(
|
||||
-1, source.find("const unsigned char bar[] DATA_ALIGN_ATTRIBUTE = {"))
|
||||
self.assertNotEqual(-1, source.find(""" 0xff, 0xfe, 0xfd, 0xfc,"""))
|
||||
self.assertNotEqual(-1, source.find("/* Copyright"))
|
||||
self.assertNotEqual(-1, source.find("#include \"my/guard.h\""))
|
||||
self.assertNotEqual(-1, header.find("#ifndef MY_GUARD"))
|
||||
self.assertNotEqual(-1, header.find("#define MY_GUARD"))
|
||||
self.assertNotEqual(-1, header.find("/* Copyright"))
|
||||
|
||||
|
||||
class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user