From c3759a41301a93281ed2d50b8f4b8786609bf4b7 Mon Sep 17 00:00:00 2001 From: Meghna Natraj Date: Wed, 22 Apr 2020 10:23:35 -0700 Subject: [PATCH] Add file identifier and model version to python TFLite FlatBuffers PiperOrigin-RevId: 307843847 Change-Id: If867ffcf8e5770c257f818698acc221b2c91b29e --- tensorflow/lite/tools/flatbuffer_utils.py | 4 +++- tensorflow/lite/tools/test_utils.py | 3 +++ tensorflow/lite/tools/visualize_test.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/tools/flatbuffer_utils.py b/tensorflow/lite/tools/flatbuffer_utils.py index 5b513bbfef2..f80daad2519 100644 --- a/tensorflow/lite/tools/flatbuffer_utils.py +++ b/tensorflow/lite/tools/flatbuffer_utils.py @@ -31,6 +31,8 @@ import random from flatbuffers.python import flatbuffers from tensorflow.lite.python import schema_py_generated as schema_fb +TFLITE_FILE_IDENTIFIER = b'TFL3' + def read_model(input_tflite_file): """Reads and parses a tflite model. @@ -66,7 +68,7 @@ def write_model(model, output_tflite_file): # Initial size of the buffer, which will grow automatically if needed builder = flatbuffers.Builder(1024) model_offset = model.Pack(builder) - builder.Finish(model_offset) + builder.Finish(model_offset, file_identifier=TFLITE_FILE_IDENTIFIER) model_data = builder.Output() with open(output_tflite_file, 'wb') as out_file: out_file.write(model_data) diff --git a/tensorflow/lite/tools/test_utils.py b/tensorflow/lite/tools/test_utils.py index dfeb8a2fee2..3950e3de35e 100644 --- a/tensorflow/lite/tools/test_utils.py +++ b/tensorflow/lite/tools/test_utils.py @@ -24,6 +24,8 @@ from __future__ import print_function from flatbuffers.python import flatbuffers from tensorflow.lite.python import schema_py_generated as schema_fb +TFLITE_SCHEMA_VERSION = 3 + def build_mock_flatbuffer_model(): """Creates a flatbuffer containing an example model.""" @@ -194,6 +196,7 @@ def build_mock_flatbuffer_model(): string4_offset = builder.CreateString('model_description') schema_fb.ModelStart(builder) + schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION) schema_fb.ModelAddOperatorCodes(builder, codes_offset) schema_fb.ModelAddSubgraphs(builder, subgraphs_offset) schema_fb.ModelAddDescription(builder, string4_offset) diff --git a/tensorflow/lite/tools/visualize_test.py b/tensorflow/lite/tools/visualize_test.py index 6480b79e2dc..aa74891224d 100644 --- a/tensorflow/lite/tools/visualize_test.py +++ b/tensorflow/lite/tools/visualize_test.py @@ -37,7 +37,7 @@ class VisualizeTest(test_util.TensorFlowTestCase): def testFlatbufferToDict(self): model = test_utils.build_mock_flatbuffer_model() model_dict = visualize.CreateDictFromFlatbuffer(model) - self.assertEqual(0, model_dict['version']) + self.assertEqual(test_utils.TFLITE_SCHEMA_VERSION, model_dict['version']) self.assertEqual(1, len(model_dict['subgraphs'])) self.assertEqual(1, len(model_dict['operator_codes'])) self.assertEqual(3, len(model_dict['buffers']))