Export model's minimum required runtime string into flatbuffer's metadata field. Also add a method in FlatBufferModel to parse the runtime string.

PiperOrigin-RevId: 261418624
This commit is contained in:
Haoliang Zhang 2019-08-02 16:59:17 -07:00 committed by TensorFlower Gardener
parent 7ebd2fb017
commit a3bdb597ca
10 changed files with 111 additions and 6 deletions

View File

@ -328,6 +328,7 @@ cc_test(
"testdata/2_subgraphs.bin",
"testdata/empty_model.bin",
"testdata/multi_add_flex.bin",
"testdata/test_min_runtime.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
],

View File

@ -159,6 +159,22 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
return model;
}
string FlatBufferModel::GetMinimumRuntime() const {
if (!model_ || !model_->metadata()) return "";
for (int i = 0; i < model_->metadata()->size(); ++i) {
auto metadata = model_->metadata()->Get(i);
if (metadata->name()->str() == "min_runtime_version") {
auto buf = metadata->buffer();
auto* buffer = (*model_->buffers())[buf];
auto* array = buffer->data();
return string(reinterpret_cast<const char*>(array->data()),
array->size());
}
}
return "";
}
bool FlatBufferModel::CheckModelIdentifier() const {
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());

View File

@ -135,6 +135,12 @@ class FlatBufferModel {
ErrorReporter* error_reporter() const { return error_reporter_; }
const Allocation* allocation() const { return allocation_.get(); }
// Returns the minimum runtime version from the flatbuffer. This runtime
// version encodes the minimum required interpreter version to run the
// flatbuffer model. If the minimum version can't be determined, an empty
// string will be returned.
string GetMinimumRuntime() const;
/// Returns true if the model identifier is correct (otherwise false and
/// reports an error).
bool CheckModelIdentifier() const;

View File

@ -315,6 +315,22 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) {
ASSERT_NE(interpreter, nullptr);
}
// Test reading the minimum runtime string from metadata in a Model flatbuffer.
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
// First read a model that doesn't have the runtime string.
auto model1 = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/test_model.bin");
ASSERT_TRUE(model1);
ASSERT_EQ(model1->GetMinimumRuntime(), "");
// Read a model that has minimum runtime string populated.
auto model2 = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/test_min_runtime.bin");
ASSERT_TRUE(model2);
// Check that we have read the runtime string correctly.
ASSERT_EQ(model2->GetMinimumRuntime(), "1.10.0");
}
// TODO(aselle): Add tests for serialization of builtin op data types.
// These tests will occur with the evaluation tests of individual operators,
// not here.

Binary file not shown.

View File

