Add file identifier and model version to python TFLite FlatBuffers
PiperOrigin-RevId: 307843847 Change-Id: If867ffcf8e5770c257f818698acc221b2c91b29e
This commit is contained in:
parent
173035836c
commit
c3759a4130
@ -31,6 +31,8 @@ import random
|
|||||||
from flatbuffers.python import flatbuffers
|
from flatbuffers.python import flatbuffers
|
||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
|
||||||
|
TFLITE_FILE_IDENTIFIER = b'TFL3'
|
||||||
|
|
||||||
|
|
||||||
def read_model(input_tflite_file):
|
def read_model(input_tflite_file):
|
||||||
"""Reads and parses a tflite model.
|
"""Reads and parses a tflite model.
|
||||||
@ -66,7 +68,7 @@ def write_model(model, output_tflite_file):
|
|||||||
# Initial size of the buffer, which will grow automatically if needed
|
# Initial size of the buffer, which will grow automatically if needed
|
||||||
builder = flatbuffers.Builder(1024)
|
builder = flatbuffers.Builder(1024)
|
||||||
model_offset = model.Pack(builder)
|
model_offset = model.Pack(builder)
|
||||||
builder.Finish(model_offset)
|
builder.Finish(model_offset, file_identifier=TFLITE_FILE_IDENTIFIER)
|
||||||
model_data = builder.Output()
|
model_data = builder.Output()
|
||||||
with open(output_tflite_file, 'wb') as out_file:
|
with open(output_tflite_file, 'wb') as out_file:
|
||||||
out_file.write(model_data)
|
out_file.write(model_data)
|
||||||
|
@ -24,6 +24,8 @@ from __future__ import print_function
|
|||||||
from flatbuffers.python import flatbuffers
|
from flatbuffers.python import flatbuffers
|
||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
|
||||||
|
TFLITE_SCHEMA_VERSION = 3
|
||||||
|
|
||||||
|
|
||||||
def build_mock_flatbuffer_model():
|
def build_mock_flatbuffer_model():
|
||||||
"""Creates a flatbuffer containing an example model."""
|
"""Creates a flatbuffer containing an example model."""
|
||||||
@ -194,6 +196,7 @@ def build_mock_flatbuffer_model():
|
|||||||
|
|
||||||
string4_offset = builder.CreateString('model_description')
|
string4_offset = builder.CreateString('model_description')
|
||||||
schema_fb.ModelStart(builder)
|
schema_fb.ModelStart(builder)
|
||||||
|
schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION)
|
||||||
schema_fb.ModelAddOperatorCodes(builder, codes_offset)
|
schema_fb.ModelAddOperatorCodes(builder, codes_offset)
|
||||||
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
|
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
|
||||||
schema_fb.ModelAddDescription(builder, string4_offset)
|
schema_fb.ModelAddDescription(builder, string4_offset)
|
||||||
|
@ -37,7 +37,7 @@ class VisualizeTest(test_util.TensorFlowTestCase):
|
|||||||
def testFlatbufferToDict(self):
|
def testFlatbufferToDict(self):
|
||||||
model = test_utils.build_mock_flatbuffer_model()
|
model = test_utils.build_mock_flatbuffer_model()
|
||||||
model_dict = visualize.CreateDictFromFlatbuffer(model)
|
model_dict = visualize.CreateDictFromFlatbuffer(model)
|
||||||
self.assertEqual(0, model_dict['version'])
|
self.assertEqual(test_utils.TFLITE_SCHEMA_VERSION, model_dict['version'])
|
||||||
self.assertEqual(1, len(model_dict['subgraphs']))
|
self.assertEqual(1, len(model_dict['subgraphs']))
|
||||||
self.assertEqual(1, len(model_dict['operator_codes']))
|
self.assertEqual(1, len(model_dict['operator_codes']))
|
||||||
self.assertEqual(3, len(model_dict['buffers']))
|
self.assertEqual(3, len(model_dict['buffers']))
|
||||||
|
Loading…
Reference in New Issue
Block a user