Add vocab file type and update the versioning

PiperOrigin-RevId: 314639904
Change-Id: Id1e09f2f7590aadaa2c77abd0999458863960258
This commit is contained in:
Lu Wang 2020-06-03 17:55:54 -07:00 committed by TensorFlower Gardener
parent 7c001cc50b
commit 716e8a092c
6 changed files with 318 additions and 17 deletions

View File

@ -10,7 +10,9 @@ cc_library(
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/tools:logging",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
)

View File

@ -14,16 +14,171 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <ostream>
#include <string>
#include <vector>
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/tools/logging.h"
namespace tflite {
namespace metadata {
namespace {
// Members that are added to the metadata schema after the initial version
// of 1.0.0.
enum class SchemaMembers {
kAssociatedFileTypeVocabulary = 0,
};
// Helper class to compare semantic versions in terms of three integers, major,
// minor, and patch.
class Version {
public:
explicit Version(int major, int minor = 0, int patch = 0)
: version_({major, minor, patch}) {}
explicit Version(const std::string& version) {
const std::vector<std::string> vec = absl::StrSplit(version, '.');
// The version string should always be less than four numbers.
TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
version_[0] = std::stoi(vec[0]);
version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
}
// Compares two semantic version numbers.
//
// Example results when comparing two versions strings:
// "1.9" precedes "1.14";
// "1.14" precedes "1.14.1";
// "1.14" and "1.14.0" are equal.
//
// Returns the value 0 if the two versions are equal; a value less than 0 if
// *this precedes v; a value greater than 0 if v precedes *this.
int Compare(const Version& v) {
for (int i = 0; i < kElementNumber; ++i) {
if (version_[i] != v.version_[i]) {
return version_[i] < v.version_[i] ? -1 : 1;
}
}
return 0;
}
// Converts version_ into a version string.
std::string ToString() { return absl::StrJoin(version_, "."); }
private:
static constexpr int kElementNumber = 3;
std::array<int, kElementNumber> version_;
};
Version GetMemberVersion(SchemaMembers member) {
switch (member) {
case SchemaMembers::kAssociatedFileTypeVocabulary:
return Version(1, 0, 1);
default:
TFLITE_LOG(FATAL) << "Unsupported schema member: "
<< static_cast<int>(member);
}
}
// Updates min_version if it precedes the new_version.
inline void UpdateMinimumVersion(const Version& new_version,
Version* min_version) {
if (min_version->Compare(new_version) < 0) {
*min_version = new_version;
}
}
void UpdateMinimumVersionForAssociatedFile(
const tflite::AssociatedFile* associated_file, Version* min_version) {
if (associated_file == nullptr) return;
if (associated_file->type() == AssociatedFileType_VOCABULARY) {
UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
min_version);
}
}
void UpdateMinimumVersionForAssociatedFileArray(
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
associated_files,
Version* min_version) {
if (associated_files == nullptr) return;
for (int i = 0; i < associated_files->size(); ++i) {
UpdateMinimumVersionForAssociatedFile(associated_files->Get(i),
min_version);
}
}
void UpdateMinimumVersionForTensorMetadata(
const tflite::TensorMetadata* tensor_metadata, Version* min_version) {
if (tensor_metadata == nullptr) return;
// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(
tensor_metadata->associated_files(), min_version);
}
void UpdateMinimumVersionForTensorMetadataArray(
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
tensor_metadata_array,
Version* min_version) {
if (tensor_metadata_array == nullptr) return;
for (int i = 0; i < tensor_metadata_array->size(); ++i) {
UpdateMinimumVersionForTensorMetadata(tensor_metadata_array->Get(i),
min_version);
}
}
void UpdateMinimumVersionForSubGraphMetadata(
const tflite::SubGraphMetadata* subgraph_metadata, Version* min_version) {
if (subgraph_metadata == nullptr) return;
// Checks in the input/output metadata arrays.
UpdateMinimumVersionForTensorMetadataArray(
subgraph_metadata->input_tensor_metadata(), min_version);
UpdateMinimumVersionForTensorMetadataArray(
subgraph_metadata->output_tensor_metadata(), min_version);
// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(
subgraph_metadata->associated_files(), min_version);
}
void UpdateMinimumVersionForModelMetadata(
const tflite::ModelMetadata& model_metadata, Version* min_version) {
// Checks the subgraph_metadata field.
if (model_metadata.subgraph_metadata() != nullptr) {
for (int i = 0; i < model_metadata.subgraph_metadata()->size(); ++i) {
UpdateMinimumVersionForSubGraphMetadata(
model_metadata.subgraph_metadata()->Get(i), min_version);
}
}
// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(model_metadata.associated_files(),
min_version);
}
} // namespace
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
size_t buffer_size,
std::string* min_version) {
std::string* min_version_str) {
flatbuffers::Verifier verifier =
flatbuffers::Verifier(buffer_data, buffer_size);
if (!tflite::VerifyModelMetadataBuffer(verifier)) {
@ -31,18 +186,27 @@ TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
return kTfLiteError;
}
// Returns the version as the initial default one, "1.0.0", because it is the
// first version ever for metadata_schema.fbs.
//
// Later, when new fields are added to the schema, we'll update the logic of
// getting the minimum metadata parser version. To be more specific, we'll
// have a table that records the new fields and the versions of the schema
// they are added to. And the minimum metadata parser version will be the
// largest version number of all fields that has been added to a metadata
// flatbuffer.
// TODO(b/156539454): replace the hardcoded version with template + genrule.
static constexpr char kDefaultVersion[] = "1.0.0";
*min_version = kDefaultVersion;
Version min_version = Version(kDefaultVersion);
// Checks if any member declared after 1.0.0 (such as those in
// SchemaMembers) exists, and updates min_version accordingly. The minimum
// metadata parser version will be the largest version number of all fields
// that has been added to a metadata flatbuffer
const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
// All tables in the metadata schema should have their dedicated
// UpdateMinimumVersionFor**() methods, respectively. We'll gradually add
// these methods when new fields show up in later schema versions.
//
// UpdateMinimumVersionFor<Foo>() takes a const pointer of Foo. The pointer
// can be a nullptr if Foo is not populated into the corresponding table of
// the Flatbuffer object. In this case, UpdateMinimumVersionFor<Foo>() will be
// skipped. An exception is UpdateMinimumVersionForModelMetadata(), where
// ModelMetadata is the root table, and it won't be null.
UpdateMinimumVersionForModelMetadata(*model_metadata, &min_version);
*min_version_str = min_version.ToString();
return kTfLiteOk;
}

