Update the usage of the 'absl-py' python library from TFLite Micro utils
PiperOrigin-RevId: 344322150 Change-Id: Ida045ffe25f630663ab91170c72d370e3f7f7634
This commit is contained in:
parent
aa7b7395e7
commit
c624153938
@ -433,9 +433,9 @@ py_binary(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":lite",
|
|
||||||
":util",
|
":util",
|
||||||
"@six_archive//:six",
|
"@absl_py//absl:app",
|
||||||
|
"@absl_py//absl/flags",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# Lint as: python2, python3
|
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -13,94 +12,61 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Python command line interface for converting TF Lite files into C source."""
|
"""Converts a TFLite model to a TFLite Micro model (C++ Source)."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
from absl import app
|
||||||
import sys
|
from absl import flags
|
||||||
|
|
||||||
from tensorflow.lite.python import util
|
from tensorflow.lite.python import util
|
||||||
from tensorflow.python.platform import app
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string("input_tflite_file", None,
|
||||||
|
"Full path name to the input TFLite model file.")
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"output_source_file", None,
|
||||||
|
"Full path name to the output TFLite Micro model (C++ Source) file).")
|
||||||
|
flags.DEFINE_string("output_header_file", None,
|
||||||
|
"Full filepath of the output C header file.")
|
||||||
|
flags.DEFINE_string("array_variable_name", None,
|
||||||
|
"Name to use for the C data array variable.")
|
||||||
|
flags.DEFINE_integer("line_width", 80, "Width to use for formatting.")
|
||||||
|
flags.DEFINE_string("include_guard", None,
|
||||||
|
"Name to use for the C header include guard.")
|
||||||
|
flags.DEFINE_string("include_path", None,
|
||||||
|
"Optional path to include in generated source file.")
|
||||||
|
flags.DEFINE_boolean(
|
||||||
|
"use_tensorflow_license", False,
|
||||||
|
"Whether to prefix the generated files with the TF Apache2 license.")
|
||||||
|
|
||||||
|
flags.mark_flag_as_required("input_tflite_file")
|
||||||
|
flags.mark_flag_as_required("output_source_file")
|
||||||
|
flags.mark_flag_as_required("output_header_file")
|
||||||
|
flags.mark_flag_as_required("array_variable_name")
|
||||||
|
|
||||||
|
|
||||||
def run_main(_):
|
def main(_):
|
||||||
"""Main in convert_file_to_c_source.py."""
|
with open(FLAGS.input_tflite_file, "rb") as input_handle:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description=("Command line tool to run TensorFlow Lite Converter."))
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--input_tflite_file",
|
|
||||||
type=str,
|
|
||||||
help="Full filepath of the input TensorFlow Lite file.",
|
|
||||||
required=True)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_source_file",
|
|
||||||
type=str,
|
|
||||||
help="Full filepath of the output C source file.",
|
|
||||||
required=True)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_header_file",
|
|
||||||
type=str,
|
|
||||||
help="Full filepath of the output C header file.",
|
|
||||||
required=True)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--array_variable_name",
|
|
||||||
type=str,
|
|
||||||
help="Name to use for the C data array variable.",
|
|
||||||
required=True)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--line_width", type=int, help="Width to use for formatting.", default=80)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--include_guard",
|
|
||||||
type=str,
|
|
||||||
help="Name to use for the C header include guard.",
|
|
||||||
default=None)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--include_path",
|
|
||||||
type=str,
|
|
||||||
help="Optional path to include in generated source file.",
|
|
||||||
default=None)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_tensorflow_license",
|
|
||||||
dest="use_tensorflow_license",
|
|
||||||
help="Whether to prefix the generated files with the TF Apache2 license.",
|
|
||||||
action="store_true")
|
|
||||||
parser.set_defaults(use_tensorflow_license=False)
|
|
||||||
|
|
||||||
flags, _ = parser.parse_known_args(args=sys.argv[1:])
|
|
||||||
|
|
||||||
with open(flags.input_tflite_file, "rb") as input_handle:
|
|
||||||
input_data = input_handle.read()
|
input_data = input_handle.read()
|
||||||
|
|
||||||
source, header = util.convert_bytes_to_c_source(
|
source, header = util.convert_bytes_to_c_source(
|
||||||
data=input_data,
|
data=input_data,
|
||||||
array_name=flags.array_variable_name,
|
array_name=FLAGS.array_variable_name,
|
||||||
max_line_width=flags.line_width,
|
max_line_width=FLAGS.line_width,
|
||||||
include_guard=flags.include_guard,
|
include_guard=FLAGS.include_guard,
|
||||||
include_path=flags.include_path,
|
include_path=FLAGS.include_path,
|
||||||
use_tensorflow_license=flags.use_tensorflow_license)
|
use_tensorflow_license=FLAGS.use_tensorflow_license)
|
||||||
|
|
||||||
with open(flags.output_source_file, "w") as source_handle:
|
with open(FLAGS.output_source_file, "w") as source_handle:
|
||||||
source_handle.write(source)
|
source_handle.write(source)
|
||||||
|
|
||||||
with open(flags.output_header_file, "w") as header_handle:
|
with open(FLAGS.output_header_file, "w") as header_handle:
|
||||||
header_handle.write(header)
|
header_handle.write(header)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app.run(main=run_main, argv=sys.argv[:1])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
app.run(main)
|
||||||
|
@ -34,7 +34,7 @@ ${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/python/convert_file_to_c_source
|
|||||||
--line_width=80 \
|
--line_width=80 \
|
||||||
--include_guard="SOME_GUARD_H_" \
|
--include_guard="SOME_GUARD_H_" \
|
||||||
--include_path="some/guard.h" \
|
--include_path="some/guard.h" \
|
||||||
--use_tensorflow_license
|
--use_tensorflow_license=True
|
||||||
|
|
||||||
if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then
|
if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then
|
||||||
echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'"
|
echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'"
|
||||||
|
@ -52,8 +52,9 @@ py_binary(
|
|||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:image_ops",
|
"//tensorflow/python:image_ops",
|
||||||
"//tensorflow/python:io_ops",
|
"//tensorflow/python:io_ops",
|
||||||
"//tensorflow/python:platform",
|
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
|
"@absl_py//absl:app",
|
||||||
|
"@absl_py//absl/flags",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# Lint as: python2, python3
|
|
||||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -15,27 +14,35 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
r"""This tool converts an image file into a CSV data array.
|
r"""This tool converts an image file into a CSV data array.
|
||||||
|
|
||||||
Designed to help create test inputs that can be shared between Python and
|
Loads JPEG or PNG input files, resizes them, optionally converts to grayscale,
|
||||||
on-device test cases to investigate accuracy issues.
|
and writes out as comma-separated variables, one image per row. Designed to
|
||||||
|
help create test inputs that can be shared between Python and on-device test
|
||||||
|
cases to investigate accuracy issues.
|
||||||
|
|
||||||
Example usage:
|
|
||||||
|
|
||||||
python convert_image_to_csv.py some_image.jpg --width=16 --height=20 \
|
|
||||||
--want_grayscale
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework.errors_impl import NotFoundError
|
from tensorflow.python.framework.errors_impl import NotFoundError
|
||||||
from tensorflow.python.ops import image_ops
|
from tensorflow.python.ops import image_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.platform import app
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_multi_string("image_file_names", None,
|
||||||
|
"List of paths to the input images.")
|
||||||
|
flags.DEFINE_integer("width", 96, "Width to scale images to.")
|
||||||
|
flags.DEFINE_integer("height", 96, "Height to scale images to.")
|
||||||
|
flags.DEFINE_boolean("want_grayscale", False,
|
||||||
|
"Whether to convert the image to monochrome.")
|
||||||
|
|
||||||
|
|
||||||
def get_image(width, height, want_grayscale, filepath):
|
def get_image(width, height, want_grayscale, filepath):
|
||||||
@ -55,10 +62,9 @@ def get_image(width, height, want_grayscale, filepath):
|
|||||||
with session.Session():
|
with session.Session():
|
||||||
file_data = io_ops.read_file(filepath)
|
file_data = io_ops.read_file(filepath)
|
||||||
channels = 1 if want_grayscale else 3
|
channels = 1 if want_grayscale else 3
|
||||||
image_tensor = image_ops.decode_image(file_data,
|
image_tensor = image_ops.decode_image(file_data, channels=channels).eval()
|
||||||
channels=channels).eval()
|
resized_tensor = image_ops.resize_images_v2(image_tensor,
|
||||||
resized_tensor = image_ops.resize_images_v2(
|
(height, width)).eval()
|
||||||
image_tensor, (height, width)).eval()
|
|
||||||
return resized_tensor
|
return resized_tensor
|
||||||
|
|
||||||
|
|
||||||
@ -73,43 +79,19 @@ def array_to_int_csv(array_data):
|
|||||||
"""
|
"""
|
||||||
flattened_array = array_data.flatten()
|
flattened_array = array_data.flatten()
|
||||||
array_as_strings = [item.astype(int).astype(str) for item in flattened_array]
|
array_as_strings = [item.astype(int).astype(str) for item in flattened_array]
|
||||||
return ','.join(array_as_strings)
|
return ",".join(array_as_strings)
|
||||||
|
|
||||||
|
|
||||||
def run_main(_):
|
def main(_):
|
||||||
"""Application run loop."""
|
for image_file_name in FLAGS.image_file_names:
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Loads JPEG or PNG input files, resizes them, optionally'
|
|
||||||
' converts to grayscale, and writes out as comma-separated variables,'
|
|
||||||
' one image per row.')
|
|
||||||
parser.add_argument(
|
|
||||||
'image_file_names',
|
|
||||||
type=str,
|
|
||||||
nargs='+',
|
|
||||||
help='List of paths to the input images.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--width', type=int, default=96, help='Width to scale images to.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--height', type=int, default=96, help='Height to scale images to.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--want_grayscale',
|
|
||||||
action='store_true',
|
|
||||||
help='Whether to convert the image to monochrome.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for image_file_name in args.image_file_names:
|
|
||||||
try:
|
try:
|
||||||
image_data = get_image(args.width, args.height, args.want_grayscale,
|
image_data = get_image(FLAGS.width, FLAGS.height, FLAGS.want_grayscale,
|
||||||
image_file_name)
|
image_file_name)
|
||||||
print(array_to_int_csv(image_data))
|
print(array_to_int_csv(image_data))
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
sys.stderr.write('Image file not found at {0}\n'.format(image_file_name))
|
sys.stderr.write("Image file not found at {0}\n".format(image_file_name))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
app.run(main=run_main, argv=sys.argv[:1])
|
app.run(main)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user