Add file identifier and model version to python TFLite FlatBuffers

PiperOrigin-RevId: 307843847
Change-Id: If867ffcf8e5770c257f818698acc221b2c91b29e
This commit is contained in:
Meghna Natraj 2020-04-22 10:23:35 -07:00 committed by TensorFlower Gardener
parent 173035836c
commit c3759a4130
3 changed files with 7 additions and 2 deletions

View File

@ -31,6 +31,8 @@ import random
from flatbuffers.python import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
TFLITE_FILE_IDENTIFIER = b'TFL3'
def read_model(input_tflite_file):
"""Reads and parses a tflite model.
@ -66,7 +68,7 @@ def write_model(model, output_tflite_file):
# Initial size of the buffer, which will grow automatically if needed
builder = flatbuffers.Builder(1024)
model_offset = model.Pack(builder)
builder.Finish(model_offset)
builder.Finish(model_offset, file_identifier=TFLITE_FILE_IDENTIFIER)
model_data = builder.Output()
with open(output_tflite_file, 'wb') as out_file:
out_file.write(model_data)

View File

@ -24,6 +24,8 @@ from __future__ import print_function
from flatbuffers.python import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
TFLITE_SCHEMA_VERSION = 3
def build_mock_flatbuffer_model():
"""Creates a flatbuffer containing an example model."""
@ -194,6 +196,7 @@ def build_mock_flatbuffer_model():
string4_offset = builder.CreateString('model_description')
schema_fb.ModelStart(builder)
schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION)
schema_fb.ModelAddOperatorCodes(builder, codes_offset)
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
schema_fb.ModelAddDescription(builder, string4_offset)

View File

@ -37,7 +37,7 @@ class VisualizeTest(test_util.TensorFlowTestCase):
def testFlatbufferToDict(self):
model = test_utils.build_mock_flatbuffer_model()
model_dict = visualize.CreateDictFromFlatbuffer(model)
self.assertEqual(0, model_dict['version'])
self.assertEqual(test_utils.TFLITE_SCHEMA_VERSION, model_dict['version'])
self.assertEqual(1, len(model_dict['subgraphs']))
self.assertEqual(1, len(model_dict['operator_codes']))
self.assertEqual(3, len(model_dict['buffers']))