Check Flatbuffer identifier in both MetadataPopulator and MetadataExtractor
PiperOrigin-RevId: 304633026 Change-Id: Ib8edd9897c3ae7162aac1bba4aa676e99e9c0e07
This commit is contained in:
parent
569f3f82f6
commit
ac539cf044
@ -97,8 +97,9 @@ class MetadataPopulator(object):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
IOError: File not found.
|
IOError: File not found.
|
||||||
|
ValueError: the model does not have the expected flatbuffer identifer.
|
||||||
"""
|
"""
|
||||||
_assert_exist(model_file)
|
_assert_model_file_identifier(model_file)
|
||||||
self._model_file = model_file
|
self._model_file = model_file
|
||||||
self._metadata_buf = None
|
self._metadata_buf = None
|
||||||
self._associated_files = set()
|
self._associated_files = set()
|
||||||
@ -115,6 +116,7 @@ class MetadataPopulator(object):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
IOError: File not found.
|
IOError: File not found.
|
||||||
|
ValueError: the model does not have the expected flatbuffer identifer.
|
||||||
"""
|
"""
|
||||||
return cls(model_file)
|
return cls(model_file)
|
||||||
|
|
||||||
@ -129,6 +131,9 @@ class MetadataPopulator(object):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: the model does not have the expected flatbuffer identifer.
|
||||||
"""
|
"""
|
||||||
return _MetadataPopulatorWithBuffer(model_buf)
|
return _MetadataPopulatorWithBuffer(model_buf)
|
||||||
|
|
||||||
@ -211,12 +216,13 @@ class MetadataPopulator(object):
|
|||||||
metadata_buf: metadata buffer (in bytearray) to be populated.
|
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError:
|
ValueError: The metadata to be populated is empty.
|
||||||
The metadata to be populated is empty.
|
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||||
"""
|
"""
|
||||||
if not metadata_buf:
|
if not metadata_buf:
|
||||||
raise ValueError("The metadata to be populated is empty.")
|
raise ValueError("The metadata to be populated is empty.")
|
||||||
|
|
||||||
|
_assert_metadata_buffer_identifier(metadata_buf)
|
||||||
self._metadata_buf = metadata_buf
|
self._metadata_buf = metadata_buf
|
||||||
|
|
||||||
def load_metadata_file(self, metadata_file):
|
def load_metadata_file(self, metadata_file):
|
||||||
@ -226,8 +232,8 @@ class MetadataPopulator(object):
|
|||||||
metadata_file: path to the metadata file to be populated.
|
metadata_file: path to the metadata file to be populated.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
IOError:
|
IOError: File not found.
|
||||||
File not found.
|
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||||
"""
|
"""
|
||||||
_assert_exist(metadata_file)
|
_assert_exist(metadata_file)
|
||||||
with open(metadata_file, "rb") as f:
|
with open(metadata_file, "rb") as f:
|
||||||
@ -391,6 +397,7 @@ class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: model_buf is empty.
|
ValueError: model_buf is empty.
|
||||||
|
ValueError: model_buf does not have the expected flatbuffer identifer.
|
||||||
"""
|
"""
|
||||||
if not model_buf:
|
if not model_buf:
|
||||||
raise ValueError("model_buf cannot be empty.")
|
raise ValueError("model_buf cannot be empty.")
|
||||||
@ -423,6 +430,8 @@ class MetadataDisplayer(object):
|
|||||||
metadata_file: valid path to the metadata file.
|
metadata_file: valid path to the metadata file.
|
||||||
associated_file_list: list of associate files in the model file.
|
associated_file_list: list of associate files in the model file.
|
||||||
"""
|
"""
|
||||||
|
_assert_model_file_identifier(model_file)
|
||||||
|
_assert_metadata_file_identifier(metadata_file)
|
||||||
self._model_file = model_file
|
self._model_file = model_file
|
||||||
self._metadata_file = metadata_file
|
self._metadata_file = metadata_file
|
||||||
self._associated_file_list = associated_file_list
|
self._associated_file_list = associated_file_list
|
||||||
@ -553,3 +562,32 @@ def _assert_exist(filename):
|
|||||||
"""Checks if a file exists."""
|
"""Checks if a file exists."""
|
||||||
if not os.path.exists(filename):
|
if not os.path.exists(filename):
|
||||||
raise IOError("File, '{0}', does not exist.".format(filename))
|
raise IOError("File, '{0}', does not exist.".format(filename))
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_model_file_identifier(model_file):
|
||||||
|
"""Checks if a model file has the expected TFLite schema identifier."""
|
||||||
|
_assert_exist(model_file)
|
||||||
|
with open(model_file, "rb") as f:
|
||||||
|
model_buf = f.read()
|
||||||
|
|
||||||
|
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
|
||||||
|
raise ValueError(
|
||||||
|
"The model provided does not have the expected identifier, and "
|
||||||
|
"may not be a valid TFLite model.")
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_metadata_file_identifier(metadata_file):
|
||||||
|
"""Checks if a metadata file has the expected Metadata schema identifier."""
|
||||||
|
_assert_exist(metadata_file)
|
||||||
|
with open(metadata_file, "rb") as f:
|
||||||
|
metadata_buf = f.read()
|
||||||
|
_assert_metadata_buffer_identifier(metadata_buf)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_metadata_buffer_identifier(metadata_buf):
|
||||||
|
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
|
||||||
|
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
|
||||||
|
metadata_buf, 0):
|
||||||
|
raise ValueError(
|
||||||
|
"The metadata buffer does not have the expected identifier, and may not"
|
||||||
|
" be a valid TFLite Metadata.")
|
||||||
|
@ -102,6 +102,39 @@ class MetadataTest(test_util.TensorFlowTestCase):
|
|||||||
f.write(b.Output())
|
f.write(b.Output())
|
||||||
return metadata_file
|
return metadata_file
|
||||||
|
|
||||||
|
def _create_model_buffer_with_wrong_identifier(self):
|
||||||
|
wrong_identifier = b"widn"
|
||||||
|
model = _schema_fb.ModelT()
|
||||||
|
model_builder = flatbuffers.Builder(0)
|
||||||
|
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
|
||||||
|
return model_builder.Output()
|
||||||
|
|
||||||
|
def _create_metadata_buffer_with_wrong_identifier(self):
|
||||||
|
# Creates a metadata with wrong identifier
|
||||||
|
wrong_identifier = b"widn"
|
||||||
|
metadata = _metadata_fb.ModelMetadataT()
|
||||||
|
metadata_builder = flatbuffers.Builder(0)
|
||||||
|
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
|
||||||
|
return metadata_builder.Output()
|
||||||
|
|
||||||
|
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
|
||||||
|
identifier):
|
||||||
|
# For testing purposes only. MetadataPopulator cannot populate metadata with
|
||||||
|
# wrong identifiers.
|
||||||
|
model = _schema_fb.ModelT.InitFromObj(
|
||||||
|
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||||
|
buffer_field = _schema_fb.BufferT()
|
||||||
|
buffer_field.data = metadata_buf
|
||||||
|
model.buffers = [buffer_field]
|
||||||
|
# Creates a new metadata field.
|
||||||
|
metadata_field = _schema_fb.MetadataT()
|
||||||
|
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
|
||||||
|
metadata_field.buffer = len(model.buffers) - 1
|
||||||
|
model.metadata = [metadata_field]
|
||||||
|
b = flatbuffers.Builder(0)
|
||||||
|
b.Finish(model.Pack(b), identifier)
|
||||||
|
return b.Output()
|
||||||
|
|
||||||
|
|
||||||
class MetadataPopulatorTest(MetadataTest):
|
class MetadataPopulatorTest(MetadataTest):
|
||||||
|
|
||||||
@ -126,6 +159,14 @@ class MetadataPopulatorTest(MetadataTest):
|
|||||||
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
||||||
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
||||||
|
|
||||||
|
def testToModelBufferWithWrongIdentifier(self):
|
||||||
|
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
_metadata.MetadataPopulator.with_model_buffer(model_buf)
|
||||||
|
self.assertEqual(
|
||||||
|
"The model provided does not have the expected identifier, and "
|
||||||
|
"may not be a valid TFLite model.", str(error.exception))
|
||||||
|
|
||||||
def testSinglePopulateAssociatedFile(self):
|
def testSinglePopulateAssociatedFile(self):
|
||||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
self._empty_model_buf)
|
self._empty_model_buf)
|
||||||
@ -228,6 +269,15 @@ class MetadataPopulatorTest(MetadataTest):
|
|||||||
"not been loaded into the populator.").format(
|
"not been loaded into the populator.").format(
|
||||||
os.path.basename(self._file2)), str(error.exception))
|
os.path.basename(self._file2)), str(error.exception))
|
||||||
|
|
||||||
|
def testPopulateMetadataBufferWithWrongIdentifier(self):
|
||||||
|
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
populator.load_metadata_buffer(metadata_buf)
|
||||||
|
self.assertEqual(
|
||||||
|
"The metadata buffer does not have the expected identifier, and may not"
|
||||||
|
" be a valid TFLite Metadata.", str(error.exception))
|
||||||
|
|
||||||
def _assert_golden_metadata(self, model_file):
|
def _assert_golden_metadata(self, model_file):
|
||||||
with open(model_file, "rb") as f:
|
with open(model_file, "rb") as f:
|
||||||
model_buf_from_file = f.read()
|
model_buf_from_file = f.read()
|
||||||
@ -332,6 +382,34 @@ class MetadataDisplayerTest(MetadataTest):
|
|||||||
populator.populate()
|
populator.populate()
|
||||||
return model_file
|
return model_file
|
||||||
|
|
||||||
|
def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException(
|
||||||
|
self):
|
||||||
|
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||||
|
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||||
|
model_buf = self._populate_metadata_with_identifier(
|
||||||
|
model_buf, metadata_buf,
|
||||||
|
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||||
|
self.assertEqual(
|
||||||
|
"The metadata buffer does not have the expected identifier, and may not"
|
||||||
|
" be a valid TFLite Metadata.", str(error.exception))
|
||||||
|
|
||||||
|
def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException(
|
||||||
|
self):
|
||||||
|
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||||
|
metadata_file = self._create_metadata_file()
|
||||||
|
wrong_identifier = b"widn"
|
||||||
|
with open(metadata_file, "rb") as f:
|
||||||
|
metadata_buf = bytearray(f.read())
|
||||||
|
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
|
||||||
|
wrong_identifier)
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||||
|
self.assertEqual(
|
||||||
|
"The model provided does not have the expected identifier, and "
|
||||||
|
"may not be a valid TFLite model.", str(error.exception))
|
||||||
|
|
||||||
def test_load_model_file_invalidModelFile_throwsException(self):
|
def test_load_model_file_invalidModelFile_throwsException(self):
|
||||||
with self.assertRaises(IOError) as error:
|
with self.assertRaises(IOError) as error:
|
||||||
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
||||||
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
|||||||
def repo():
|
def repo():
|
||||||
third_party_http_archive(
|
third_party_http_archive(
|
||||||
name = "flatbuffers",
|
name = "flatbuffers",
|
||||||
strip_prefix = "flatbuffers-a4b2884e4ed6116335d534af8f58a84678b74a17",
|
strip_prefix = "flatbuffers-1.12.0",
|
||||||
sha256 = "6ff041dcaf873acbf0a93886e6b4f7704b68af1457e8b675cae88fbefe2de330",
|
sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
|
||||||
"https://github.com/google/flatbuffers/archive/a4b2884e4ed6116335d534af8f58a84678b74a17.zip",
|
"https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
|
||||||
],
|
],
|
||||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||||
system_build_file = "//third_party/flatbuffers:BUILD.system",
|
system_build_file = "//third_party/flatbuffers:BUILD.system",
|
||||||
|
Loading…
Reference in New Issue
Block a user