Stamp the minimum metadata parser version in MetadataPopulator.

PiperOrigin-RevId: 313264741
Change-Id: I823cff6f816aa8667ac351ca0fbb0f72178617b3
This commit is contained in:
Lu Wang 2020-05-26 14:10:23 -07:00 committed by TensorFlower Gardener
parent 68adba436c
commit 8182ab3bfc
11 changed files with 308 additions and 5 deletions

View File

@ -62,6 +62,7 @@ py_library(
deps = [
":metadata_schema_py",
":schema_py",
"//tensorflow/lite/experimental/support/metadata/cc/python:_pywrap_metadata_version",
"//tensorflow/lite/experimental/support/metadata/flatbuffers_lib:_pywrap_flatbuffers",
"//tensorflow/python:platform",
"@flatbuffers//:runtime_py",

View File

@ -0,0 +1,16 @@
package(
default_visibility = ["//tensorflow/lite/experimental/support:users"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "metadata_version",
srcs = ["metadata_version.cc"],
hdrs = ["metadata_version.h"],
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/tools:logging",
"@flatbuffers",
],
)

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/tools/logging.h"
namespace tflite {
namespace metadata {
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
size_t buffer_size,
std::string* min_version) {
flatbuffers::Verifier verifier =
flatbuffers::Verifier(buffer_data, buffer_size);
if (!tflite::VerifyModelMetadataBuffer(verifier)) {
TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
return kTfLiteError;
}
// Returns the version as the initial default one, "1.0.0", because it is the
// first version ever for metadata_schema.fbs.
//
// Later, when new fields are added to the schema, we'll update the logic of
// getting the minimum metadata parser version. To be more specific, we'll
// have a table that records the new fields and the versions of the schema
// they are added to. And the minimum metadata parser version will be the
// largest version number of all fields that has been added to a metadata
// flatbuffer.
// TODO(b/156539454): replace the hardcoded version with template + genrule.
static constexpr char kDefaultVersion[] = "1.0.0";
*min_version = kDefaultVersion;
return kTfLiteOk;
}
} // namespace metadata
} // namespace tflite

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
#include <string>
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace metadata {
// Gets the minimum metadata parser version that can fully understand all fields
// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning
// 2.0. Each release version has the form MAJOR.MINOR.PATCH.
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
size_t buffer_size,
std::string* min_version);
} // namespace metadata
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_

View File

@ -0,0 +1,22 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//tensorflow/lite/experimental/support/metadata:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_metadata_version",
srcs = [
"metadata_version.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_metadata_version",
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
"@pybind11",
],
)

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include "pybind11/pybind11.h"
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace metadata {
PYBIND11_MODULE(_pywrap_metadata_version, m) {
m.doc() = R"pbdoc(
_pywrap_metadata_version
A module that returns the minimum metadata parser version of a given
metadata flatbuffer.
)pbdoc";
// Using pybind11 type conversions to convert between Python and native
// C++ types. There are other options to provide access to native Python types
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
// Type converstions is recommended by pybind11, though the main downside
// is that a copy of the data must be made on every Python to C++ transition:
// this is needed since the C++ and Python versions of the same type generally
// wont have the same memory layout.
//
// [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html
m.def("GetMinimumMetadataParserVersion",
[](const std::string& buffer_data) -> std::string {
std::string min_version;
if (GetMinimumMetadataParserVersion(
reinterpret_cast<const uint8_t*>(buffer_data.c_str()),
buffer_data.length(), &min_version) != kTfLiteOk) {
pybind11::value_error(
"Error occurred when getting the minimum metadata parser "
"version of the metadata flatbuffer.");
}
return min_version;
});
}
} // namespace metadata
} // namespace tflite

View File

@ -0,0 +1,15 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_test(
name = "metadata_version_test",
srcs = ["metadata_version_test.cc"],
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
],
)

View File

