Add vocab file type and update the versioning
PiperOrigin-RevId: 314639904 Change-Id: Id1e09f2f7590aadaa2c77abd0999458863960258
This commit is contained in:
parent
7c001cc50b
commit
716e8a092c
@ -10,7 +10,9 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -14,16 +14,171 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
// Members that are added to the metadata schema after the initial version
|
||||
// of 1.0.0.
|
||||
enum class SchemaMembers {
|
||||
kAssociatedFileTypeVocabulary = 0,
|
||||
};
|
||||
|
||||
// Helper class to compare semantic versions in terms of three integers, major,
|
||||
// minor, and patch.
|
||||
class Version {
|
||||
public:
|
||||
explicit Version(int major, int minor = 0, int patch = 0)
|
||||
: version_({major, minor, patch}) {}
|
||||
|
||||
explicit Version(const std::string& version) {
|
||||
const std::vector<std::string> vec = absl::StrSplit(version, '.');
|
||||
// The version string should always be less than four numbers.
|
||||
TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
|
||||
version_[0] = std::stoi(vec[0]);
|
||||
version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
|
||||
version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
|
||||
}
|
||||
|
||||
// Compares two semantic version numbers.
|
||||
//
|
||||
// Example results when comparing two versions strings:
|
||||
// "1.9" precedes "1.14";
|
||||
// "1.14" precedes "1.14.1";
|
||||
// "1.14" and "1.14.0" are equal.
|
||||
//
|
||||
// Returns the value 0 if the two versions are equal; a value less than 0 if
|
||||
// *this precedes v; a value greater than 0 if v precedes *this.
|
||||
int Compare(const Version& v) {
|
||||
for (int i = 0; i < kElementNumber; ++i) {
|
||||
if (version_[i] != v.version_[i]) {
|
||||
return version_[i] < v.version_[i] ? -1 : 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Converts version_ into a version string.
|
||||
std::string ToString() { return absl::StrJoin(version_, "."); }
|
||||
|
||||
private:
|
||||
static constexpr int kElementNumber = 3;
|
||||
std::array<int, kElementNumber> version_;
|
||||
};
|
||||
|
||||
Version GetMemberVersion(SchemaMembers member) {
|
||||
switch (member) {
|
||||
case SchemaMembers::kAssociatedFileTypeVocabulary:
|
||||
return Version(1, 0, 1);
|
||||
default:
|
||||
TFLITE_LOG(FATAL) << "Unsupported schema member: "
|
||||
<< static_cast<int>(member);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates min_version if it precedes the new_version.
|
||||
inline void UpdateMinimumVersion(const Version& new_version,
|
||||
Version* min_version) {
|
||||
if (min_version->Compare(new_version) < 0) {
|
||||
*min_version = new_version;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForAssociatedFile(
|
||||
const tflite::AssociatedFile* associated_file, Version* min_version) {
|
||||
if (associated_file == nullptr) return;
|
||||
|
||||
if (associated_file->type() == AssociatedFileType_VOCABULARY) {
|
||||
UpdateMinimumVersion(
|
||||
GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForAssociatedFileArray(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
|
||||
associated_files,
|
||||
Version* min_version) {
|
||||
if (associated_files == nullptr) return;
|
||||
|
||||
for (int i = 0; i < associated_files->size(); ++i) {
|
||||
UpdateMinimumVersionForAssociatedFile(associated_files->Get(i),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForTensorMetadata(
|
||||
const tflite::TensorMetadata* tensor_metadata, Version* min_version) {
|
||||
if (tensor_metadata == nullptr) return;
|
||||
|
||||
// Checks the associated_files field.
|
||||
UpdateMinimumVersionForAssociatedFileArray(
|
||||
tensor_metadata->associated_files(), min_version);
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForTensorMetadataArray(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
|
||||
tensor_metadata_array,
|
||||
Version* min_version) {
|
||||
if (tensor_metadata_array == nullptr) return;
|
||||
|
||||
for (int i = 0; i < tensor_metadata_array->size(); ++i) {
|
||||
UpdateMinimumVersionForTensorMetadata(tensor_metadata_array->Get(i),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForSubGraphMetadata(
|
||||
const tflite::SubGraphMetadata* subgraph_metadata, Version* min_version) {
|
||||
if (subgraph_metadata == nullptr) return;
|
||||
|
||||
// Checks in the input/output metadata arrays.
|
||||
UpdateMinimumVersionForTensorMetadataArray(
|
||||
subgraph_metadata->input_tensor_metadata(), min_version);
|
||||
UpdateMinimumVersionForTensorMetadataArray(
|
||||
subgraph_metadata->output_tensor_metadata(), min_version);
|
||||
|
||||
// Checks the associated_files field.
|
||||
UpdateMinimumVersionForAssociatedFileArray(
|
||||
subgraph_metadata->associated_files(), min_version);
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForModelMetadata(
|
||||
const tflite::ModelMetadata& model_metadata, Version* min_version) {
|
||||
// Checks the subgraph_metadata field.
|
||||
if (model_metadata.subgraph_metadata() != nullptr) {
|
||||
for (int i = 0; i < model_metadata.subgraph_metadata()->size(); ++i) {
|
||||
UpdateMinimumVersionForSubGraphMetadata(
|
||||
model_metadata.subgraph_metadata()->Get(i), min_version);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the associated_files field.
|
||||
UpdateMinimumVersionForAssociatedFileArray(model_metadata.associated_files(),
|
||||
min_version);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
|
||||
size_t buffer_size,
|
||||
std::string* min_version) {
|
||||
std::string* min_version_str) {
|
||||
flatbuffers::Verifier verifier =
|
||||
flatbuffers::Verifier(buffer_data, buffer_size);
|
||||
if (!tflite::VerifyModelMetadataBuffer(verifier)) {
|
||||
@ -31,18 +186,27 @@ TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Returns the version as the initial default one, "1.0.0", because it is the
|
||||
// first version ever for metadata_schema.fbs.
|
||||
//
|
||||
// Later, when new fields are added to the schema, we'll update the logic of
|
||||
// getting the minimum metadata parser version. To be more specific, we'll
|
||||
// have a table that records the new fields and the versions of the schema
|
||||
// they are added to. And the minimum metadata parser version will be the
|
||||
// largest version number of all fields that has been added to a metadata
|
||||
// flatbuffer.
|
||||
// TODO(b/156539454): replace the hardcoded version with template + genrule.
|
||||
static constexpr char kDefaultVersion[] = "1.0.0";
|
||||
*min_version = kDefaultVersion;
|
||||
Version min_version = Version(kDefaultVersion);
|
||||
|
||||
// Checks if any member declared after 1.0.0 (such as those in
|
||||
// SchemaMembers) exists, and updates min_version accordingly. The minimum
|
||||
// metadata parser version will be the largest version number of all fields
|
||||
// that has been added to a metadata flatbuffer
|
||||
const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
|
||||
|
||||
// All tables in the metadata schema should have their dedicated
|
||||
// UpdateMinimumVersionFor**() methods, respectively. We'll gradually add
|
||||
// these methods when new fields show up in later schema versions.
|
||||
//
|
||||
// UpdateMinimumVersionFor<Foo>() takes a const pointer of Foo. The pointer
|
||||
// can be a nullptr if Foo is not populated into the corresponding table of
|
||||
// the Flatbuffer object. In this case, UpdateMinimumVersionFor<Foo>() will be
|
||||
// skipped. An exception is UpdateMinimumVersionForModelMetadata(), where
|
||||
// ModelMetadata is the root table, and it won't be null.
|
||||
UpdateMinimumVersionForModelMetadata(*model_metadata, &min_version);
|
||||
|
||||
*min_version_str = min_version.ToString();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@ -15,6 +15,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
@ -24,6 +26,7 @@ namespace metadata {
|
||||
namespace {
|
||||
|
||||
using ::testing::MatchesRegex;
|
||||
using ::testing::StrEq;
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionSucceedsWithValidMetadata) {
|
||||
@ -45,7 +48,7 @@ TEST(MetadataVersionTest,
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionSucceedsWithInvalidIdentifier) {
|
||||
GetMinimumMetadataParserVersionFailsWithInvalidIdentifier) {
|
||||
// Creates a dummy metadata flatbuffer without identifier.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
@ -60,6 +63,125 @@ TEST(MetadataVersionTest,
|
||||
EXPECT_TRUE(min_version.empty());
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForModelMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// ModelMetadata.associated_fiels, populated with the vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_associated_files(associated_files);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_THAT(min_version, StrEq("1.0.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForSubGraphMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// SubGraphMetadata.associated_fiels, populated with the vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_associated_files(associated_files);
|
||||
auto subgraphs =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
|
||||
subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_THAT(min_version, StrEq("1.0.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForInputMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// SubGraphMetadata.input_tensor_metadata.associated_fiels, populated with the
|
||||
// vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
TensorMetadataBuilder tensor_builder(builder);
|
||||
tensor_builder.add_associated_files(associated_files);
|
||||
auto tensors =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
|
||||
tensor_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_input_tensor_metadata(tensors);
|
||||
auto subgraphs =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
|
||||
subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_THAT(min_version, StrEq("1.0.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForOutputMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// SubGraphMetadata.output_tensor_metadata.associated_fiels, populated with
|
||||
// the vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
TensorMetadataBuilder tensor_builder(builder);
|
||||
tensor_builder.add_associated_files(associated_files);
|
||||
auto tensors =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
|
||||
tensor_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_output_tensor_metadata(tensors);
|
||||
auto subgraphs =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
|
||||
subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_EQ(min_version, "1.0.1");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
||||
|
||||
@ -55,7 +55,7 @@ public class MetadataExtractor {
|
||||
// TODO(b/156539454): remove the hardcode versioning number and populate the version through
|
||||
// genrule.
|
||||
/** The version of the metadata parser that this {@link MetadataExtractor} library depends on. */
|
||||
public static final String METADATA_PARSER_VERSION = "1.0.0";
|
||||
public static final String METADATA_PARSER_VERSION = "1.0.1";
|
||||
|
||||
/** The helper class to load metadata from TFLite model FlatBuffer. */
|
||||
private final ModelInfo modelInfo;
|
||||
|
||||
@ -40,23 +40,28 @@ namespace tflite;
|
||||
// New fields and types will have associated comments with the schema version for
|
||||
// which they were added.
|
||||
//
|
||||
// Schema Semantic version: 1.0.0
|
||||
// Schema Semantic version: 1.0.1
|
||||
|
||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||
// break change is applied to the schema, such as removing fields or adding new
|
||||
// fields to the middle of a table.
|
||||
file_identifier "M001";
|
||||
|
||||
// History:
|
||||
// 1.0.1 - Added VOCABULARY type to AssociatedFileType.
|
||||
|
||||
// File extension of any written files.
|
||||
file_extension "tflitemeta";
|
||||
|
||||
// LINT.ThenChange(//tensorflow/lite/experimental/\
|
||||
// /supportmetadata/java/src/java/org/tensorflow/lite/support/metadata/\
|
||||
// /support/metadata/java/src/java/org/tensorflow/lite/support/metadata/\
|
||||
// MetadataExtractor.java)
|
||||
|
||||
// LINT.IfChange
|
||||
enum AssociatedFileType : byte {
|
||||
UNKNOWN = 0,
|
||||
// Files such as readme.txt
|
||||
|
||||
// Files such as readme.txt.
|
||||
DESCRIPTIONS = 1,
|
||||
|
||||
// Contains labels that annotate certain axis of the tensor. For example,
|
||||
@ -98,6 +103,11 @@ enum AssociatedFileType : byte {
|
||||
//
|
||||
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
|
||||
TENSOR_AXIS_SCORE_CALIBRATION = 4,
|
||||
|
||||
// Contains a list of unique words (characters separated by "\n" or in lines)
|
||||
// that help to convert natural language words to embedding vectors.
|
||||
// Added in: 1.0.1
|
||||
VOCABULARY = 5,
|
||||
}
|
||||
|
||||
table AssociatedFile {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user