Update the usage of the 'absl-py' python library from TFLite Micro utils
PiperOrigin-RevId: 344322150 Change-Id: Ida045ffe25f630663ab91170c72d370e3f7f7634
This commit is contained in:
parent
aa7b7395e7
commit
c624153938
@ -433,9 +433,9 @@ py_binary(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lite",
|
||||
":util",
|
||||
"@six_archive//:six",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -13,94 +12,61 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python command line interface for converting TF Lite files into C source."""
|
||||
"""Converts a TFLite model to a TFLite Micro model (C++ Source)."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow.lite.python import util
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("input_tflite_file", None,
|
||||
"Full path name to the input TFLite model file.")
|
||||
flags.DEFINE_string(
|
||||
"output_source_file", None,
|
||||
"Full path name to the output TFLite Micro model (C++ Source) file).")
|
||||
flags.DEFINE_string("output_header_file", None,
|
||||
"Full filepath of the output C header file.")
|
||||
flags.DEFINE_string("array_variable_name", None,
|
||||
"Name to use for the C data array variable.")
|
||||
flags.DEFINE_integer("line_width", 80, "Width to use for formatting.")
|
||||
flags.DEFINE_string("include_guard", None,
|
||||
"Name to use for the C header include guard.")
|
||||
flags.DEFINE_string("include_path", None,
|
||||
"Optional path to include in generated source file.")
|
||||
flags.DEFINE_boolean(
|
||||
"use_tensorflow_license", False,
|
||||
"Whether to prefix the generated files with the TF Apache2 license.")
|
||||
|
||||
flags.mark_flag_as_required("input_tflite_file")
|
||||
flags.mark_flag_as_required("output_source_file")
|
||||
flags.mark_flag_as_required("output_header_file")
|
||||
flags.mark_flag_as_required("array_variable_name")
|
||||
|
||||
|
||||
def run_main(_):
|
||||
"""Main in convert_file_to_c_source.py."""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Command line tool to run TensorFlow Lite Converter."))
|
||||
|
||||
parser.add_argument(
|
||||
"--input_tflite_file",
|
||||
type=str,
|
||||
help="Full filepath of the input TensorFlow Lite file.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_source_file",
|
||||
type=str,
|
||||
help="Full filepath of the output C source file.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_header_file",
|
||||
type=str,
|
||||
help="Full filepath of the output C header file.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--array_variable_name",
|
||||
type=str,
|
||||
help="Name to use for the C data array variable.",
|
||||
required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--line_width", type=int, help="Width to use for formatting.", default=80)
|
||||
|
||||
parser.add_argument(
|
||||
"--include_guard",
|
||||
type=str,
|
||||
help="Name to use for the C header include guard.",
|
||||
default=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--include_path",
|
||||
type=str,
|
||||
help="Optional path to include in generated source file.",
|
||||
default=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_tensorflow_license",
|
||||
dest="use_tensorflow_license",
|
||||
help="Whether to prefix the generated files with the TF Apache2 license.",
|
||||
action="store_true")
|
||||
parser.set_defaults(use_tensorflow_license=False)
|
||||
|
||||
flags, _ = parser.parse_known_args(args=sys.argv[1:])
|
||||
|
||||
with open(flags.input_tflite_file, "rb") as input_handle:
|
||||
def main(_):
|
||||
with open(FLAGS.input_tflite_file, "rb") as input_handle:
|
||||
input_data = input_handle.read()
|
||||
|
||||
source, header = util.convert_bytes_to_c_source(
|
||||
data=input_data,
|
||||
array_name=flags.array_variable_name,
|
||||
max_line_width=flags.line_width,
|
||||
include_guard=flags.include_guard,
|
||||
include_path=flags.include_path,
|
||||
use_tensorflow_license=flags.use_tensorflow_license)
|
||||
array_name=FLAGS.array_variable_name,
|
||||
max_line_width=FLAGS.line_width,
|
||||
include_guard=FLAGS.include_guard,
|
||||
include_path=FLAGS.include_path,
|
||||
use_tensorflow_license=FLAGS.use_tensorflow_license)
|
||||
|
||||
with open(flags.output_source_file, "w") as source_handle:
|
||||
with open(FLAGS.output_source_file, "w") as source_handle:
|
||||
source_handle.write(source)
|
||||
|
||||
with open(flags.output_header_file, "w") as header_handle:
|
||||
with open(FLAGS.output_header_file, "w") as header_handle:
|
||||
header_handle.write(header)
|
||||
|
||||
|
||||
def main():
|
||||
app.run(main=run_main, argv=sys.argv[:1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
app.run(main)
|
||||
|
@ -34,7 +34,7 @@ ${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/python/convert_file_to_c_source
|
||||
--line_width=80 \
|
||||
--include_guard="SOME_GUARD_H_" \
|
||||
--include_path="some/guard.h" \
|
||||
--use_tensorflow_license
|
||||
--use_tensorflow_license=True
|
||||
|
||||
if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then
|
||||
echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'"
|
||||
|
@ -52,8 +52,9 @@ py_binary(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:image_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:session",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -15,27 +14,35 @@
|
||||
# ==============================================================================
|
||||
r"""This tool converts an image file into a CSV data array.
|
||||
|
||||
Designed to help create test inputs that can be shared between Python and
|
||||
on-device test cases to investigate accuracy issues.
|
||||
Loads JPEG or PNG input files, resizes them, optionally converts to grayscale,
|
||||
and writes out as comma-separated variables, one image per row. Designed to
|
||||
help create test inputs that can be shared between Python and on-device test
|
||||
cases to investigate accuracy issues.
|
||||
|
||||
Example usage:
|
||||
|
||||
python convert_image_to_csv.py some_image.jpg --width=16 --height=20 \
|
||||
--want_grayscale
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.errors_impl import NotFoundError
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_multi_string("image_file_names", None,
|
||||
"List of paths to the input images.")
|
||||
flags.DEFINE_integer("width", 96, "Width to scale images to.")
|
||||
flags.DEFINE_integer("height", 96, "Height to scale images to.")
|
||||
flags.DEFINE_boolean("want_grayscale", False,
|
||||
"Whether to convert the image to monochrome.")
|
||||
|
||||
|
||||
def get_image(width, height, want_grayscale, filepath):
|
||||
@ -55,10 +62,9 @@ def get_image(width, height, want_grayscale, filepath):
|
||||
with session.Session():
|
||||
file_data = io_ops.read_file(filepath)
|
||||
channels = 1 if want_grayscale else 3
|
||||
image_tensor = image_ops.decode_image(file_data,
|
||||
channels=channels).eval()
|
||||
resized_tensor = image_ops.resize_images_v2(
|
||||
image_tensor, (height, width)).eval()
|
||||
image_tensor = image_ops.decode_image(file_data, channels=channels).eval()
|
||||
resized_tensor = image_ops.resize_images_v2(image_tensor,
|
||||
(height, width)).eval()
|
||||
return resized_tensor
|
||||
|
||||
|
||||
@ -73,43 +79,19 @@ def array_to_int_csv(array_data):
|
||||
"""
|
||||
flattened_array = array_data.flatten()
|
||||
array_as_strings = [item.astype(int).astype(str) for item in flattened_array]
|
||||
return ','.join(array_as_strings)
|
||||
return ",".join(array_as_strings)
|
||||
|
||||
|
||||
def run_main(_):
|
||||
"""Application run loop."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Loads JPEG or PNG input files, resizes them, optionally'
|
||||
' converts to grayscale, and writes out as comma-separated variables,'
|
||||
' one image per row.')
|
||||
parser.add_argument(
|
||||
'image_file_names',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='List of paths to the input images.')
|
||||
parser.add_argument(
|
||||
'--width', type=int, default=96, help='Width to scale images to.')
|
||||
parser.add_argument(
|
||||
'--height', type=int, default=96, help='Height to scale images to.')
|
||||
parser.add_argument(
|
||||
'--want_grayscale',
|
||||
action='store_true',
|
||||
help='Whether to convert the image to monochrome.')
|
||||
args = parser.parse_args()
|
||||
|
||||
for image_file_name in args.image_file_names:
|
||||
def main(_):
|
||||
for image_file_name in FLAGS.image_file_names:
|
||||
try:
|
||||
image_data = get_image(args.width, args.height, args.want_grayscale,
|
||||
image_data = get_image(FLAGS.width, FLAGS.height, FLAGS.want_grayscale,
|
||||
image_file_name)
|
||||
print(array_to_int_csv(image_data))
|
||||
except NotFoundError:
|
||||
sys.stderr.write('Image file not found at {0}\n'.format(image_file_name))
|
||||
sys.stderr.write("Image file not found at {0}\n".format(image_file_name))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
app.run(main=run_main, argv=sys.argv[:1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
Loading…
Reference in New Issue
Block a user