@ -0,0 +1,65 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace metadata {
namespace {
using ::testing::MatchesRegex;
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionSucceedsWithValidMetadata) {
// Creates a dummy metadata flatbuffer for test.
flatbuffers::FlatBufferBuilder builder(1024);
auto name = builder.CreateString("Foo");
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_name(name);
auto metadata = metadata_builder.Finish();
FinishModelMetadataBuffer(builder, metadata);
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is well-formed (x.y.z).
EXPECT_THAT(min_version, MatchesRegex("[0-9]*\\.[0-9]*\\.[0-9]"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionSucceedsWithInvalidIdentifier) {
// Creates a dummy metadata flatbuffer without identifier.
flatbuffers::FlatBufferBuilder builder(1024);
ModelMetadataBuilder metadata_builder(builder);
auto metadata = metadata_builder.Finish();
builder.Finish(metadata);
// Gets the mimimum metadata parser version and triggers error.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteError);
EXPECT_TRUE(min_version.empty());
}
} // namespace
} // namespace metadata
} // namespace tflite

View File

@ -28,6 +28,7 @@ import zipfile
from flatbuffers.python import flatbuffers
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
from tensorflow.lite.experimental.support.metadata.cc.python import _pywrap_metadata_version
from tensorflow.lite.experimental.support.metadata.flatbuffers_lib import _pywrap_flatbuffers
from tensorflow.python.platform import resource_loader
@ -55,7 +56,7 @@ class MetadataPopulator(object):
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and lable file into the model as follows.
Then, pack the metadata and label file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
@ -78,6 +79,9 @@ class MetadataPopulator(object):
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
Note that existing metadata buffer (if applied) will be overridden by the new
metadata buffer.
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
@ -218,12 +222,27 @@ class MetadataPopulator(object):
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Error occurs when getting the minimum metadata parser version.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
_assert_metadata_buffer_identifier(metadata_buf)
self._metadata_buf = metadata_buf
# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
bytes(metadata_buf))
# Inserts in the minimum metadata parser version into the metadata_buf.
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
b = flatbuffers.Builder(0)
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
metadata_buf_with_version = b.Output()
self._metadata_buf = metadata_buf_with_version
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
@ -325,6 +344,9 @@ class MetadataPopulator(object):
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
Existing metadata buffer (if applied) will be overridden by the new metadata
buffer.
"""
with open(self._model_file, "rb") as f:

View File

@ -43,6 +43,8 @@ class MetadataTest(test_util.TensorFlowTestCase):
f.write(self._empty_model_buf)
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
self._metadata_file = self._create_metadata_file()
self._metadata_file_with_version = self._create_metadata_file_with_version(
self._metadata_file, "1.0.0")
self._file1 = self.create_tempfile("file1").full_path
self._file2 = self.create_tempfile("file2").full_path
self._file3 = self.create_tempfile("file3").full_path
@ -135,6 +137,25 @@ class MetadataTest(test_util.TensorFlowTestCase):
b.Finish(model.Pack(b), identifier)
return b.Output()
def _create_metadata_file_with_version(self, metadata_file, min_version):
# Creates a new metadata file with the specified min_version for testing
# purposes.
with open(metadata_file, "rb") as f:
metadata_buf = bytearray(f.read())
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
b = flatbuffers.Builder(0)
b.Finish(
metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_file_with_version = self.create_tempfile().full_path
with open(metadata_file_with_version, "wb") as f:
f.write(b.Output())
return metadata_file_with_version
class MetadataPopulatorTest(MetadataTest):
@ -245,7 +266,7 @@ class MetadataPopulatorTest(MetadataTest):
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
with open(self._metadata_file, "rb") as f:
with open(self._metadata_file_with_version, "rb") as f:
expected_metadata_buf = bytearray(f.read())
self.assertEqual(metadata_buf, expected_metadata_buf)
@ -293,7 +314,7 @@ class MetadataPopulatorTest(MetadataTest):
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
with open(self._metadata_file, "rb") as f:
with open(self._metadata_file_with_version, "rb") as f:
expected_metadata_buf = bytearray(f.read())
self.assertEqual(metadata_buf, expected_metadata_buf)

View File

@ -17,5 +17,6 @@
{
"name": "file1"
}
]
],
"min_parser_version": "1.0.0"
}