Merge pull request #42837 from danielyou0230:tflite_reverse_xxd_dump

PiperOrigin-RevId: 329600652
Change-Id: Ibc5a29ab22d30e6052d30e751a42d1d90224d625
This commit is contained in:
TensorFlower Gardener 2020-09-01 16:12:37 -07:00
commit 1f40fe92b1
4 changed files with 166 additions and 1 deletions

View File

@ -84,6 +84,17 @@ py_binary(
],
)
py_binary(
name = "reverse_xxd_dump_from_cc",
srcs = ["reverse_xxd_dump_from_cc.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":flatbuffer_utils",
"//tensorflow/python:platform",
],
)
py_binary(
name = "randomize_weights",
srcs = ["randomize_weights.py"],

View File

@ -28,6 +28,7 @@ from __future__ import print_function
import copy
import os
import random
import re
import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
@ -81,7 +82,7 @@ def read_model_with_mutable_tensors(input_tflite_file):
def convert_object_to_bytearray(model_object):
"""Converts a tflite model from an object to a bytearray."""
"""Converts a tflite model from an object to a immutable bytearray."""
# Initial size of the buffer, which will grow automatically if needed
builder = flatbuffers.Builder(1024)
model_offset = model_object.Pack(builder)
@ -153,3 +154,59 @@ def randomize_weights(model, random_seed=0):
# end up as denormalized or NaN/Inf floating point numbers.
for j in range(buffer_i_size):
buffer_i_data[j] = random.randint(0, 255)
def xxd_output_to_bytes(input_cc_file):
"""Converts xxd output C++ source file to bytes (immutable)
Args:
input_cc_file: Full path name to th C++ source file dumped by xxd
Raises:
RuntimeError: If input_cc_file path is invalid.
IOError: If input_cc_file cannot be opened.
Returns:
A bytearray corresponding to the input cc file array.
"""
# Match hex values in the string with comma as separator
pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
model_bytearray = bytearray()
with open(input_cc_file) as file_handle:
for line in file_handle:
values_match = pattern.match(line)
if values_match is None:
continue
# Match in the parentheses (hex array only)
list_text = values_match.group(1)
# Extract hex values (text) from the line
# e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
values_text = filter(None, list_text.split(','))
# Convert to hex
values = [int(x, base=16) for x in values_text]
model_bytearray.extend(values)
return bytes(model_bytearray)
def xxd_output_to_object(input_cc_file):
"""Converts xxd output C++ source file to object
Args:
input_cc_file: Full path name to th C++ source file dumped by xxd
Raises:
RuntimeError: If input_cc_file path is invalid.
IOError: If input_cc_file cannot be opened.
Returns:
A python object corresponding to the input tflite file.
"""
model_bytes = xxd_output_to_bytes(input_cc_file)
return convert_bytearray_to_object(model_bytes)

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import copy
import os
import subprocess
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.lite.tools import test_utils
@ -159,5 +160,33 @@ class RandomizeWeightsTest(test_util.TensorFlowTestCase):
self.assertNotEqual(initial_buffer.data[j], final_buffer.data[j])
class XxdOutputToBytesTest(test_util.TensorFlowTestCase):
def testXxdOutputToBytes(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model()
initial_bytes = flatbuffer_utils.convert_object_to_bytearray(initial_model)
# Define temporary files
tmp_dir = self.get_temp_dir()
model_filename = os.path.join(tmp_dir, 'model.tflite')
# 2. Write model to temporary file (will be used as input for xxd)
flatbuffer_utils.write_model(initial_model, model_filename)
# 3. DUMP WITH xxd
input_cc_file = os.path.join(tmp_dir, 'model.cc')
command = 'xxd -i {} > {}'.format(model_filename, input_cc_file)
subprocess.call(command, shell=True)
# 4. VALIDATE
final_bytes = flatbuffer_utils.xxd_output_to_bytes(input_cc_file)
# Validate that the initial and final bytearray are the same
self.assertEqual(initial_bytes, final_bytes)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,68 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Reverses xxd dump from to binary file
This script is used to convert models from C++ source file (dumped with xxd) to
the binary model weight file and analyze it with model visualizer like Netron
(https://github.com/lutzroeder/netron) or load the model in TensorFlow Python
API
to evaluate the results in Python.
The command to dump binary file to C++ source file looks like
xxd -i model_data.tflite > model_data.cc
Example usage:
python reverse_xxd_dump_from_cc.py \
--input_cc_file=model_data.cc \
--output_tflite_file=model_data.tflite
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.python.platform import app
def main(_):
"""Application run loop."""
parser = argparse.ArgumentParser(
description='Reverses xxd dump from to binary file')
parser.add_argument(
'--input_cc_file',
type=str,
required=True,
help='Full path name to the input cc file.')
parser.add_argument(
'--output_tflite_file',
type=str,
required=True,
help='Full path name to the stripped output tflite file.')
args = parser.parse_args()
# Read the model from xxd output C++ source file
model = flatbuffer_utils.xxd_output_to_object(args.input_cc_file)
# Write the model
flatbuffer_utils.write_model(model, args.output_tflite_file)
if __name__ == '__main__':
app.run(main=main, argv=sys.argv[:1])