View File

@ -15,6 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include "tensorflow/lite/c/common.h"

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
@ -24,6 +26,7 @@ namespace metadata {
namespace {
using ::testing::MatchesRegex;
using ::testing::StrEq;
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionSucceedsWithValidMetadata) {
@ -45,7 +48,7 @@ TEST(MetadataVersionTest,
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionSucceedsWithInvalidIdentifier) {
GetMinimumMetadataParserVersionFailsWithInvalidIdentifier) {
// Creates a dummy metadata flatbuffer without identifier.
flatbuffers::FlatBufferBuilder builder(1024);
ModelMetadataBuilder metadata_builder(builder);
@ -60,6 +63,125 @@ TEST(MetadataVersionTest,
EXPECT_TRUE(min_version.empty());
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForModelMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// ModelMetadata.associated_fiels, populated with the vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_THAT(min_version, StrEq("1.0.1"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForSubGraphMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// SubGraphMetadata.associated_fiels, populated with the vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_associated_files(associated_files);
auto subgraphs =
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_THAT(min_version, StrEq("1.0.1"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForInputMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// SubGraphMetadata.input_tensor_metadata.associated_fiels, populated with the
// vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
TensorMetadataBuilder tensor_builder(builder);
tensor_builder.add_associated_files(associated_files);
auto tensors =
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
tensor_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_input_tensor_metadata(tensors);
auto subgraphs =
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_THAT(min_version, StrEq("1.0.1"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForOutputMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// SubGraphMetadata.output_tensor_metadata.associated_fiels, populated with
// the vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
TensorMetadataBuilder tensor_builder(builder);
tensor_builder.add_associated_files(associated_files);
auto tensors =
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
tensor_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_output_tensor_metadata(tensors);
auto subgraphs =
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_EQ(min_version, "1.0.1");
}
} // namespace
} // namespace metadata
} // namespace tflite

View File

@ -55,7 +55,7 @@ public class MetadataExtractor {
// TODO(b/156539454): remove the hardcode versioning number and populate the version through
// genrule.
/** The version of the metadata parser that this {@link MetadataExtractor} library depends on. */
public static final String METADATA_PARSER_VERSION = "1.0.0";
public static final String METADATA_PARSER_VERSION = "1.0.1";
/** The helper class to load metadata from TFLite model FlatBuffer. */
private final ModelInfo modelInfo;

View File

@ -40,23 +40,28 @@ namespace tflite;
// New fields and types will have associated comments with the schema version for
// which they were added.
//
// Schema Semantic version: 1.0.0
// Schema Semantic version: 1.0.1
// This indicates the flatbuffer compatibility. The number will bump up when a
// break change is applied to the schema, such as removing fields or adding new
// fields to the middle of a table.
file_identifier "M001";
// History:
// 1.0.1 - Added VOCABULARY type to AssociatedFileType.
// File extension of any written files.
file_extension "tflitemeta";
// LINT.ThenChange(//tensorflow/lite/experimental/\
// /supportmetadata/java/src/java/org/tensorflow/lite/support/metadata/\
// /support/metadata/java/src/java/org/tensorflow/lite/support/metadata/\
// MetadataExtractor.java)
// LINT.IfChange
enum AssociatedFileType : byte {
UNKNOWN = 0,
// Files such as readme.txt
// Files such as readme.txt.
DESCRIPTIONS = 1,
// Contains labels that annotate certain axis of the tensor. For example,
@ -98,6 +103,11 @@ enum AssociatedFileType : byte {
//
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
TENSOR_AXIS_SCORE_CALIBRATION = 4,
// Contains a list of unique words (characters separated by "\n" or in lines)
// that help to convert natural language words to embedding vectors.
// Added in: 1.0.1
VOCABULARY = 5,
}
table AssociatedFile {