Stamp the minimum metadata parser version in MetadataPopulator.
PiperOrigin-RevId: 313264741 Change-Id: I823cff6f816aa8667ac351ca0fbb0f72178617b3
This commit is contained in:
parent
68adba436c
commit
8182ab3bfc
@ -62,6 +62,7 @@ py_library(
|
||||
deps = [
|
||||
":metadata_schema_py",
|
||||
":schema_py",
|
||||
"//tensorflow/lite/experimental/support/metadata/cc/python:_pywrap_metadata_version",
|
||||
"//tensorflow/lite/experimental/support/metadata/flatbuffers_lib:_pywrap_flatbuffers",
|
||||
"//tensorflow/python:platform",
|
||||
"@flatbuffers//:runtime_py",
|
||||
|
16
tensorflow/lite/experimental/support/metadata/cc/BUILD
Normal file
16
tensorflow/lite/experimental/support/metadata/cc/BUILD
Normal file
@ -0,0 +1,16 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow/lite/experimental/support:users"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "metadata_version",
|
||||
srcs = ["metadata_version.cc"],
|
||||
hdrs = ["metadata_version.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
|
||||
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
|
||||
size_t buffer_size,
|
||||
std::string* min_version) {
|
||||
flatbuffers::Verifier verifier =
|
||||
flatbuffers::Verifier(buffer_data, buffer_size);
|
||||
if (!tflite::VerifyModelMetadataBuffer(verifier)) {
|
||||
TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Returns the version as the initial default one, "1.0.0", because it is the
|
||||
// first version ever for metadata_schema.fbs.
|
||||
//
|
||||
// Later, when new fields are added to the schema, we'll update the logic of
|
||||
// getting the minimum metadata parser version. To be more specific, we'll
|
||||
// have a table that records the new fields and the versions of the schema
|
||||
// they are added to. And the minimum metadata parser version will be the
|
||||
// largest version number of all fields that has been added to a metadata
|
||||
// flatbuffer.
|
||||
// TODO(b/156539454): replace the hardcoded version with template + genrule.
|
||||
static constexpr char kDefaultVersion[] = "1.0.0";
|
||||
*min_version = kDefaultVersion;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
|
||||
// Gets the minimum metadata parser version that can fully understand all fields
|
||||
// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning
|
||||
// 2.0. Each release version has the form MAJOR.MINOR.PATCH.
|
||||
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
|
||||
size_t buffer_size,
|
||||
std::string* min_version);
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
@ -0,0 +1,22 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/lite/experimental/support/metadata:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_metadata_version",
|
||||
srcs = [
|
||||
"metadata_version.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_metadata_version",
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_metadata_version, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_metadata_version
|
||||
A module that returns the minimum metadata parser version of a given
|
||||
metadata flatbuffer.
|
||||
)pbdoc";
|
||||
|
||||
// Using pybind11 type conversions to convert between Python and native
|
||||
// C++ types. There are other options to provide access to native Python types
|
||||
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
|
||||
// Type converstions is recommended by pybind11, though the main downside
|
||||
// is that a copy of the data must be made on every Python to C++ transition:
|
||||
// this is needed since the C++ and Python versions of the same type generally
|
||||
// won’t have the same memory layout.
|
||||
//
|
||||
// [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html
|
||||
m.def("GetMinimumMetadataParserVersion",
|
||||
[](const std::string& buffer_data) -> std::string {
|
||||
std::string min_version;
|
||||
if (GetMinimumMetadataParserVersion(
|
||||
reinterpret_cast<const uint8_t*>(buffer_data.c_str()),
|
||||
buffer_data.length(), &min_version) != kTfLiteOk) {
|
||||
pybind11::value_error(
|
||||
"Error occurred when getting the minimum metadata parser "
|
||||
"version of the metadata flatbuffer.");
|
||||
}
|
||||
return min_version;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
15
tensorflow/lite/experimental/support/metadata/cc/test/BUILD
Normal file
15
tensorflow/lite/experimental/support/metadata/cc/test/BUILD
Normal file
@ -0,0 +1,15 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "metadata_version_test",
|
||||
srcs = ["metadata_version_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
@ -0,0 +1,65 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
using ::testing::MatchesRegex;
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionSucceedsWithValidMetadata) {
|
||||
// Creates a dummy metadata flatbuffer for test.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto name = builder.CreateString("Foo");
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_name(name);
|
||||
auto metadata = metadata_builder.Finish();
|
||||
FinishModelMetadataBuffer(builder, metadata);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is well-formed (x.y.z).
|
||||
EXPECT_THAT(min_version, MatchesRegex("[0-9]*\\.[0-9]*\\.[0-9]"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionSucceedsWithInvalidIdentifier) {
|
||||
// Creates a dummy metadata flatbuffer without identifier.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
auto metadata = metadata_builder.Finish();
|
||||
builder.Finish(metadata);
|
||||
|
||||
// Gets the mimimum metadata parser version and triggers error.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteError);
|
||||
EXPECT_TRUE(min_version.empty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
@ -28,6 +28,7 @@ import zipfile
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||
from tensorflow.lite.experimental.support.metadata.cc.python import _pywrap_metadata_version
|
||||
from tensorflow.lite.experimental.support.metadata.flatbuffers_lib import _pywrap_flatbuffers
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
@ -55,7 +56,7 @@ class MetadataPopulator(object):
|
||||
classifer model using Flatbuffers API. Attach the label file onto the ouput
|
||||
tensor (the tensor of probabilities) in the metadata.
|
||||
|
||||
Then, pack the metadata and lable file into the model as follows.
|
||||
Then, pack the metadata and label file into the model as follows.
|
||||
|
||||
```python
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
@ -78,6 +79,9 @@ class MetadataPopulator(object):
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
```
|
||||
|
||||
Note that existing metadata buffer (if applied) will be overridden by the new
|
||||
metadata buffer.
|
||||
"""
|
||||
# As Zip API is used to concatenate associated files after tflite model file,
|
||||
# the populating operation is developed based on a model file. For in-memory
|
||||
@ -218,12 +222,27 @@ class MetadataPopulator(object):
|
||||
Raises:
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
ValueError: Error occurs when getting the minimum metadata parser version.
|
||||
"""
|
||||
if not metadata_buf:
|
||||
raise ValueError("The metadata to be populated is empty.")
|
||||
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
self._metadata_buf = metadata_buf
|
||||
|
||||
# Gets the minimum metadata parser version of the metadata_buf.
|
||||
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
|
||||
bytes(metadata_buf))
|
||||
|
||||
# Inserts in the minimum metadata parser version into the metadata_buf.
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf_with_version = b.Output()
|
||||
|
||||
self._metadata_buf = metadata_buf_with_version
|
||||
|
||||
def load_metadata_file(self, metadata_file):
|
||||
"""Loads the metadata file to be populated.
|
||||
@ -325,6 +344,9 @@ class MetadataPopulator(object):
|
||||
Inserts metadata_buf into the metadata field of schema.Model. If the
|
||||
MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
|
||||
Existing metadata buffer (if applied) will be overridden by the new metadata
|
||||
buffer.
|
||||
"""
|
||||
|
||||
with open(self._model_file, "rb") as f:
|
||||
|
@ -43,6 +43,8 @@ class MetadataTest(test_util.TensorFlowTestCase):
|
||||
f.write(self._empty_model_buf)
|
||||
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
|
||||
self._metadata_file = self._create_metadata_file()
|
||||
self._metadata_file_with_version = self._create_metadata_file_with_version(
|
||||
self._metadata_file, "1.0.0")
|
||||
self._file1 = self.create_tempfile("file1").full_path
|
||||
self._file2 = self.create_tempfile("file2").full_path
|
||||
self._file3 = self.create_tempfile("file3").full_path
|
||||
@ -135,6 +137,25 @@ class MetadataTest(test_util.TensorFlowTestCase):
|
||||
b.Finish(model.Pack(b), identifier)
|
||||
return b.Output()
|
||||
|
||||
def _create_metadata_file_with_version(self, metadata_file, min_version):
|
||||
# Creates a new metadata file with the specified min_version for testing
|
||||
# purposes.
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = bytearray(f.read())
|
||||
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
|
||||
metadata_file_with_version = self.create_tempfile().full_path
|
||||
with open(metadata_file_with_version, "wb") as f:
|
||||
f.write(b.Output())
|
||||
return metadata_file_with_version
|
||||
|
||||
|
||||
class MetadataPopulatorTest(MetadataTest):
|
||||
|
||||
@ -245,7 +266,7 @@ class MetadataPopulatorTest(MetadataTest):
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
with open(self._metadata_file, "rb") as f:
|
||||
with open(self._metadata_file_with_version, "rb") as f:
|
||||
expected_metadata_buf = bytearray(f.read())
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
@ -293,7 +314,7 @@ class MetadataPopulatorTest(MetadataTest):
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
with open(self._metadata_file, "rb") as f:
|
||||
with open(self._metadata_file_with_version, "rb") as f:
|
||||
expected_metadata_buf = bytearray(f.read())
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
|
@ -17,5 +17,6 @@
|
||||
{
|
||||
"name": "file1"
|
||||
}
|
||||
]
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user