Minor updates to flatbuffer utilities
PiperOrigin-RevId: 307732210 Change-Id: I6b97ccdff0323dbf0fd20fc20d6bc7e49d5e08ad
This commit is contained in:
parent
4b700752f3
commit
47ea7eeb96
@ -31,7 +31,7 @@ class WriteReadModelTest(test_util.TensorFlowTestCase):
|
|||||||
def testWriteReadModel(self):
|
def testWriteReadModel(self):
|
||||||
# 1. SETUP
|
# 1. SETUP
|
||||||
# Define the initial model
|
# Define the initial model
|
||||||
initial_model = test_utils.build_mock_model_python_object()
|
initial_model = test_utils.build_mock_model()
|
||||||
# Define temporary files
|
# Define temporary files
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
||||||
@ -76,7 +76,7 @@ class StripStringsTest(test_util.TensorFlowTestCase):
|
|||||||
def testStripStrings(self):
|
def testStripStrings(self):
|
||||||
# 1. SETUP
|
# 1. SETUP
|
||||||
# Define the initial model
|
# Define the initial model
|
||||||
initial_model = test_utils.build_mock_model_python_object()
|
initial_model = test_utils.build_mock_model()
|
||||||
final_model = copy.deepcopy(initial_model)
|
final_model = copy.deepcopy(initial_model)
|
||||||
|
|
||||||
# 2. INVOKE
|
# 2. INVOKE
|
||||||
@ -121,7 +121,7 @@ class RandomizeWeightsTest(test_util.TensorFlowTestCase):
|
|||||||
def testRandomizeWeights(self):
|
def testRandomizeWeights(self):
|
||||||
# 1. SETUP
|
# 1. SETUP
|
||||||
# Define the initial model
|
# Define the initial model
|
||||||
initial_model = test_utils.build_mock_model_python_object()
|
initial_model = test_utils.build_mock_model()
|
||||||
final_model = copy.deepcopy(initial_model)
|
final_model = copy.deepcopy(initial_model)
|
||||||
|
|
||||||
# 2. INVOKE
|
# 2. INVOKE
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Utility functions that support testing.
|
"""Utility functions that support testing.
|
||||||
|
|
||||||
All functions that can be commonly used by various tests are in this file.
|
All functions that can be commonly used by various tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -25,7 +25,7 @@ from flatbuffers.python import flatbuffers
|
|||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
|
||||||
|
|
||||||
def build_mock_model():
|
def build_mock_flatbuffer_model():
|
||||||
"""Creates a flatbuffer containing an example model."""
|
"""Creates a flatbuffer containing an example model."""
|
||||||
builder = flatbuffers.Builder(1024)
|
builder = flatbuffers.Builder(1024)
|
||||||
|
|
||||||
@ -205,10 +205,14 @@ def build_mock_model():
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_mock_model_python_object():
|
def load_model_from_flatbuffer(flatbuffer_model):
|
||||||
"""Creates a python flatbuffer object containing an example model."""
|
"""Loads a model as a python object from a flatbuffer model."""
|
||||||
model_mock = build_mock_model()
|
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
|
||||||
model_obj = schema_fb.Model.GetRootAsModel(model_mock, 0)
|
model = schema_fb.ModelT.InitFromObj(model)
|
||||||
model = schema_fb.ModelT.InitFromObj(model_obj)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_mock_model():
|
||||||
|
"""Creates an object containing an example model."""
|
||||||
|
model = build_mock_flatbuffer_model()
|
||||||
|
return load_model_from_flatbuffer(model)
|
||||||
|
@ -35,8 +35,8 @@ class VisualizeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10))
|
self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10))
|
||||||
|
|
||||||
def testFlatbufferToDict(self):
|
def testFlatbufferToDict(self):
|
||||||
model_data = test_utils.build_mock_model()
|
model = test_utils.build_mock_flatbuffer_model()
|
||||||
model_dict = visualize.CreateDictFromFlatbuffer(model_data)
|
model_dict = visualize.CreateDictFromFlatbuffer(model)
|
||||||
self.assertEqual(0, model_dict['version'])
|
self.assertEqual(0, model_dict['version'])
|
||||||
self.assertEqual(1, len(model_dict['subgraphs']))
|
self.assertEqual(1, len(model_dict['subgraphs']))
|
||||||
self.assertEqual(1, len(model_dict['operator_codes']))
|
self.assertEqual(1, len(model_dict['operator_codes']))
|
||||||
@ -45,12 +45,11 @@ class VisualizeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer'])
|
self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer'])
|
||||||
|
|
||||||
def testVisualize(self):
|
def testVisualize(self):
|
||||||
model_data = test_utils.build_mock_model()
|
model = test_utils.build_mock_flatbuffer_model()
|
||||||
|
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
model_filename = os.path.join(tmp_dir, 'model.tflite')
|
||||||
with open(model_filename, 'wb') as model_file:
|
with open(model_filename, 'wb') as model_file:
|
||||||
model_file.write(model_data)
|
model_file.write(model)
|
||||||
html_filename = os.path.join(tmp_dir, 'visualization.html')
|
html_filename = os.path.join(tmp_dir, 'visualization.html')
|
||||||
|
|
||||||
visualize.CreateHtmlFile(model_filename, html_filename)
|
visualize.CreateHtmlFile(model_filename, html_filename)
|
||||||
|
Loading…
Reference in New Issue
Block a user