Update the usage of the 'absl-py' python library from TFLite Micro utils

PiperOrigin-RevId: 344322150
Change-Id: Ida045ffe25f630663ab91170c72d370e3f7f7634
This commit is contained in:
Meghna Natraj 2020-11-25 15:01:43 -08:00 committed by TensorFlower Gardener
parent aa7b7395e7
commit c624153938
5 changed files with 69 additions and 120 deletions

View File

@ -433,9 +433,9 @@ py_binary(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":lite",
":util",
"@six_archive//:six",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)

View File

@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -13,94 +12,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python command line interface for converting TF Lite files into C source."""
"""Converts a TFLite model to a TFLite Micro model (C++ Source)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from absl import app
from absl import flags
from tensorflow.lite.python import util
from tensorflow.python.platform import app
FLAGS = flags.FLAGS
flags.DEFINE_string("input_tflite_file", None,
"Full path name to the input TFLite model file.")
flags.DEFINE_string(
"output_source_file", None,
"Full path name to the output TFLite Micro model (C++ Source) file).")
flags.DEFINE_string("output_header_file", None,
"Full filepath of the output C header file.")
flags.DEFINE_string("array_variable_name", None,
"Name to use for the C data array variable.")
flags.DEFINE_integer("line_width", 80, "Width to use for formatting.")
flags.DEFINE_string("include_guard", None,
"Name to use for the C header include guard.")
flags.DEFINE_string("include_path", None,
"Optional path to include in generated source file.")
flags.DEFINE_boolean(
"use_tensorflow_license", False,
"Whether to prefix the generated files with the TF Apache2 license.")
flags.mark_flag_as_required("input_tflite_file")
flags.mark_flag_as_required("output_source_file")
flags.mark_flag_as_required("output_header_file")
flags.mark_flag_as_required("array_variable_name")
def run_main(_):
"""Main in convert_file_to_c_source.py."""
parser = argparse.ArgumentParser(
description=("Command line tool to run TensorFlow Lite Converter."))
parser.add_argument(
"--input_tflite_file",
type=str,
help="Full filepath of the input TensorFlow Lite file.",
required=True)
parser.add_argument(
"--output_source_file",
type=str,
help="Full filepath of the output C source file.",
required=True)
parser.add_argument(
"--output_header_file",
type=str,
help="Full filepath of the output C header file.",
required=True)
parser.add_argument(
"--array_variable_name",
type=str,
help="Name to use for the C data array variable.",
required=True)
parser.add_argument(
"--line_width", type=int, help="Width to use for formatting.", default=80)
parser.add_argument(
"--include_guard",
type=str,
help="Name to use for the C header include guard.",
default=None)
parser.add_argument(
"--include_path",
type=str,
help="Optional path to include in generated source file.",
default=None)
parser.add_argument(
"--use_tensorflow_license",
dest="use_tensorflow_license",
help="Whether to prefix the generated files with the TF Apache2 license.",
action="store_true")
parser.set_defaults(use_tensorflow_license=False)
flags, _ = parser.parse_known_args(args=sys.argv[1:])
with open(flags.input_tflite_file, "rb") as input_handle:
def main(_):
with open(FLAGS.input_tflite_file, "rb") as input_handle:
input_data = input_handle.read()
source, header = util.convert_bytes_to_c_source(
data=input_data,
array_name=flags.array_variable_name,
max_line_width=flags.line_width,
include_guard=flags.include_guard,
include_path=flags.include_path,
use_tensorflow_license=flags.use_tensorflow_license)
array_name=FLAGS.array_variable_name,
max_line_width=FLAGS.line_width,
include_guard=FLAGS.include_guard,
include_path=FLAGS.include_path,
use_tensorflow_license=FLAGS.use_tensorflow_license)
with open(flags.output_source_file, "w") as source_handle:
with open(FLAGS.output_source_file, "w") as source_handle:
source_handle.write(source)
with open(flags.output_header_file, "w") as header_handle:
with open(FLAGS.output_header_file, "w") as header_handle:
header_handle.write(header)
def main():
app.run(main=run_main, argv=sys.argv[:1])
if __name__ == "__main__":
main()
app.run(main)

