Check Flatbuffer identifier in both MetadataPopulator and MetadataExtractor
PiperOrigin-RevId: 304633026 Change-Id: Ib8edd9897c3ae7162aac1bba4aa676e99e9c0e07
This commit is contained in:
parent
569f3f82f6
commit
ac539cf044
@ -97,8 +97,9 @@ class MetadataPopulator(object):
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
_assert_exist(model_file)
|
||||
_assert_model_file_identifier(model_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_buf = None
|
||||
self._associated_files = set()
|
||||
@ -115,6 +116,7 @@ class MetadataPopulator(object):
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return cls(model_file)
|
||||
|
||||
@ -129,6 +131,9 @@ class MetadataPopulator(object):
|
||||
|
||||
Returns:
|
||||
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||
|
||||
Raises:
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return _MetadataPopulatorWithBuffer(model_buf)
|
||||
|
||||
@ -211,12 +216,13 @@ class MetadataPopulator(object):
|
||||
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
The metadata to be populated is empty.
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
if not metadata_buf:
|
||||
raise ValueError("The metadata to be populated is empty.")
|
||||
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
self._metadata_buf = metadata_buf
|
||||
|
||||
def load_metadata_file(self, metadata_file):
|
||||
@ -226,8 +232,8 @@ class MetadataPopulator(object):
|
||||
metadata_file: path to the metadata file to be populated.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
File not found.
|
||||
IOError: File not found.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
_assert_exist(metadata_file)
|
||||
with open(metadata_file, "rb") as f:
|
||||
@ -391,6 +397,7 @@ class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
||||
|
||||
Raises:
|
||||
ValueError: model_buf is empty.
|
||||
ValueError: model_buf does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
if not model_buf:
|
||||
raise ValueError("model_buf cannot be empty.")
|
||||
@ -423,6 +430,8 @@ class MetadataDisplayer(object):
|
||||
metadata_file: valid path to the metadata file.
|
||||
associated_file_list: list of associate files in the model file.
|
||||
"""
|
||||
_assert_model_file_identifier(model_file)
|
||||
_assert_metadata_file_identifier(metadata_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_file = metadata_file
|
||||
self._associated_file_list = associated_file_list
|
||||
@ -553,3 +562,32 @@ def _assert_exist(filename):
|
||||
"""Checks if a file exists."""
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("File, '{0}', does not exist.".format(filename))
|
||||
|
||||
|
||||
def _assert_model_file_identifier(model_file):
|
||||
"""Checks if a model file has the expected TFLite schema identifier."""
|
||||
_assert_exist(model_file)
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
|
||||
raise ValueError(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.")
|
||||
|
||||
|
||||
def _assert_metadata_file_identifier(metadata_file):
|
||||
"""Checks if a metadata file has the expected Metadata schema identifier."""
|
||||
_assert_exist(metadata_file)
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = f.read()
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
|
||||
|
||||
def _assert_metadata_buffer_identifier(metadata_buf):
|
||||
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
|
||||
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
|
||||
metadata_buf, 0):
|
||||
raise ValueError(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.")
|
||||
|
@ -102,6 +102,39 @@ class MetadataTest(test_util.TensorFlowTestCase):
|
||||
f.write(b.Output())
|
||||
return metadata_file
|
||||
|
||||
def _create_model_buffer_with_wrong_identifier(self):
|
||||
wrong_identifier = b"widn"
|
||||
model = _schema_fb.ModelT()
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
|
||||
return model_builder.Output()
|
||||
|
||||
def _create_metadata_buffer_with_wrong_identifier(self):
|
||||
# Creates a metadata with wrong identifier
|
||||
wrong_identifier = b"widn"
|
||||
metadata = _metadata_fb.ModelMetadataT()
|
||||
metadata_builder = flatbuffers.Builder(0)
|
||||
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
|
||||
return metadata_builder.Output()
|
||||
|
||||
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
|
||||
identifier):
|
||||
# For testing purposes only. MetadataPopulator cannot populate metadata with
|
||||
# wrong identifiers.
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = metadata_buf
|
||||
model.buffers = [buffer_field]
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata = [metadata_field]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), identifier)
|
||||
return b.Output()
|
||||
|
||||
|
||||
class MetadataPopulatorTest(MetadataTest):
|
||||
|
||||
@ -126,6 +159,14 @@ class MetadataPopulatorTest(MetadataTest):
|
||||
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
||||
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
||||
|
||||
def testToModelBufferWithWrongIdentifier(self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataPopulator.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.", str(error.exception))
|
||||
|
||||
def testSinglePopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
@ -228,6 +269,15 @@ class MetadataPopulatorTest(MetadataTest):
|
||||
"not been loaded into the populator.").format(
|
||||
os.path.basename(self._file2)), str(error.exception))
|
||||
|
||||
def testPopulateMetadataBufferWithWrongIdentifier(self):
|
||||
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
self.assertEqual(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.", str(error.exception))
|
||||
|
||||
def _assert_golden_metadata(self, model_file):
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
@ -332,6 +382,34 @@ class MetadataDisplayerTest(MetadataTest):
|
||||
populator.populate()
|
||||
return model_file
|
||||
|
||||
def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException(
|
||||
self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||
model_buf = self._populate_metadata_with_identifier(
|
||||
model_buf, metadata_buf,
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.", str(error.exception))
|
||||
|
||||
def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException(
|
||||
self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
metadata_file = self._create_metadata_file()
|
||||
wrong_identifier = b"widn"
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = bytearray(f.read())
|
||||
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
|
||||
wrong_identifier)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.", str(error.exception))
|
||||
|
||||
def test_load_model_file_invalidModelFile_throwsException(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
||||
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "flatbuffers",
|
||||
strip_prefix = "flatbuffers-a4b2884e4ed6116335d534af8f58a84678b74a17",
|
||||
sha256 = "6ff041dcaf873acbf0a93886e6b4f7704b68af1457e8b675cae88fbefe2de330",
|
||||
strip_prefix = "flatbuffers-1.12.0",
|
||||
sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip",
|
||||
"https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
|
||||
"https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
|
||||
],
|
||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||
system_build_file = "//third_party/flatbuffers:BUILD.system",
|
||||
|
Loading…
Reference in New Issue
Block a user