From ac539cf0447b9d50841b727fb9808e30cdcd3484 Mon Sep 17 00:00:00 2001 From: Lu Wang Date: Fri, 3 Apr 2020 09:37:49 -0700 Subject: [PATCH] Check Flatbuffer identifier in both MetadataPopulator and MetadataExtractor PiperOrigin-RevId: 304633026 Change-Id: Ib8edd9897c3ae7162aac1bba4aa676e99e9c0e07 --- .../experimental/support/metadata/metadata.py | 48 ++++++++++-- .../support/metadata/metadata_test.py | 78 +++++++++++++++++++ third_party/flatbuffers/workspace.bzl | 8 +- 3 files changed, 125 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/experimental/support/metadata/metadata.py b/tensorflow/lite/experimental/support/metadata/metadata.py index 1b5380352b8..25ca57bb4cc 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata.py +++ b/tensorflow/lite/experimental/support/metadata/metadata.py @@ -97,8 +97,9 @@ class MetadataPopulator(object): Raises: IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. """ - _assert_exist(model_file) + _assert_model_file_identifier(model_file) self._model_file = model_file self._metadata_buf = None self._associated_files = set() @@ -115,6 +116,7 @@ class MetadataPopulator(object): Raises: IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. """ return cls(model_file) @@ -129,6 +131,9 @@ class MetadataPopulator(object): Returns: A MetadataPopulator(_MetadataPopulatorWithBuffer) object. + + Raises: + ValueError: the model does not have the expected flatbuffer identifer. """ return _MetadataPopulatorWithBuffer(model_buf) @@ -211,12 +216,13 @@ class MetadataPopulator(object): metadata_buf: metadata buffer (in bytearray) to be populated. Raises: - ValueError: - The metadata to be populated is empty. + ValueError: The metadata to be populated is empty. + ValueError: The metadata does not have the expected flatbuffer identifer. """ if not metadata_buf: raise ValueError("The metadata to be populated is empty.") + _assert_metadata_buffer_identifier(metadata_buf) self._metadata_buf = metadata_buf def load_metadata_file(self, metadata_file): @@ -226,8 +232,8 @@ class MetadataPopulator(object): metadata_file: path to the metadata file to be populated. Raises: - IOError: - File not found. + IOError: File not found. + ValueError: The metadata does not have the expected flatbuffer identifer. """ _assert_exist(metadata_file) with open(metadata_file, "rb") as f: @@ -391,6 +397,7 @@ class _MetadataPopulatorWithBuffer(MetadataPopulator): Raises: ValueError: model_buf is empty. + ValueError: model_buf does not have the expected flatbuffer identifer. """ if not model_buf: raise ValueError("model_buf cannot be empty.") @@ -423,6 +430,8 @@ class MetadataDisplayer(object): metadata_file: valid path to the metadata file. associated_file_list: list of associate files in the model file. """ + _assert_model_file_identifier(model_file) + _assert_metadata_file_identifier(metadata_file) self._model_file = model_file self._metadata_file = metadata_file self._associated_file_list = associated_file_list @@ -553,3 +562,32 @@ def _assert_exist(filename): """Checks if a file exists.""" if not os.path.exists(filename): raise IOError("File, '{0}', does not exist.".format(filename)) + + +def _assert_model_file_identifier(model_file): + """Checks if a model file has the expected TFLite schema identifier.""" + _assert_exist(model_file) + with open(model_file, "rb") as f: + model_buf = f.read() + + if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0): + raise ValueError( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.") + + +def _assert_metadata_file_identifier(metadata_file): + """Checks if a metadata file has the expected Metadata schema identifier.""" + _assert_exist(metadata_file) + with open(metadata_file, "rb") as f: + metadata_buf = f.read() + _assert_metadata_buffer_identifier(metadata_buf) + + +def _assert_metadata_buffer_identifier(metadata_buf): + """Checks if a metadata buffer has the expected Metadata schema identifier.""" + if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier( + metadata_buf, 0): + raise ValueError( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.") diff --git a/tensorflow/lite/experimental/support/metadata/metadata_test.py b/tensorflow/lite/experimental/support/metadata/metadata_test.py index 30f6a73e070..81b3eef62f9 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata_test.py +++ b/tensorflow/lite/experimental/support/metadata/metadata_test.py @@ -102,6 +102,39 @@ class MetadataTest(test_util.TensorFlowTestCase): f.write(b.Output()) return metadata_file + def _create_model_buffer_with_wrong_identifier(self): + wrong_identifier = b"widn" + model = _schema_fb.ModelT() + model_builder = flatbuffers.Builder(0) + model_builder.Finish(model.Pack(model_builder), wrong_identifier) + return model_builder.Output() + + def _create_metadata_buffer_with_wrong_identifier(self): + # Creates a metadata with wrong identifier + wrong_identifier = b"widn" + metadata = _metadata_fb.ModelMetadataT() + metadata_builder = flatbuffers.Builder(0) + metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier) + return metadata_builder.Output() + + def _populate_metadata_with_identifier(self, model_buf, metadata_buf, + identifier): + # For testing purposes only. MetadataPopulator cannot populate metadata with + # wrong identifiers. + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = metadata_buf + model.buffers = [buffer_field] + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata = [metadata_field] + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), identifier) + return b.Output() + class MetadataPopulatorTest(MetadataTest): @@ -126,6 +159,14 @@ class MetadataPopulatorTest(MetadataTest): _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf) self.assertEqual("model_buf cannot be empty.", str(error.exception)) + def testToModelBufferWithWrongIdentifier(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + def testSinglePopulateAssociatedFile(self): populator = _metadata.MetadataPopulator.with_model_buffer( self._empty_model_buf) @@ -228,6 +269,15 @@ class MetadataPopulatorTest(MetadataTest): "not been loaded into the populator.").format( os.path.basename(self._file2)), str(error.exception)) + def testPopulateMetadataBufferWithWrongIdentifier(self): + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(metadata_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + def _assert_golden_metadata(self, model_file): with open(model_file, "rb") as f: model_buf_from_file = f.read() @@ -332,6 +382,34 @@ class MetadataDisplayerTest(MetadataTest): populator.populate() return model_file + def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException( + self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + model_buf = self._populate_metadata_with_identifier( + model_buf, metadata_buf, + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException( + self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_file = self._create_metadata_file() + wrong_identifier = b"widn" + with open(metadata_file, "rb") as f: + metadata_buf = bytearray(f.read()) + model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf, + wrong_identifier) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + def test_load_model_file_invalidModelFile_throwsException(self): with self.assertRaises(IOError) as error: _metadata.MetadataDisplayer.with_model_file(self._invalid_file) diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index dffc100bc22..d1d19a46134 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-a4b2884e4ed6116335d534af8f58a84678b74a17", - sha256 = "6ff041dcaf873acbf0a93886e6b4f7704b68af1457e8b675cae88fbefe2de330", + strip_prefix = "flatbuffers-1.12.0", + sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip", - "https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz", ], build_file = "//third_party/flatbuffers:BUILD.bazel", system_build_file = "//third_party/flatbuffers:BUILD.system",