View File

@ -34,7 +34,7 @@ ${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/python/convert_file_to_c_source
--line_width=80 \
--include_guard="SOME_GUARD_H_" \
--include_path="some/guard.h" \
--use_tensorflow_license
--use_tensorflow_license=True
if ! grep -q 'const unsigned char g_some_array' ${OUTPUT_SOURCE_FILE}; then
echo "ERROR: No array found in output '${OUTPUT_SOURCE_FILE}'"

View File

@ -52,8 +52,9 @@ py_binary(
"//tensorflow/python:framework_ops",
"//tensorflow/python:image_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)

View File

@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -15,27 +14,35 @@
# ==============================================================================
r"""This tool converts an image file into a CSV data array.
Designed to help create test inputs that can be shared between Python and
on-device test cases to investigate accuracy issues.
Loads JPEG or PNG input files, resizes them, optionally converts to grayscale,
and writes out as comma-separated variables, one image per row. Designed to
help create test inputs that can be shared between Python and on-device test
cases to investigate accuracy issues.
Example usage:
python convert_image_to_csv.py some_image.jpg --width=16 --height=20 \
--want_grayscale
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from absl import app
from absl import flags
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import app
FLAGS = flags.FLAGS
flags.DEFINE_multi_string("image_file_names", None,
"List of paths to the input images.")
flags.DEFINE_integer("width", 96, "Width to scale images to.")
flags.DEFINE_integer("height", 96, "Height to scale images to.")
flags.DEFINE_boolean("want_grayscale", False,
"Whether to convert the image to monochrome.")
def get_image(width, height, want_grayscale, filepath):
@ -55,10 +62,9 @@ def get_image(width, height, want_grayscale, filepath):
with session.Session():
file_data = io_ops.read_file(filepath)
channels = 1 if want_grayscale else 3
image_tensor = image_ops.decode_image(file_data,
channels=channels).eval()
resized_tensor = image_ops.resize_images_v2(
image_tensor, (height, width)).eval()
image_tensor = image_ops.decode_image(file_data, channels=channels).eval()
resized_tensor = image_ops.resize_images_v2(image_tensor,
(height, width)).eval()
return resized_tensor
@ -73,43 +79,19 @@ def array_to_int_csv(array_data):
"""
flattened_array = array_data.flatten()
array_as_strings = [item.astype(int).astype(str) for item in flattened_array]
return ','.join(array_as_strings)
return ",".join(array_as_strings)
def run_main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description='Loads JPEG or PNG input files, resizes them, optionally'
' converts to grayscale, and writes out as comma-separated variables,'
' one image per row.')
parser.add_argument(
'image_file_names',
type=str,
nargs='+',
help='List of paths to the input images.')
parser.add_argument(
'--width', type=int, default=96, help='Width to scale images to.')
parser.add_argument(
'--height', type=int, default=96, help='Height to scale images to.')
parser.add_argument(
'--want_grayscale',
action='store_true',
help='Whether to convert the image to monochrome.')
args = parser.parse_args()
for image_file_name in args.image_file_names:
def main(_):
for image_file_name in FLAGS.image_file_names:
try:
image_data = get_image(args.width, args.height, args.want_grayscale,
image_data = get_image(FLAGS.width, FLAGS.height, FLAGS.want_grayscale,
image_file_name)
print(array_to_int_csv(image_data))
except NotFoundError:
sys.stderr.write('Image file not found at {0}\n'.format(image_file_name))
sys.stderr.write("Image file not found at {0}\n".format(image_file_name))
sys.exit(1)
def main():
app.run(main=run_main, argv=sys.argv[:1])
if __name__ == '__main__':
main()
if __name__ == "__main__":
app.run(main)