Merge pull request #42837 from danielyou0230:tflite_reverse_xxd_dump
PiperOrigin-RevId: 329600652 Change-Id: Ibc5a29ab22d30e6052d30e751a42d1d90224d625
This commit is contained in:
commit
1f40fe92b1
@ -84,6 +84,17 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "reverse_xxd_dump_from_cc",
|
||||
srcs = ["reverse_xxd_dump_from_cc.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":flatbuffer_utils",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "randomize_weights",
|
||||
srcs = ["randomize_weights.py"],
|
||||
|
@ -28,6 +28,7 @@ from __future__ import print_function
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import flatbuffers
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
@ -81,7 +82,7 @@ def read_model_with_mutable_tensors(input_tflite_file):
|
||||
|
||||
|
||||
def convert_object_to_bytearray(model_object):
|
||||
"""Converts a tflite model from an object to a bytearray."""
|
||||
"""Converts a tflite model from an object to a immutable bytearray."""
|
||||
# Initial size of the buffer, which will grow automatically if needed
|
||||
builder = flatbuffers.Builder(1024)
|
||||
model_offset = model_object.Pack(builder)
|
||||
@ -153,3 +154,59 @@ def randomize_weights(model, random_seed=0):
|
||||
# end up as denormalized or NaN/Inf floating point numbers.
|
||||
for j in range(buffer_i_size):
|
||||
buffer_i_data[j] = random.randint(0, 255)
|
||||
|
||||
|
||||
def xxd_output_to_bytes(input_cc_file):
|
||||
"""Converts xxd output C++ source file to bytes (immutable)
|
||||
|
||||
Args:
|
||||
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||||
|
||||
Raises:
|
||||
RuntimeError: If input_cc_file path is invalid.
|
||||
IOError: If input_cc_file cannot be opened.
|
||||
|
||||
Returns:
|
||||
A bytearray corresponding to the input cc file array.
|
||||
"""
|
||||
# Match hex values in the string with comma as separator
|
||||
pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
|
||||
|
||||
model_bytearray = bytearray()
|
||||
|
||||
with open(input_cc_file) as file_handle:
|
||||
for line in file_handle:
|
||||
values_match = pattern.match(line)
|
||||
|
||||
if values_match is None:
|
||||
continue
|
||||
|
||||
# Match in the parentheses (hex array only)
|
||||
list_text = values_match.group(1)
|
||||
|
||||
# Extract hex values (text) from the line
|
||||
# e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
|
||||
values_text = filter(None, list_text.split(','))
|
||||
|
||||
# Convert to hex
|
||||
values = [int(x, base=16) for x in values_text]
|
||||
model_bytearray.extend(values)
|
||||
|
||||
return bytes(model_bytearray)
|
||||
|
||||
|
||||
def xxd_output_to_object(input_cc_file):
|
||||
"""Converts xxd output C++ source file to object
|
||||
|
||||
Args:
|
||||
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||||
|
||||
Raises:
|
||||
RuntimeError: If input_cc_file path is invalid.
|
||||
IOError: If input_cc_file cannot be opened.
|
||||
|
||||
Returns:
|
||||
A python object corresponding to the input tflite file.
|
||||
"""
|
||||
model_bytes = xxd_output_to_bytes(input_cc_file)
|
||||
return convert_bytearray_to_object(model_bytes)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from tensorflow.lite.tools import flatbuffer_utils
|
||||
from tensorflow.lite.tools import test_utils
|
||||
@ -159,5 +160,33 @@ class RandomizeWeightsTest(test_util.TensorFlowTestCase):
|
||||
self.assertNotEqual(initial_buffer.data[j], final_buffer.data[j])
|
||||
|
||||
|
||||
class XxdOutputToBytesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testXxdOutputToBytes(self):
|
||||
# 1. SETUP
|
||||
# Define the initial model
|
||||
initial_model = test_utils.build_mock_model()
|
||||
initial_bytes = flatbuffer_utils.convert_object_to_bytearray(initial_model)
|
||||
|
||||
# Define temporary files
|
||||
tmp_dir = self.get_temp_dir()
|
||||
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
||||
|
||||
# 2. Write model to temporary file (will be used as input for xxd)
|
||||
flatbuffer_utils.write_model(initial_model, model_filename)
|
||||
|
||||
# 3. DUMP WITH xxd
|
||||
input_cc_file = os.path.join(tmp_dir, 'model.cc')
|
||||
|
||||
command = 'xxd -i {} > {}'.format(model_filename, input_cc_file)
|
||||
subprocess.call(command, shell=True)
|
||||
|
||||
# 4. VALIDATE
|
||||
final_bytes = flatbuffer_utils.xxd_output_to_bytes(input_cc_file)
|
||||
|
||||
# Validate that the initial and final bytearray are the same
|
||||
self.assertEqual(initial_bytes, final_bytes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
68
tensorflow/lite/tools/reverse_xxd_dump_from_cc.py
Normal file
68
tensorflow/lite/tools/reverse_xxd_dump_from_cc.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Reverses xxd dump from to binary file
|
||||
|
||||
This script is used to convert models from C++ source file (dumped with xxd) to
|
||||
the binary model weight file and analyze it with model visualizer like Netron
|
||||
(https://github.com/lutzroeder/netron) or load the model in TensorFlow Python
|
||||
API
|
||||
to evaluate the results in Python.
|
||||
|
||||
The command to dump binary file to C++ source file looks like
|
||||
|
||||
xxd -i model_data.tflite > model_data.cc
|
||||
|
||||
Example usage:
|
||||
|
||||
python reverse_xxd_dump_from_cc.py \
|
||||
--input_cc_file=model_data.cc \
|
||||
--output_tflite_file=model_data.tflite
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.lite.tools import flatbuffer_utils
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
|
||||
def main(_):
|
||||
"""Application run loop."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Reverses xxd dump from to binary file')
|
||||
parser.add_argument(
|
||||
'--input_cc_file',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Full path name to the input cc file.')
|
||||
parser.add_argument(
|
||||
'--output_tflite_file',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Full path name to the stripped output tflite file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Read the model from xxd output C++ source file
|
||||
model = flatbuffer_utils.xxd_output_to_object(args.input_cc_file)
|
||||
# Write the model
|
||||
flatbuffer_utils.write_model(model, args.output_tflite_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main=main, argv=sys.argv[:1])
|
Loading…
Reference in New Issue
Block a user