Minor updates to flatbuffer utilities
PiperOrigin-RevId: 307732210 Change-Id: I6b97ccdff0323dbf0fd20fc20d6bc7e49d5e08ad
This commit is contained in:
parent
4b700752f3
commit
47ea7eeb96
@ -31,7 +31,7 @@ class WriteReadModelTest(test_util.TensorFlowTestCase):
|
||||
def testWriteReadModel(self):
|
||||
# 1. SETUP
|
||||
# Define the initial model
|
||||
initial_model = test_utils.build_mock_model_python_object()
|
||||
initial_model = test_utils.build_mock_model()
|
||||
# Define temporary files
|
||||
tmp_dir = self.get_temp_dir()
|
||||
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
||||
@ -76,7 +76,7 @@ class StripStringsTest(test_util.TensorFlowTestCase):
|
||||
def testStripStrings(self):
|
||||
# 1. SETUP
|
||||
# Define the initial model
|
||||
initial_model = test_utils.build_mock_model_python_object()
|
||||
initial_model = test_utils.build_mock_model()
|
||||
final_model = copy.deepcopy(initial_model)
|
||||
|
||||
# 2. INVOKE
|
||||
@ -121,7 +121,7 @@ class RandomizeWeightsTest(test_util.TensorFlowTestCase):
|
||||
def testRandomizeWeights(self):
|
||||
# 1. SETUP
|
||||
# Define the initial model
|
||||
initial_model = test_utils.build_mock_model_python_object()
|
||||
initial_model = test_utils.build_mock_model()
|
||||
final_model = copy.deepcopy(initial_model)
|
||||
|
||||
# 2. INVOKE
|
||||
|
@ -14,7 +14,7 @@
|
||||
# ==============================================================================
|
||||
"""Utility functions that support testing.
|
||||
|
||||
All functions that can be commonly used by various tests are in this file.
|
||||
All functions that can be commonly used by various tests.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -25,7 +25,7 @@ from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
|
||||
|
||||
def build_mock_model():
|
||||
def build_mock_flatbuffer_model():
|
||||
"""Creates a flatbuffer containing an example model."""
|
||||
builder = flatbuffers.Builder(1024)
|
||||
|
||||
@ -205,10 +205,14 @@ def build_mock_model():
|
||||
return model
|
||||
|
||||
|
||||
def build_mock_model_python_object():
|
||||
"""Creates a python flatbuffer object containing an example model."""
|
||||
model_mock = build_mock_model()
|
||||
model_obj = schema_fb.Model.GetRootAsModel(model_mock, 0)
|
||||
model = schema_fb.ModelT.InitFromObj(model_obj)
|
||||
|
||||
def load_model_from_flatbuffer(flatbuffer_model):
|
||||
"""Loads a model as a python object from a flatbuffer model."""
|
||||
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
|
||||
model = schema_fb.ModelT.InitFromObj(model)
|
||||
return model
|
||||
|
||||
|
||||
def build_mock_model():
|
||||
"""Creates an object containing an example model."""
|
||||
model = build_mock_flatbuffer_model()
|
||||
return load_model_from_flatbuffer(model)
|
||||
|
@ -35,8 +35,8 @@ class VisualizeTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10))
|
||||
|
||||
def testFlatbufferToDict(self):
|
||||
model_data = test_utils.build_mock_model()
|
||||
model_dict = visualize.CreateDictFromFlatbuffer(model_data)
|
||||
model = test_utils.build_mock_flatbuffer_model()
|
||||
model_dict = visualize.CreateDictFromFlatbuffer(model)
|
||||
self.assertEqual(0, model_dict['version'])
|
||||
self.assertEqual(1, len(model_dict['subgraphs']))
|
||||
self.assertEqual(1, len(model_dict['operator_codes']))
|
||||
@ -45,12 +45,11 @@ class VisualizeTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer'])
|
||||
|
||||
def testVisualize(self):
|
||||
model_data = test_utils.build_mock_model()
|
||||
|
||||
model = test_utils.build_mock_flatbuffer_model()
|
||||
tmp_dir = self.get_temp_dir()
|
||||
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
||||
with open(model_filename, 'wb') as model_file:
|
||||
model_file.write(model_data)
|
||||
model_file.write(model)
|
||||
html_filename = os.path.join(tmp_dir, 'visualization.html')
|
||||
|
||||
visualize.CreateHtmlFile(model_filename, html_filename)
|
||||
|
Loading…
Reference in New Issue
Block a user