Update the usage of the 'absl-py' python library

PiperOrigin-RevId: 344343142
Change-Id: Ibdd6351b714aedbda1c2b7230ab5819353d7cc41
This commit is contained in:
Meghna Natraj 2020-11-25 17:50:07 -08:00 committed by TensorFlower Gardener
parent 0ec320bdd9
commit 595d575c23
6 changed files with 88 additions and 155 deletions

View File

@ -81,7 +81,8 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":flatbuffer_utils",
"//tensorflow/python:platform",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
@ -92,7 +93,8 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":flatbuffer_utils",
"//tensorflow/python:platform",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
@ -103,7 +105,8 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":flatbuffer_utils",
"//tensorflow/python:platform",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)

View File

@ -15,7 +15,8 @@ py_binary(
deps = [
":modify_model_interface_constants",
":modify_model_interface_lib",
"//tensorflow/python:platform",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)

View File

@ -1,4 +1,3 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -13,66 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Modify a quantized model's interface from float to integer.
Example usage:
python modify_model_interface_main.py \
--input_file=float_model.tflite \
--output_file=int_model.tflite \
--input_type=INT8 \
--output_type=INT8
"""
r"""Modify a quantized model's interface from float to integer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from absl import app
from absl import flags
from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib
from tensorflow.python.platform import app
FLAGS = flags.FLAGS
flags.DEFINE_string('input_tflite_file', None,
'Full path name to the input TFLite file.')
flags.DEFINE_string('output_tflite_file', None,
'Full path name to the output TFLite file.')
flags.DEFINE_enum('input_type', mmi_constants.DEFAULT_STR_TYPE,
mmi_constants.STR_TYPES,
'Modified input integer interface type.')
flags.DEFINE_enum('output_type', mmi_constants.DEFAULT_STR_TYPE,
mmi_constants.STR_TYPES,
'Modified output integer interface type.')
flags.mark_flag_as_required('input_tflite_file')
flags.mark_flag_as_required('output_tflite_file')
def main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description="Modify a quantized model's interface from float to integer.")
parser.add_argument(
'--input_file',
type=str,
required=True,
help='Full path name to the input tflite file.')
parser.add_argument(
'--output_file',
type=str,
required=True,
help='Full path name to the output tflite file.')
parser.add_argument(
'--input_type',
type=str.upper,
choices=mmi_constants.STR_TYPES,
default=mmi_constants.DEFAULT_STR_TYPE,
help='Modified input integer interface type.')
parser.add_argument(
'--output_type',
type=str.upper,
choices=mmi_constants.STR_TYPES,
default=mmi_constants.DEFAULT_STR_TYPE,
help='Modified output integer interface type.')
args = parser.parse_args()
input_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.input_type]
output_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.output_type]
input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type]
output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type]
mmi_lib.modify_model_interface(args.input_file, args.output_file, input_type,
output_type)
mmi_lib.modify_model_interface(FLAGS.input_file, FLAGS.output_file,
input_type, output_type)
print('Successfully modified the model input type from FLOAT to '
'{input_type} and output type from FLOAT to {output_type}.'.format(
input_type=args.input_type, output_type=args.output_type))
input_type=FLAGS.input_type, output_type=FLAGS.output_type))
if __name__ == '__main__':
app.run(main=main, argv=sys.argv[:1])
app.run(main)

View File

@ -12,53 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Randomize all weights in a tflite file.
Example usage:
python randomize_weights.py \
--input_tflite_file=foo.tflite \
--output_tflite_file=foo_randomized.tflite
"""
r"""Randomize all weights in a tflite file."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from absl import app
from absl import flags
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.python.platform import app
FLAGS = flags.FLAGS
flags.DEFINE_string('input_tflite_file', None,
'Full path name to the input TFLite file.')
flags.DEFINE_string('output_tflite_file', None,
'Full path name to the output randomized TFLite file.')
flags.DEFINE_integer('random_seed', 0, 'Input to the random number generator.')
flags.mark_flag_as_required('input_tflite_file')
flags.mark_flag_as_required('output_tflite_file')
def main(_):
parser = argparse.ArgumentParser(
description='Randomize weights in a tflite file.')
parser.add_argument(
'--input_tflite_file',
type=str,
required=True,
help='Full path name to the input tflite file.')
parser.add_argument(
'--output_tflite_file',
type=str,
required=True,
help='Full path name to the output randomized tflite file.')
parser.add_argument(
'--random_seed',
type=str,
required=False,
default=0,
help='Input to the random number generator. The default value is 0.')
args = parser.parse_args()
# Read the model
model = flatbuffer_utils.read_model(args.input_tflite_file)
# Invoke the randomize weights function
flatbuffer_utils.randomize_weights(model, args.random_seed)
# Write the model
flatbuffer_utils.write_model(model, args.output_tflite_file)
model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
flatbuffer_utils.randomize_weights(model, FLAGS.random_seed)
flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
if __name__ == '__main__':
app.run(main=main, argv=sys.argv[:1])
app.run(main)