@ -89,6 +89,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
":op_version",
":operator",
":types",
"//tensorflow/lite:schema_fbs_version",
@ -108,9 +109,11 @@ tf_cc_test(
],
deps = [
":export",
":operator",
"//tensorflow/core:ops",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/tflite/op_version.h"
#include "tensorflow/lite/toco/tflite/operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
#include "tensorflow/lite/toco/tooling_util.h"
@ -38,9 +39,11 @@ using ::tflite::BuiltinOperator_CUSTOM;
using ::tflite::BuiltinOperator_MAX;
using ::tflite::BuiltinOperator_MIN;
using ::tflite::CreateBuffer;
using ::tflite::CreateMetadata;
using ::tflite::CreateModel;
using ::tflite::CreateOperator;
using ::tflite::CreateTensor;
using ::tflite::Metadata;
using ::tflite::Operator;
using ::tflite::OperatorCode;
using ::tflite::SubGraph;
@ -456,6 +459,17 @@ void ParseControlFlowErrors(std::set<string>* custom_ops,
}
}
// Exports a string buffer that contains the model's minimum required runtime
// version.
void ExportModelVersionBuffer(
const Model& model, std::vector<Offset<Vector<uint8_t>>>* buffers_to_write,
FlatBufferBuilder* builder) {
const std::string min_runtime = GetMinimumRuntimeVersionForModel(model);
buffers_to_write->push_back(builder->CreateVector(
reinterpret_cast<const uint8_t*>(min_runtime.data()),
min_runtime.size()));
}
tensorflow::Status Export(
const Model& model, string* output_file_contents,
const ExportParams& params,
@ -612,11 +626,20 @@ tensorflow::Status Export(
"not implemented yet.");
}
// Write the minimum required runtime version into metadata.
auto metadata =
CreateMetadata(builder, builder.CreateString("min_runtime_version"),
buffers_to_write.size());
ExportModelVersionBuffer(model, &buffers_to_write, &builder);
std::vector<flatbuffers::Offset<Metadata>> metadatas = {metadata};
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
auto description = builder.CreateString("TOCO Converted.");
auto new_model_location =
CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs), description, buffers);
builder.CreateVector(subgraphs), description, buffers,
/* metadata_buffer */ 0, builder.CreateVector(metadatas));
::tflite::FinishModelBuffer(builder, new_model_location);
if (params.quantize_weights == QuantizedBufferType::NONE) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -245,6 +246,44 @@ TEST_F(ExportTest, Export) {
EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3));
}
TEST_F(ExportTest, ExportMinRuntime) {
AddOperatorsByName({"Conv", "Add", "Sub"});
ExportParams params;
params.allow_custom_ops = true;
params.enable_select_tf_ops = false;
params.quantize_weights = QuantizedBufferType::NONE;
string output;
auto status = Export(input_model_, &output, params);
auto* model = ::tflite::GetModel(output.data());
EXPECT_EQ(model->metadata()->size(), 1);
EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
auto buf = model->metadata()->Get(0)->buffer();
auto* buffer = (*model->buffers())[buf];
auto* array = buffer->data();
string version(reinterpret_cast<const char*>(array->data()), array->size());
EXPECT_EQ(version, "1.6.0");
}
TEST_F(ExportTest, ExportEmptyMinRuntime) {
AddOperatorsByName({"Switch", "MyCustomOp", "Assert"});
ExportParams params;
params.allow_custom_ops = true;
string output;
auto status = Export(input_model_, &output, params);
auto* model = ::tflite::GetModel(output.data());
EXPECT_EQ(model->metadata()->size(), 1);
EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version");
auto buf = model->metadata()->Get(0)->buffer();
auto* buffer = (*model->buffers())[buf];
auto* array = buffer->data();
string version(reinterpret_cast<const char*>(array->data()), array->size());
EXPECT_EQ(version, "");
}
TEST_F(ExportTest, UnsupportedControlFlowErrors) {
AddOperatorsByName({"Conv", "Add", "Switch", "Merge"});
@ -532,7 +571,7 @@ class VersionedOpExportTest : public ::testing::Test {
auto* op = new ConvOperator;
op->inputs.push_back("input");
op->inputs.push_back("filter");
op->inputs.push_back("output");
op->outputs.push_back("output");
op->padding.type = PaddingType::kSame;
op->stride_width = 1;

View File

@ -183,6 +183,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
op_signature.model = &model;
string model_min_version;
for (const auto& op : model.operators) {
if (op_types_map.find(op->type) == op_types_map.end()) continue;
op_signature.op = op.get();
const int version = op_types_map.at(op->type)->GetVersion(op_signature);
std::pair<OperatorType, int> version_key = {op->type, version};

View File

@ -20,10 +20,10 @@ limitations under the License.
namespace toco {
namespace tflite {
// Get the minimum TF Lite runtime required to run a model. Each operator in
// the model will have its own minimum requirement of a runtime, and the model's
// minimum requirement of runtime is defined as the maximum of all the
// operators' minimum runtime.
// Get the minimum TF Lite runtime required to run a model. Each built-in
// operator in the model will have its own minimum requirement of a runtime, and
// the model's minimum requirement of runtime is defined as the maximum of all
// the built-in operators' minimum runtime.
std::string GetMinimumRuntimeVersionForModel(const Model& model);
} // namespace tflite