Check Flatbuffer identifier in both MetadataPopulator and MetadataExtractor

PiperOrigin-RevId: 304633026
Change-Id: Ib8edd9897c3ae7162aac1bba4aa676e99e9c0e07
This commit is contained in:
Lu Wang 2020-04-03 09:37:49 -07:00 committed by TensorFlower Gardener
parent 569f3f82f6
commit ac539cf044
3 changed files with 125 additions and 9 deletions

View File

@ -97,8 +97,9 @@ class MetadataPopulator(object):
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
_assert_exist(model_file)
_assert_model_file_identifier(model_file)
self._model_file = model_file
self._metadata_buf = None
self._associated_files = set()
@ -115,6 +116,7 @@ class MetadataPopulator(object):
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
return cls(model_file)
@ -129,6 +131,9 @@ class MetadataPopulator(object):
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
Raises:
ValueError: the model does not have the expected flatbuffer identifer.
"""
return _MetadataPopulatorWithBuffer(model_buf)
@ -211,12 +216,13 @@ class MetadataPopulator(object):
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError:
The metadata to be populated is empty.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
_assert_metadata_buffer_identifier(metadata_buf)
self._metadata_buf = metadata_buf
def load_metadata_file(self, metadata_file):
@ -226,8 +232,8 @@ class MetadataPopulator(object):
metadata_file: path to the metadata file to be populated.
Raises:
IOError:
File not found.
IOError: File not found.
ValueError: The metadata does not have the expected flatbuffer identifer.
"""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
@ -391,6 +397,7 @@ class _MetadataPopulatorWithBuffer(MetadataPopulator):
Raises:
ValueError: model_buf is empty.
ValueError: model_buf does not have the expected flatbuffer identifer.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
@ -423,6 +430,8 @@ class MetadataDisplayer(object):
metadata_file: valid path to the metadata file.
associated_file_list: list of associate files in the model file.
"""
_assert_model_file_identifier(model_file)
_assert_metadata_file_identifier(metadata_file)
self._model_file = model_file
self._metadata_file = metadata_file
self._associated_file_list = associated_file_list
@ -553,3 +562,32 @@ def _assert_exist(filename):
"""Checks if a file exists."""
if not os.path.exists(filename):
raise IOError("File, '{0}', does not exist.".format(filename))
def _assert_model_file_identifier(model_file):
"""Checks if a model file has the expected TFLite schema identifier."""
_assert_exist(model_file)
with open(model_file, "rb") as f:
model_buf = f.read()
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
raise ValueError(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.")
def _assert_metadata_file_identifier(metadata_file):
"""Checks if a metadata file has the expected Metadata schema identifier."""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
metadata_buf = f.read()
_assert_metadata_buffer_identifier(metadata_buf)
def _assert_metadata_buffer_identifier(metadata_buf):
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
metadata_buf, 0):
raise ValueError(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.")

View File

@ -102,6 +102,39 @@ class MetadataTest(test_util.TensorFlowTestCase):
f.write(b.Output())
return metadata_file
def _create_model_buffer_with_wrong_identifier(self):
wrong_identifier = b"widn"
model = _schema_fb.ModelT()
model_builder = flatbuffers.Builder(0)
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
return model_builder.Output()
def _create_metadata_buffer_with_wrong_identifier(self):
# Creates a metadata with wrong identifier
wrong_identifier = b"widn"
metadata = _metadata_fb.ModelMetadataT()
metadata_builder = flatbuffers.Builder(0)
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
return metadata_builder.Output()
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
identifier):
# For testing purposes only. MetadataPopulator cannot populate metadata with
# wrong identifiers.
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = metadata_buf
model.buffers = [buffer_field]
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata = [metadata_field]
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), identifier)
return b.Output()
class MetadataPopulatorTest(MetadataTest):
@ -126,6 +159,14 @@ class MetadataPopulatorTest(MetadataTest):
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
self.assertEqual("model_buf cannot be empty.", str(error.exception))
def testToModelBufferWithWrongIdentifier(self):
model_buf = self._create_model_buffer_with_wrong_identifier()
with self.assertRaises(ValueError) as error:
_metadata.MetadataPopulator.with_model_buffer(model_buf)
self.assertEqual(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.", str(error.exception))
def testSinglePopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
@ -228,6 +269,15 @@ class MetadataPopulatorTest(MetadataTest):
"not been loaded into the populator.").format(
os.path.basename(self._file2)), str(error.exception))
def testPopulateMetadataBufferWithWrongIdentifier(self):
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer(metadata_buf)
self.assertEqual(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.", str(error.exception))
def _assert_golden_metadata(self, model_file):
with open(model_file, "rb") as f:
model_buf_from_file = f.read()
@ -332,6 +382,34 @@ class MetadataDisplayerTest(MetadataTest):
populator.populate()
return model_file
def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException(
self):
model_buf = self._create_model_buffer_with_wrong_identifier()
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
model_buf = self._populate_metadata_with_identifier(
model_buf, metadata_buf,
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
self.assertEqual(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.", str(error.exception))
def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException(
self):
model_buf = self._create_model_buffer_with_wrong_identifier()
metadata_file = self._create_metadata_file()
wrong_identifier = b"widn"
with open(metadata_file, "rb") as f:
metadata_buf = bytearray(f.read())
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
wrong_identifier)
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
self.assertEqual(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.", str(error.exception))
def test_load_model_file_invalidModelFile_throwsException(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)

View File

@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "flatbuffers",
strip_prefix = "flatbuffers-a4b2884e4ed6116335d534af8f58a84678b74a17",
sha256 = "6ff041dcaf873acbf0a93886e6b4f7704b68af1457e8b675cae88fbefe2de330",
strip_prefix = "flatbuffers-1.12.0",
sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip",
"https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
"https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
],
build_file = "//third_party/flatbuffers:BUILD.bazel",
system_build_file = "//third_party/flatbuffers:BUILD.system",