Export model's minimum required runtime string into flatbuffer's metadata field. Also add a method in FlatBufferModel to parse the runtime string.
PiperOrigin-RevId: 261418624
This commit is contained in:
parent
7ebd2fb017
commit
a3bdb597ca
@ -328,6 +328,7 @@ cc_test(
|
||||
"testdata/2_subgraphs.bin",
|
||||
"testdata/empty_model.bin",
|
||||
"testdata/multi_add_flex.bin",
|
||||
"testdata/test_min_runtime.bin",
|
||||
"testdata/test_model.bin",
|
||||
"testdata/test_model_broken.bin",
|
||||
],
|
||||
|
||||
@ -159,6 +159,22 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
|
||||
return model;
|
||||
}
|
||||
|
||||
string FlatBufferModel::GetMinimumRuntime() const {
|
||||
if (!model_ || !model_->metadata()) return "";
|
||||
|
||||
for (int i = 0; i < model_->metadata()->size(); ++i) {
|
||||
auto metadata = model_->metadata()->Get(i);
|
||||
if (metadata->name()->str() == "min_runtime_version") {
|
||||
auto buf = metadata->buffer();
|
||||
auto* buffer = (*model_->buffers())[buf];
|
||||
auto* array = buffer->data();
|
||||
return string(reinterpret_cast<const char*>(array->data()),
|
||||
array->size());
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool FlatBufferModel::CheckModelIdentifier() const {
|
||||
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
|
||||
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
|
||||
|
||||
@ -135,6 +135,12 @@ class FlatBufferModel {
|
||||
ErrorReporter* error_reporter() const { return error_reporter_; }
|
||||
const Allocation* allocation() const { return allocation_.get(); }
|
||||
|
||||
// Returns the minimum runtime version from the flatbuffer. This runtime
|
||||
// version encodes the minimum required interpreter version to run the
|
||||
// flatbuffer model. If the minimum version can't be determined, an empty
|
||||
// string will be returned.
|
||||
string GetMinimumRuntime() const;
|
||||
|
||||
/// Returns true if the model identifier is correct (otherwise false and
|
||||
/// reports an error).
|
||||
bool CheckModelIdentifier() const;
|
||||
|
||||
@ -315,6 +315,22 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) {
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
}
|
||||
|
||||
// Test reading the minimum runtime string from metadata in a Model flatbuffer.
|
||||
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
|
||||
// First read a model that doesn't have the runtime string.
|
||||
auto model1 = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/test_model.bin");
|
||||
ASSERT_TRUE(model1);
|
||||
ASSERT_EQ(model1->GetMinimumRuntime(), "");
|
||||
|
||||
// Read a model that has minimum runtime string populated.
|
||||
auto model2 = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/test_min_runtime.bin");
|
||||
ASSERT_TRUE(model2);
|
||||
// Check that we have read the runtime string correctly.
|
||||
ASSERT_EQ(model2->GetMinimumRuntime(), "1.10.0");
|
||||
}
|
||||
|
||||
// TODO(aselle): Add tests for serialization of builtin op data types.
|
||||
// These tests will occur with the evaluation tests of individual operators,
|
||||
// not here.
|
||||
|
||||
BIN
tensorflow/lite/testdata/test_min_runtime.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/test_min_runtime.bin
vendored
Normal file
Binary file not shown.
@ -89,6 +89,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":op_version",
|
||||
":operator",
|
||||
":types",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
@ -108,9 +109,11 @@ tf_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":export",
|
||||
":operator",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/toco/tflite/op_version.h"
|
||||
#include "tensorflow/lite/toco/tflite/operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/types.h"
|
||||
#include "tensorflow/lite/toco/tooling_util.h"
|
||||
@ -38,9 +39,11 @@ using ::tflite::BuiltinOperator_CUSTOM;
|
||||
using ::tflite::BuiltinOperator_MAX;
|
||||
using ::tflite::BuiltinOperator_MIN;
|
||||
using ::tflite::CreateBuffer;
|
||||
using ::tflite::CreateMetadata;
|
||||
using ::tflite::CreateModel;
|
||||
using ::tflite::CreateOperator;
|
||||
using ::tflite::CreateTensor;
|
||||
using ::tflite::Metadata;
|
||||
using ::tflite::Operator;
|
||||
using ::tflite::OperatorCode;
|
||||
using ::tflite::SubGraph;
|
||||
@ -456,6 +459,17 @@ void ParseControlFlowErrors(std::set<string>* custom_ops,
|
||||
}
|
||||
}
|
||||
|
||||
// Exports a string buffer that contains the model's minimum required runtime
|
||||
// version.
|
||||
void ExportModelVersionBuffer(
|
||||
const Model& model, std::vector<Offset<Vector<uint8_t>>>* buffers_to_write,
|
||||
FlatBufferBuilder* builder) {
|
||||
const std::string min_runtime = GetMinimumRuntimeVersionForModel(model);
|
||||
buffers_to_write->push_back(builder->CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(min_runtime.data()),
|
||||
min_runtime.size()));
|
||||
}
|
||||
|
||||
tensorflow::Status Export(
|
||||
const Model& model, string* output_file_contents,
|
||||
const ExportParams& params,
|
||||
@ -612,11 +626,20 @@ tensorflow::Status Export(
|
||||
"not implemented yet.");
|
||||
}
|
||||
|
||||
// Write the minimum required runtime version into metadata.
|
||||
auto metadata =
|
||||
CreateMetadata(builder, builder.CreateString("min_runtime_version"),
|
||||
buffers_to_write.size());
|
||||
ExportModelVersionBuffer(model, &buffers_to_write, &builder);
|
||||
std::vector<flatbuffers::Offset<Metadata>> metadatas = {metadata};
|
||||
|
||||
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
|
||||
auto description = builder.CreateString("TOCO Converted.");
|
||||
|
||||
auto new_model_location =
|
||||
CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
|
||||
builder.CreateVector(subgraphs), description, buffers);
|
||||
builder.CreateVector(subgraphs), description, buffers,
|
||||
/* metadata_buffer */ 0, builder.CreateVector(metadatas));
|
||||
::tflite::FinishModelBuffer(builder, new_model_location);
|
||||
|
||||
if (params.quantize_weights == QuantizedBufferType::NONE) {
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
@ -245,6 +246,44 @@ TEST_F(ExportTest, Export) {
|
||||
EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, ExportMinRuntime) {
|
||||
AddOperatorsByName({"Conv", "Add", "Sub"});
|
||||
|
||||
ExportParams params;
|
||||
params.allow_custom_ops = true;
|
||||
params.enable_select_tf_ops = false;
|
||||
params.quantize_weights = QuantizedBufferType::NONE;
|
||||
|
||||
string output;
|
||||
auto status = Export(input_model_, &output, params);
|
||||
auto* model = ::tflite::GetModel(output.data());
|
||||
EXPECT_EQ(model->metadata()->size(), 1);
|
||||
EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
|
||||
auto buf = model->metadata()->Get(0)->buffer();
|
||||
auto* buffer = (*model->buffers())[buf];
|
||||
auto* array = buffer->data();
|
||||
string version(reinterpret_cast<const char*>(array->data()), array->size());
|
||||
EXPECT_EQ(version, "1.6.0");
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, ExportEmptyMinRuntime) {
|
||||
AddOperatorsByName({"Switch", "MyCustomOp", "Assert"});
|
||||
|
||||
ExportParams params;
|
||||
params.allow_custom_ops = true;
|
||||
|
||||
string output;
|
||||
auto status = Export(input_model_, &output, params);
|
||||
auto* model = ::tflite::GetModel(output.data());
|
||||
EXPECT_EQ(model->metadata()->size(), 1);
|
||||
EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
|
||||
auto buf = model->metadata()->Get(0)->buffer();
|
||||
auto* buffer = (*model->buffers())[buf];
|
||||
auto* array = buffer->data();
|
||||
string version(reinterpret_cast<const char*>(array->data()), array->size());
|
||||
EXPECT_EQ(version, "");
|
||||
}
|
||||
|
||||
TEST_F(ExportTest, UnsupportedControlFlowErrors) {
|
||||
AddOperatorsByName({"Conv", "Add", "Switch", "Merge"});
|
||||
|
||||
@ -532,7 +571,7 @@ class VersionedOpExportTest : public ::testing::Test {
|
||||
auto* op = new ConvOperator;
|
||||
op->inputs.push_back("input");
|
||||
op->inputs.push_back("filter");
|
||||
op->inputs.push_back("output");
|
||||
op->outputs.push_back("output");
|
||||
|
||||
op->padding.type = PaddingType::kSame;
|
||||
op->stride_width = 1;
|
||||
|
||||
@ -183,6 +183,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
op_signature.model = &model;
|
||||
string model_min_version;
|
||||
for (const auto& op : model.operators) {
|
||||
if (op_types_map.find(op->type) == op_types_map.end()) continue;
|
||||
op_signature.op = op.get();
|
||||
const int version = op_types_map.at(op->type)->GetVersion(op_signature);
|
||||
std::pair<OperatorType, int> version_key = {op->type, version};
|
||||
|
||||
@ -20,10 +20,10 @@ limitations under the License.
|
||||
namespace toco {
|
||||
namespace tflite {
|
||||
|
||||
// Get the minimum TF Lite runtime required to run a model. Each operator in
|
||||
// the model will have its own minimum requirement of a runtime, and the model's
|
||||
// minimum requirement of runtime is defined as the maximum of all the
|
||||
// operators' minimum runtime.
|
||||
// Get the minimum TF Lite runtime required to run a model. Each built-in
|
||||
// operator in the model will have its own minimum requirement of a runtime, and
|
||||
// the model's minimum requirement of runtime is defined as the maximum of all
|
||||
// the built-in operators' minimum runtime.
|
||||
std::string GetMinimumRuntimeVersionForModel(const Model& model);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user