diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index a4f236a6f6a..a41e399c4a3 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -81,7 +81,8 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":flatbuffer_utils", - "//tensorflow/python:platform", + "@absl_py//absl:app", + "@absl_py//absl/flags", ], ) @@ -92,7 +93,8 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":flatbuffer_utils", - "//tensorflow/python:platform", + "@absl_py//absl:app", + "@absl_py//absl/flags", ], ) @@ -103,7 +105,8 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":flatbuffer_utils", - "//tensorflow/python:platform", + "@absl_py//absl:app", + "@absl_py//absl/flags", ], ) diff --git a/tensorflow/lite/tools/optimize/python/BUILD b/tensorflow/lite/tools/optimize/python/BUILD index 34f57cbecaf..ead5d49f13b 100644 --- a/tensorflow/lite/tools/optimize/python/BUILD +++ b/tensorflow/lite/tools/optimize/python/BUILD @@ -15,7 +15,8 @@ py_binary( deps = [ ":modify_model_interface_constants", ":modify_model_interface_lib", - "//tensorflow/python:platform", + "@absl_py//absl:app", + "@absl_py//absl/flags", ], ) diff --git a/tensorflow/lite/tools/optimize/python/modify_model_interface.py b/tensorflow/lite/tools/optimize/python/modify_model_interface.py index 938f353b0ae..1de9edd88ae 100644 --- a/tensorflow/lite/tools/optimize/python/modify_model_interface.py +++ b/tensorflow/lite/tools/optimize/python/modify_model_interface.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,66 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -r"""Modify a quantized model's interface from float to integer. - -Example usage: -python modify_model_interface_main.py \ - --input_file=float_model.tflite \ - --output_file=int_model.tflite \ - --input_type=INT8 \ - --output_type=INT8 -""" +r"""Modify a quantized model's interface from float to integer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import app +from absl import flags from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib -from tensorflow.python.platform import app + +FLAGS = flags.FLAGS + +flags.DEFINE_string('input_tflite_file', None, + 'Full path name to the input TFLite file.') +flags.DEFINE_string('output_tflite_file', None, + 'Full path name to the output TFLite file.') +flags.DEFINE_enum('input_type', mmi_constants.DEFAULT_STR_TYPE, + mmi_constants.STR_TYPES, + 'Modified input integer interface type.') +flags.DEFINE_enum('output_type', mmi_constants.DEFAULT_STR_TYPE, + mmi_constants.STR_TYPES, + 'Modified output integer interface type.') + +flags.mark_flag_as_required('input_tflite_file') +flags.mark_flag_as_required('output_tflite_file') def main(_): - """Application run loop.""" - parser = argparse.ArgumentParser( - description="Modify a quantized model's interface from float to integer.") - parser.add_argument( - '--input_file', - type=str, - required=True, - help='Full path name to the input tflite file.') - parser.add_argument( - '--output_file', - type=str, - required=True, - help='Full path name to the output tflite file.') - parser.add_argument( - '--input_type', - type=str.upper, - choices=mmi_constants.STR_TYPES, - default=mmi_constants.DEFAULT_STR_TYPE, - help='Modified input integer interface type.') - parser.add_argument( - '--output_type', - type=str.upper, - choices=mmi_constants.STR_TYPES, - default=mmi_constants.DEFAULT_STR_TYPE, - help='Modified output integer interface type.') - args = parser.parse_args() + input_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.input_type] + output_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.output_type] - input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type] - output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type] - - mmi_lib.modify_model_interface(args.input_file, args.output_file, input_type, - output_type) + mmi_lib.modify_model_interface(FLAGS.input_file, FLAGS.output_file, + input_type, output_type) print('Successfully modified the model input type from FLOAT to ' '{input_type} and output type from FLOAT to {output_type}.'.format( - input_type=args.input_type, output_type=args.output_type)) + input_type=FLAGS.input_type, output_type=FLAGS.output_type)) if __name__ == '__main__': - app.run(main=main, argv=sys.argv[:1]) + app.run(main) diff --git a/tensorflow/lite/tools/randomize_weights.py b/tensorflow/lite/tools/randomize_weights.py index b68bdbb180b..fdf7f637d18 100644 --- a/tensorflow/lite/tools/randomize_weights.py +++ b/tensorflow/lite/tools/randomize_weights.py @@ -12,53 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -r"""Randomize all weights in a tflite file. - -Example usage: -python randomize_weights.py \ - --input_tflite_file=foo.tflite \ - --output_tflite_file=foo_randomized.tflite -""" +r"""Randomize all weights in a tflite file.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import app +from absl import flags from tensorflow.lite.tools import flatbuffer_utils -from tensorflow.python.platform import app + +FLAGS = flags.FLAGS + +flags.DEFINE_string('input_tflite_file', None, + 'Full path name to the input TFLite file.') +flags.DEFINE_string('output_tflite_file', None, + 'Full path name to the output randomized TFLite file.') +flags.DEFINE_integer('random_seed', 0, 'Input to the random number generator.') + +flags.mark_flag_as_required('input_tflite_file') +flags.mark_flag_as_required('output_tflite_file') def main(_): - parser = argparse.ArgumentParser( - description='Randomize weights in a tflite file.') - parser.add_argument( - '--input_tflite_file', - type=str, - required=True, - help='Full path name to the input tflite file.') - parser.add_argument( - '--output_tflite_file', - type=str, - required=True, - help='Full path name to the output randomized tflite file.') - parser.add_argument( - '--random_seed', - type=str, - required=False, - default=0, - help='Input to the random number generator. The default value is 0.') - args = parser.parse_args() - - # Read the model - model = flatbuffer_utils.read_model(args.input_tflite_file) - # Invoke the randomize weights function - flatbuffer_utils.randomize_weights(model, args.random_seed) - # Write the model - flatbuffer_utils.write_model(model, args.output_tflite_file) + model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) + flatbuffer_utils.randomize_weights(model, FLAGS.random_seed) + flatbuffer_utils.write_model(model, FLAGS.output_tflite_file) if __name__ == '__main__': - app.run(main=main, argv=sys.argv[:1]) + app.run(main) diff --git a/tensorflow/lite/tools/reverse_xxd_dump_from_cc.py b/tensorflow/lite/tools/reverse_xxd_dump_from_cc.py index cb7c73b6a2a..8e9c2289fbd 100644 --- a/tensorflow/lite/tools/reverse_xxd_dump_from_cc.py +++ b/tensorflow/lite/tools/reverse_xxd_dump_from_cc.py @@ -12,57 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -r"""Reverses xxd dump from to binary file +r"""Reverses xxd dump, i.e, converts a C++ source file back to a TFLite file. -This script is used to convert models from C++ source file (dumped with xxd) to -the binary model weight file and analyze it with model visualizer like Netron -(https://github.com/lutzroeder/netron) or load the model in TensorFlow Python -API -to evaluate the results in Python. - -The command to dump binary file to C++ source file looks like +This script is used to convert a model from a C++ source file (dumped with xxd) +back to it's original TFLite file format in order to analyze it with either a +model visualizer like Netron (https://github.com/lutzroeder/netron) or to +evaluate the model using the Python TensorFlow Lite Interpreter API. +The xxd command to dump the TFLite file to a C++ source file looks like: xxd -i model_data.tflite > model_data.cc -Example usage: - -python reverse_xxd_dump_from_cc.py \ - --input_cc_file=model_data.cc \ - --output_tflite_file=model_data.tflite """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import app +from absl import flags from tensorflow.lite.tools import flatbuffer_utils -from tensorflow.python.platform import app + +FLAGS = flags.FLAGS + +flags.DEFINE_string('input_cc_file', None, + 'Full path name to the input C++ source file.') +flags.DEFINE_string('output_tflite_file', None, + 'Full path name to the output TFLite file.') + +flags.mark_flag_as_required('input_cc_file') +flags.mark_flag_as_required('output_tflite_file') def main(_): - """Application run loop.""" - parser = argparse.ArgumentParser( - description='Reverses xxd dump from to binary file') - parser.add_argument( - '--input_cc_file', - type=str, - required=True, - help='Full path name to the input cc file.') - parser.add_argument( - '--output_tflite_file', - type=str, - required=True, - help='Full path name to the stripped output tflite file.') - - args = parser.parse_args() - - # Read the model from xxd output C++ source file - model = flatbuffer_utils.xxd_output_to_object(args.input_cc_file) - # Write the model - flatbuffer_utils.write_model(model, args.output_tflite_file) + model = flatbuffer_utils.xxd_output_to_object(FLAGS.input_cc_file) + flatbuffer_utils.write_model(model, FLAGS.output_tflite_file) if __name__ == '__main__': - app.run(main=main, argv=sys.argv[:1]) + app.run(main) diff --git a/tensorflow/lite/tools/strip_strings.py b/tensorflow/lite/tools/strip_strings.py index e24d2b737c5..b59465fd9ab 100644 --- a/tensorflow/lite/tools/strip_strings.py +++ b/tensorflow/lite/tools/strip_strings.py @@ -12,48 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -r"""Strips all nonessential strings from a tflite file. - -Example usage: -python strip_strings.py \ - --input_tflite_file=foo.tflite \ - --output_tflite_file=foo_stripped.tflite -""" +r"""Strips all nonessential strings from a TFLite file.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import app +from absl import flags from tensorflow.lite.tools import flatbuffer_utils -from tensorflow.python.platform import app + +FLAGS = flags.FLAGS + +flags.DEFINE_string('input_tflite_file', None, + 'Full path name to the input TFLite file.') +flags.DEFINE_string('output_tflite_file', None, + 'Full path name to the output stripped TFLite file.') + +flags.mark_flag_as_required('input_tflite_file') +flags.mark_flag_as_required('output_tflite_file') def main(_): - """Application run loop.""" - parser = argparse.ArgumentParser( - description='Strips all nonessential strings from a tflite file.') - parser.add_argument( - '--input_tflite_file', - type=str, - required=True, - help='Full path name to the input tflite file.') - parser.add_argument( - '--output_tflite_file', - type=str, - required=True, - help='Full path name to the stripped output tflite file.') - args = parser.parse_args() - - # Read the model - model = flatbuffer_utils.read_model(args.input_tflite_file) - # Invoke the strip tflite file function + model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) flatbuffer_utils.strip_strings(model) - # Write the model - flatbuffer_utils.write_model(model, args.output_tflite_file) + flatbuffer_utils.write_model(model, FLAGS.output_tflite_file) if __name__ == '__main__': - app.run(main=main, argv=sys.argv[:1]) + app.run(main)