Minor updates to flatbuffer utilities

PiperOrigin-RevId: 307732210
Change-Id: I6b97ccdff0323dbf0fd20fc20d6bc7e49d5e08ad
This commit is contained in:
Meghna Natraj 2020-04-21 20:04:30 -07:00 committed by TensorFlower Gardener
parent 4b700752f3
commit 47ea7eeb96
3 changed files with 19 additions and 16 deletions

View File

@ -31,7 +31,7 @@ class WriteReadModelTest(test_util.TensorFlowTestCase):
def testWriteReadModel(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model_python_object()
initial_model = test_utils.build_mock_model()
# Define temporary files
tmp_dir = self.get_temp_dir()
model_filename = os.path.join(tmp_dir, 'model.tflite')
@ -76,7 +76,7 @@ class StripStringsTest(test_util.TensorFlowTestCase):
def testStripStrings(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model_python_object()
initial_model = test_utils.build_mock_model()
final_model = copy.deepcopy(initial_model)
# 2. INVOKE
@ -121,7 +121,7 @@ class RandomizeWeightsTest(test_util.TensorFlowTestCase):
def testRandomizeWeights(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model_python_object()
initial_model = test_utils.build_mock_model()
final_model = copy.deepcopy(initial_model)
# 2. INVOKE

View File

@ -14,7 +14,7 @@
# ==============================================================================
"""Utility functions that support testing.
All functions that can be commonly used by various tests are in this file.
All functions that can be commonly used by various tests.
"""
from __future__ import absolute_import
@ -25,7 +25,7 @@ from flatbuffers.python import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
def build_mock_model():
def build_mock_flatbuffer_model():
"""Creates a flatbuffer containing an example model."""
builder = flatbuffers.Builder(1024)
@ -205,10 +205,14 @@ def build_mock_model():
return model
def build_mock_model_python_object():
"""Creates a python flatbuffer object containing an example model."""
model_mock = build_mock_model()
model_obj = schema_fb.Model.GetRootAsModel(model_mock, 0)
model = schema_fb.ModelT.InitFromObj(model_obj)
def load_model_from_flatbuffer(flatbuffer_model):
"""Loads a model as a python object from a flatbuffer model."""
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
model = schema_fb.ModelT.InitFromObj(model)
return model
def build_mock_model():
"""Creates an object containing an example model."""
model = build_mock_flatbuffer_model()
return load_model_from_flatbuffer(model)

View File

@ -35,8 +35,8 @@ class VisualizeTest(test_util.TensorFlowTestCase):
self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10))
def testFlatbufferToDict(self):
model_data = test_utils.build_mock_model()
model_dict = visualize.CreateDictFromFlatbuffer(model_data)
model = test_utils.build_mock_flatbuffer_model()
model_dict = visualize.CreateDictFromFlatbuffer(model)
self.assertEqual(0, model_dict['version'])
self.assertEqual(1, len(model_dict['subgraphs']))
self.assertEqual(1, len(model_dict['operator_codes']))
@ -45,12 +45,11 @@ class VisualizeTest(test_util.TensorFlowTestCase):
self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer'])
def testVisualize(self):
model_data = test_utils.build_mock_model()
model = test_utils.build_mock_flatbuffer_model()
tmp_dir = self.get_temp_dir()
model_filename = os.path.join(tmp_dir, 'model.tflite')
with open(model_filename, 'wb') as model_file:
model_file.write(model_data)
model_file.write(model)
html_filename = os.path.join(tmp_dir, 'visualization.html')
visualize.CreateHtmlFile(model_filename, html_filename)