View File

@ -12,57 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Reverses xxd dump from to binary file
r"""Reverses xxd dump, i.e, converts a C++ source file back to a TFLite file.
This script is used to convert models from C++ source file (dumped with xxd) to
the binary model weight file and analyze it with model visualizer like Netron
(https://github.com/lutzroeder/netron) or load the model in TensorFlow Python
API
to evaluate the results in Python.
The command to dump binary file to C++ source file looks like
This script is used to convert a model from a C++ source file (dumped with xxd)
back to it's original TFLite file format in order to analyze it with either a
model visualizer like Netron (https://github.com/lutzroeder/netron) or to
evaluate the model using the Python TensorFlow Lite Interpreter API.
The xxd command to dump the TFLite file to a C++ source file looks like:
xxd -i model_data.tflite > model_data.cc
Example usage:
python reverse_xxd_dump_from_cc.py \
--input_cc_file=model_data.cc \
--output_tflite_file=model_data.tflite
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from absl import app
from absl import flags
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.python.platform import app
FLAGS = flags.FLAGS
flags.DEFINE_string('input_cc_file', None,
'Full path name to the input C++ source file.')
flags.DEFINE_string('output_tflite_file', None,
'Full path name to the output TFLite file.')
flags.mark_flag_as_required('input_cc_file')
flags.mark_flag_as_required('output_tflite_file')
def main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description='Reverses xxd dump from to binary file')
parser.add_argument(
'--input_cc_file',
type=str,
required=True,
help='Full path name to the input cc file.')
parser.add_argument(
'--output_tflite_file',
type=str,
required=True,
help='Full path name to the stripped output tflite file.')
args = parser.parse_args()
# Read the model from xxd output C++ source file
model = flatbuffer_utils.xxd_output_to_object(args.input_cc_file)
# Write the model
flatbuffer_utils.write_model(model, args.output_tflite_file)
model = flatbuffer_utils.xxd_output_to_object(FLAGS.input_cc_file)
flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
if __name__ == '__main__':
app.run(main=main, argv=sys.argv[:1])
app.run(main)

View File

@ -12,48 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Strips all nonessential strings from a tflite file.
Example usage:
python strip_strings.py \
--input_tflite_file=foo.tflite \
--output_tflite_file=foo_stripped.tflite
"""
r"""Strips all nonessential strings from a TFLite file."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from absl import app
from absl import flags
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.python.platform import app
FLAGS = flags.FLAGS
flags.DEFINE_string('input_tflite_file', None,
'Full path name to the input TFLite file.')
flags.DEFINE_string('output_tflite_file', None,
'Full path name to the output stripped TFLite file.')
flags.mark_flag_as_required('input_tflite_file')
flags.mark_flag_as_required('output_tflite_file')
def main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description='Strips all nonessential strings from a tflite file.')
parser.add_argument(
'--input_tflite_file',
type=str,
required=True,
help='Full path name to the input tflite file.')
parser.add_argument(
'--output_tflite_file',
type=str,
required=True,
help='Full path name to the stripped output tflite file.')
args = parser.parse_args()
# Read the model
model = flatbuffer_utils.read_model(args.input_tflite_file)
# Invoke the strip tflite file function
model = flatbuffer_utils.read_model(FLAGS.input_tflite_file)
flatbuffer_utils.strip_strings(model)
# Write the model
flatbuffer_utils.write_model(model, args.output_tflite_file)
flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
if __name__ == '__main__':
app.run(main=main, argv=sys.argv[:1])
app.run(main)