Add file identifier and model version to python TFLite FlatBuffers
PiperOrigin-RevId: 307843847 Change-Id: If867ffcf8e5770c257f818698acc221b2c91b29e
This commit is contained in:
parent
173035836c
commit
c3759a4130
@ -31,6 +31,8 @@ import random
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
|
||||
TFLITE_FILE_IDENTIFIER = b'TFL3'
|
||||
|
||||
|
||||
def read_model(input_tflite_file):
|
||||
"""Reads and parses a tflite model.
|
||||
@ -66,7 +68,7 @@ def write_model(model, output_tflite_file):
|
||||
# Initial size of the buffer, which will grow automatically if needed
|
||||
builder = flatbuffers.Builder(1024)
|
||||
model_offset = model.Pack(builder)
|
||||
builder.Finish(model_offset)
|
||||
builder.Finish(model_offset, file_identifier=TFLITE_FILE_IDENTIFIER)
|
||||
model_data = builder.Output()
|
||||
with open(output_tflite_file, 'wb') as out_file:
|
||||
out_file.write(model_data)
|
||||
|
@ -24,6 +24,8 @@ from __future__ import print_function
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
|
||||
TFLITE_SCHEMA_VERSION = 3
|
||||
|
||||
|
||||
def build_mock_flatbuffer_model():
|
||||
"""Creates a flatbuffer containing an example model."""
|
||||
@ -194,6 +196,7 @@ def build_mock_flatbuffer_model():
|
||||
|
||||
string4_offset = builder.CreateString('model_description')
|
||||
schema_fb.ModelStart(builder)
|
||||
schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION)
|
||||
schema_fb.ModelAddOperatorCodes(builder, codes_offset)
|
||||
schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
|
||||
schema_fb.ModelAddDescription(builder, string4_offset)
|
||||
|
@ -37,7 +37,7 @@ class VisualizeTest(test_util.TensorFlowTestCase):
|
||||
def testFlatbufferToDict(self):
|
||||
model = test_utils.build_mock_flatbuffer_model()
|
||||
model_dict = visualize.CreateDictFromFlatbuffer(model)
|
||||
self.assertEqual(0, model_dict['version'])
|
||||
self.assertEqual(test_utils.TFLITE_SCHEMA_VERSION, model_dict['version'])
|
||||
self.assertEqual(1, len(model_dict['subgraphs']))
|
||||
self.assertEqual(1, len(model_dict['operator_codes']))
|
||||
self.assertEqual(3, len(model_dict['buffers']))
|
||||
|
Loading…
Reference in New Issue
Block a user