From a3bdb597ca28b61023a08d64f04f9b6924ad3c28 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Fri, 2 Aug 2019 16:59:17 -0700 Subject: [PATCH] Export model's minimum required runtime string into flatbuffer's metadata field. Also add a method in FlatBufferModel to parse the runtime string. PiperOrigin-RevId: 261418624 --- tensorflow/lite/BUILD | 1 + tensorflow/lite/model.cc | 16 +++++++ tensorflow/lite/model.h | 6 +++ tensorflow/lite/model_test.cc | 16 +++++++ tensorflow/lite/testdata/test_min_runtime.bin | Bin 0 -> 580 bytes tensorflow/lite/toco/tflite/BUILD | 3 ++ tensorflow/lite/toco/tflite/export.cc | 25 ++++++++++- tensorflow/lite/toco/tflite/export_test.cc | 41 +++++++++++++++++- tensorflow/lite/toco/tflite/op_version.cc | 1 + tensorflow/lite/toco/tflite/op_version.h | 8 ++-- 10 files changed, 111 insertions(+), 6 deletions(-) create mode 100644 tensorflow/lite/testdata/test_min_runtime.bin diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 853ba3d473c..e353edd121e 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -328,6 +328,7 @@ cc_test( "testdata/2_subgraphs.bin", "testdata/empty_model.bin", "testdata/multi_add_flex.bin", + "testdata/test_min_runtime.bin", "testdata/test_model.bin", "testdata/test_model_broken.bin", ], diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc index 2281835e55b..516ba693738 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -159,6 +159,22 @@ std::unique_ptr FlatBufferModel::BuildFromModel( return model; } +string FlatBufferModel::GetMinimumRuntime() const { + if (!model_ || !model_->metadata()) return ""; + + for (int i = 0; i < model_->metadata()->size(); ++i) { + auto metadata = model_->metadata()->Get(i); + if (metadata->name()->str() == "min_runtime_version") { + auto buf = metadata->buffer(); + auto* buffer = (*model_->buffers())[buf]; + auto* array = buffer->data(); + return string(reinterpret_cast(array->data()), + array->size()); + } + } + return ""; +} + bool FlatBufferModel::CheckModelIdentifier() const { if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h index 6c569470f34..06dd2e294f8 100644 --- a/tensorflow/lite/model.h +++ b/tensorflow/lite/model.h @@ -135,6 +135,12 @@ class FlatBufferModel { ErrorReporter* error_reporter() const { return error_reporter_; } const Allocation* allocation() const { return allocation_.get(); } + // Returns the minimum runtime version from the flatbuffer. This runtime + // version encodes the minimum required interpreter version to run the + // flatbuffer model. If the minimum version can't be determined, an empty + // string will be returned. + string GetMinimumRuntime() const; + /// Returns true if the model identifier is correct (otherwise false and /// reports an error). bool CheckModelIdentifier() const; diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc index d58dbf4d45f..7dc582b8862 100644 --- a/tensorflow/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -315,6 +315,22 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) { ASSERT_NE(interpreter, nullptr); } +// Test reading the minimum runtime string from metadata in a Model flatbuffer. +TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) { + // First read a model that doesn't have the runtime string. + auto model1 = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/test_model.bin"); + ASSERT_TRUE(model1); + ASSERT_EQ(model1->GetMinimumRuntime(), ""); + + // Read a model that has minimum runtime string populated. + auto model2 = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/test_min_runtime.bin"); + ASSERT_TRUE(model2); + // Check that we have read the runtime string correctly. + ASSERT_EQ(model2->GetMinimumRuntime(), "1.10.0"); +} + // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, // not here. diff --git a/tensorflow/lite/testdata/test_min_runtime.bin b/tensorflow/lite/testdata/test_min_runtime.bin new file mode 100644 index 0000000000000000000000000000000000000000..c68174390dea75e6fbfbf67d7ee18cc82a3eadd8 GIT binary patch literal 580 zcmb1PU| z7#J8C7#K33^aKV527Zt^{?7gi&iQ#|sYNBJDS8a73=9k=3=9k!VAB{F1fY7~{r~@8 zfq{V`;@|)OKmPsy&%nUOz`?-6z{bG9;K0bhQ1I{ne-PVH&(J{6fPtal|Ns9SP`$zo z3=FxMdGST1c_o>-sqtm0Ma7x)ZB5>F9FfcHH%m?`br0)yVolH;~L!plNRJ8w0|Pe$14BSgVsdImeojhi5!lZlGeBVk3Mr5}n4dsqGBDIT zLudvD1_nN`ix{9G$_EVth<*le2!rejF3klg2Kycqt{{C93>*yX(0BmpV_;xl0);)u zY>=H23=H6~bnGcAe%wq!oa{F4Kfqtevn_0 f^?}?7!XW!Vj$>f}L5P{4u!Y1IJ2<`=7#O?&*oiu! literal 0 HcmV?d00001 diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 01850bf68bb..f27f0f999da 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -89,6 +89,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":op_version", ":operator", ":types", "//tensorflow/lite:schema_fbs_version", @@ -108,9 +109,11 @@ tf_cc_test( ], deps = [ ":export", + ":operator", "//tensorflow/core:ops", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", + "@flatbuffers", ], ) diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index c32466bc1f3..227c6aada89 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/lite/context.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/tflite/op_version.h" #include "tensorflow/lite/toco/tflite/operator.h" #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/tooling_util.h" @@ -38,9 +39,11 @@ using ::tflite::BuiltinOperator_CUSTOM; using ::tflite::BuiltinOperator_MAX; using ::tflite::BuiltinOperator_MIN; using ::tflite::CreateBuffer; +using ::tflite::CreateMetadata; using ::tflite::CreateModel; using ::tflite::CreateOperator; using ::tflite::CreateTensor; +using ::tflite::Metadata; using ::tflite::Operator; using ::tflite::OperatorCode; using ::tflite::SubGraph; @@ -456,6 +459,17 @@ void ParseControlFlowErrors(std::set* custom_ops, } } +// Exports a string buffer that contains the model's minimum required runtime +// version. +void ExportModelVersionBuffer( + const Model& model, std::vector>>* buffers_to_write, + FlatBufferBuilder* builder) { + const std::string min_runtime = GetMinimumRuntimeVersionForModel(model); + buffers_to_write->push_back(builder->CreateVector( + reinterpret_cast(min_runtime.data()), + min_runtime.size())); +} + tensorflow::Status Export( const Model& model, string* output_file_contents, const ExportParams& params, @@ -612,11 +626,20 @@ tensorflow::Status Export( "not implemented yet."); } + // Write the minimum required runtime version into metadata. + auto metadata = + CreateMetadata(builder, builder.CreateString("min_runtime_version"), + buffers_to_write.size()); + ExportModelVersionBuffer(model, &buffers_to_write, &builder); + std::vector> metadatas = {metadata}; + auto buffers = ExportBuffers(model, buffers_to_write, &builder); auto description = builder.CreateString("TOCO Converted."); + auto new_model_location = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, - builder.CreateVector(subgraphs), description, buffers); + builder.CreateVector(subgraphs), description, buffers, + /* metadata_buffer */ 0, builder.CreateVector(metadatas)); ::tflite::FinishModelBuffer(builder, new_model_location); if (params.quantize_weights == QuantizedBufferType::NONE) { diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index 0ae6104f8f9..bbb1c557f4d 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -245,6 +246,44 @@ TEST_F(ExportTest, Export) { EXPECT_THAT(ExportAndGetOperatorIndices(params), ElementsAre(1, 0, 2, 3)); } +TEST_F(ExportTest, ExportMinRuntime) { + AddOperatorsByName({"Conv", "Add", "Sub"}); + + ExportParams params; + params.allow_custom_ops = true; + params.enable_select_tf_ops = false; + params.quantize_weights = QuantizedBufferType::NONE; + + string output; + auto status = Export(input_model_, &output, params); + auto* model = ::tflite::GetModel(output.data()); + EXPECT_EQ(model->metadata()->size(), 1); + EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version"); + auto buf = model->metadata()->Get(0)->buffer(); + auto* buffer = (*model->buffers())[buf]; + auto* array = buffer->data(); + string version(reinterpret_cast(array->data()), array->size()); + EXPECT_EQ(version, "1.6.0"); +} + +TEST_F(ExportTest, ExportEmptyMinRuntime) { + AddOperatorsByName({"Switch", "MyCustomOp", "Assert"}); + + ExportParams params; + params.allow_custom_ops = true; + + string output; + auto status = Export(input_model_, &output, params); + auto* model = ::tflite::GetModel(output.data()); + EXPECT_EQ(model->metadata()->size(), 1); + EXPECT_EQ(model->metadata()->Get(0)->name()->str(), "min_runtime_version"); + auto buf = model->metadata()->Get(0)->buffer(); + auto* buffer = (*model->buffers())[buf]; + auto* array = buffer->data(); + string version(reinterpret_cast(array->data()), array->size()); + EXPECT_EQ(version, ""); +} + TEST_F(ExportTest, UnsupportedControlFlowErrors) { AddOperatorsByName({"Conv", "Add", "Switch", "Merge"}); @@ -532,7 +571,7 @@ class VersionedOpExportTest : public ::testing::Test { auto* op = new ConvOperator; op->inputs.push_back("input"); op->inputs.push_back("filter"); - op->inputs.push_back("output"); + op->outputs.push_back("output"); op->padding.type = PaddingType::kSame; op->stride_width = 1; diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index f83edc87167..09ffa8c30fb 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -183,6 +183,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { op_signature.model = &model; string model_min_version; for (const auto& op : model.operators) { + if (op_types_map.find(op->type) == op_types_map.end()) continue; op_signature.op = op.get(); const int version = op_types_map.at(op->type)->GetVersion(op_signature); std::pair version_key = {op->type, version}; diff --git a/tensorflow/lite/toco/tflite/op_version.h b/tensorflow/lite/toco/tflite/op_version.h index 9c2b16723cc..54a77501b14 100644 --- a/tensorflow/lite/toco/tflite/op_version.h +++ b/tensorflow/lite/toco/tflite/op_version.h @@ -20,10 +20,10 @@ limitations under the License. namespace toco { namespace tflite { -// Get the minimum TF Lite runtime required to run a model. Each operator in -// the model will have its own minimum requirement of a runtime, and the model's -// minimum requirement of runtime is defined as the maximum of all the -// operators' minimum runtime. +// Get the minimum TF Lite runtime required to run a model. Each built-in +// operator in the model will have its own minimum requirement of a runtime, and +// the model's minimum requirement of runtime is defined as the maximum of all +// the built-in operators' minimum runtime. std::string GetMinimumRuntimeVersionForModel(const Model& model); } // namespace tflite