From 0228bf0854610559b0027ca6d3460dd5dec6d0f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2019 09:09:21 -0700 Subject: [PATCH] Added features to gen_selected_ops. - Allow to put RegisterSelectedOps in a namespace. - Support versioned ops. - Support multiple input tflite models. PiperOrigin-RevId: 270053948 --- tensorflow/lite/build_def.bzl | 15 +++- .../testdata/test_model_versioned_ops.bin | Bin 0 -> 508 bytes tensorflow/lite/tools/BUILD | 1 + tensorflow/lite/tools/gen_op_registration.cc | 19 ++-- tensorflow/lite/tools/gen_op_registration.h | 11 ++- .../lite/tools/gen_op_registration_main.cc | 84 ++++++++++++------ .../lite/tools/gen_op_registration_test.cc | 27 +++++- 7 files changed, 114 insertions(+), 43 deletions(-) create mode 100644 tensorflow/lite/testdata/test_model_versioned_ops.bin diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 619f95fca2b..6a076e1758a 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -608,23 +608,30 @@ def gen_zipped_test_file(name, file, toco, flags): srcs = [file], ) -def gen_selected_ops(name, model): +def gen_selected_ops(name, model, namespace = "", **kwargs): """Generate the library that includes only used ops. Args: name: Name of the generated library. model: TFLite model to interpret. + namespace: Namespace in which to put RegisterSelectedOps. + **kwargs: Additional kwargs to pass to genrule. """ out = name + "_registration.cc" tool = "//tensorflow/lite/tools:generate_op_registrations" tflite_path = "//tensorflow/lite" + + # isinstance is not supported in skylark. + if type(model) != type([]): + model = [model] native.genrule( name = name, - srcs = [model], + srcs = model, outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") % - (tool, model, out, tflite_path[2:]), + cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") % + (tool, namespace, out, tflite_path[2:]), tools = [tool], + **kwargs ) def flex_dep(target_op_sets): diff --git a/tensorflow/lite/testdata/test_model_versioned_ops.bin b/tensorflow/lite/testdata/test_model_versioned_ops.bin new file mode 100644 index 0000000000000000000000000000000000000000..04aa014742b4769bbd7ebcc1adb0b0ac6a9ca9ea GIT binary patch literal 508 zcmZWlI}U7d}EV5k0^mnhCI{b%`mlB%}al)Ve%vOmRxd=(qQXHvCnMOkP9o zN!*FPX3D#?KtOAI20FlOB^!* literal 0 HcmV?d00001 diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 38fc69e8408..60acf3a514d 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -54,6 +54,7 @@ cc_test( "//tensorflow/lite:testdata/empty_model.bin", "//tensorflow/lite:testdata/test_model.bin", "//tensorflow/lite:testdata/test_model_broken.bin", + "//tensorflow/lite:testdata/test_model_versioned_ops.bin", ], tags = [ "tflite_not_portable_android", diff --git a/tensorflow/lite/tools/gen_op_registration.cc b/tensorflow/lite/tools/gen_op_registration.cc index ca66eef4660..be08b6e0d31 100644 --- a/tensorflow/lite/tools/gen_op_registration.cc +++ b/tensorflow/lite/tools/gen_op_registration.cc @@ -29,17 +29,26 @@ string NormalizeCustomOpName(const string& op) { } void ReadOpsFromModel(const ::tflite::Model* model, - std::vector* builtin_ops, - std::vector* custom_ops) { + tflite::RegisteredOpMap* builtin_ops, + tflite::RegisteredOpMap* custom_ops) { if (!model) return; auto opcodes = model->operator_codes(); if (!opcodes) return; for (const auto* opcode : *opcodes) { + const int version = opcode->version(); if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { - builtin_ops->push_back( - tflite::EnumNameBuiltinOperator(opcode->builtin_code())); + auto iter_and_bool = builtin_ops->insert(std::make_pair( + tflite::EnumNameBuiltinOperator(opcode->builtin_code()), + std::make_pair(version, version))); + auto& versions = iter_and_bool.first->second; + versions.first = std::min(versions.first, version); + versions.second = std::max(versions.second, version); } else { - custom_ops->push_back(opcode->custom_code()->c_str()); + auto iter_and_bool = custom_ops->insert(std::make_pair( + opcode->custom_code()->c_str(), std::make_pair(version, version))); + auto& versions = iter_and_bool.first->second; + versions.first = std::min(versions.first, version); + versions.second = std::max(versions.second, version); } } } diff --git a/tensorflow/lite/tools/gen_op_registration.h b/tensorflow/lite/tools/gen_op_registration.h index b01ede98292..edb4c98e9af 100644 --- a/tensorflow/lite/tools/gen_op_registration.h +++ b/tensorflow/lite/tools/gen_op_registration.h @@ -27,12 +27,15 @@ namespace tflite { // Note "Register_" suffix will be added later in the tool. string NormalizeCustomOpName(const string& op); +// A map from op name to {min_version, max_version}. +typedef std::map> RegisteredOpMap; + // Read ops from the TFLite model. -// Enum name of builtin ops will be stored, such as "CONV_2D". -// Custom op name will be stored as it is. +// The builtin ops key is the enum name of builtin ops, such as "CONV_2D". +// The custom ops key is stored as it is. void ReadOpsFromModel(const ::tflite::Model* model, - std::vector* builtin_ops, - std::vector* custom_ops); + RegisteredOpMap* builtin_ops, + RegisteredOpMap* custom_ops); } // namespace tflite diff --git a/tensorflow/lite/tools/gen_op_registration_main.cc b/tensorflow/lite/tools/gen_op_registration_main.cc index 090b709478d..796213846fc 100644 --- a/tensorflow/lite/tools/gen_op_registration_main.cc +++ b/tensorflow/lite/tools/gen_op_registration_main.cc @@ -21,11 +21,12 @@ limitations under the License. #include #include "absl/strings/strip.h" -#include "tensorflow/lite/tools/gen_op_registration.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/lite/tools/gen_op_registration.h" const char kInputModelFlag[] = "input_model"; +const char kNamespace[] = "namespace"; const char kOutputRegistrationFlag[] = "output_registration"; const char kTfLitePathFlag[] = "tflite_path"; @@ -33,25 +34,29 @@ using tensorflow::Flag; using tensorflow::Flags; using tensorflow::string; -void ParseFlagAndInit(int argc, char** argv, string* input_model, - string* output_registration, string* tflite_path) { +void ParseFlagAndInit(int* argc, char** argv, string* input_model, + string* output_registration, string* tflite_path, + string* namespace_flag) { std::vector flag_list = { Flag(kInputModelFlag, input_model, "path to the tflite model"), Flag(kOutputRegistrationFlag, output_registration, "filename for generated registration code"), Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"), + Flag(kNamespace, namespace_flag, + "Namespace in which to put RegisterSelectedOps."), }; - Flags::Parse(&argc, argv, flag_list); - tensorflow::port::InitMain(argv[0], &argc, &argv); + Flags::Parse(argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], argc, &argv); } namespace { void GenerateFileContent(const std::string& tflite_path, const std::string& filename, - const std::vector& builtin_ops, - const std::vector& custom_ops) { + const std::string& namespace_flag, + const tflite::RegisteredOpMap& builtin_ops, + const tflite::RegisteredOpMap& custom_ops) { std::ofstream fout(filename); fout << "#include \"" << tflite_path << "/model.h\"\n"; @@ -63,7 +68,7 @@ void GenerateFileContent(const std::string& tflite_path, fout << "namespace builtin {\n"; fout << "// Forward-declarations for the builtin ops.\n"; for (const auto& op : builtin_ops) { - fout << "TfLiteRegistration* Register_" << op << "();\n"; + fout << "TfLiteRegistration* Register_" << op.first << "();\n"; } fout << "} // namespace builtin\n"; } @@ -73,45 +78,72 @@ void GenerateFileContent(const std::string& tflite_path, fout << "// Forward-declarations for the custom ops.\n"; for (const auto& op : custom_ops) { fout << "TfLiteRegistration* Register_" - << ::tflite::NormalizeCustomOpName(op) << "();\n"; + << ::tflite::NormalizeCustomOpName(op.first) << "();\n"; } fout << "} // namespace custom\n"; } fout << "} // namespace ops\n"; fout << "} // namespace tflite\n"; + if (!namespace_flag.empty()) { + fout << "namespace " << namespace_flag << " {\n"; + } fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n"; for (const auto& op : builtin_ops) { - fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op - << ", ::tflite::ops::builtin::Register_" << op << "());\n"; + fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op.first + << ", ::tflite::ops::builtin::Register_" << op.first << "()"; + if (op.second.first != 1 || op.second.second != 1) { + fout << ", " << op.second.first << ", " << op.second.second; + } + fout << ");\n"; } for (const auto& op : custom_ops) { - fout << " resolver->AddCustom(\"" << op + fout << " resolver->AddCustom(\"" << op.first << "\", ::tflite::ops::custom::Register_" - << ::tflite::NormalizeCustomOpName(op) << "());\n"; + << ::tflite::NormalizeCustomOpName(op.first) << "()"; + if (op.second.first != 1 || op.second.second != 1) { + fout << ", " << op.second.first << ", " << op.second.second; + } + fout << ");\n"; } fout << "}\n"; + if (!namespace_flag.empty()) { + fout << "} // namespace " << namespace_flag << "\n"; + } fout.close(); } -} // namespace -int main(int argc, char** argv) { - string input_model; - string output_registration; - string tflite_path; - ParseFlagAndInit(argc, argv, &input_model, &output_registration, - &tflite_path); - - std::vector builtin_ops; - std::vector custom_ops; +void AddOpsFromModel(const string& input_model, + tflite::RegisteredOpMap* builtin_ops, + tflite::RegisteredOpMap* custom_ops) { std::ifstream fin(input_model); std::stringstream content; content << fin.rdbuf(); // Need to store content data first, otherwise, it won't work in bazel. string content_str = content.str(); const ::tflite::Model* model = ::tflite::GetModel(content_str.data()); - ::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops); - GenerateFileContent(tflite_path, output_registration, builtin_ops, - custom_ops); + ::tflite::ReadOpsFromModel(model, builtin_ops, custom_ops); +} + +} // namespace + +int main(int argc, char** argv) { + string input_model; + string output_registration; + string tflite_path; + string namespace_flag; + ParseFlagAndInit(&argc, argv, &input_model, &output_registration, + &tflite_path, &namespace_flag); + + tflite::RegisteredOpMap builtin_ops; + tflite::RegisteredOpMap custom_ops; + if (!input_model.empty()) { + AddOpsFromModel(input_model, &builtin_ops, &custom_ops); + } + for (int i = 1; i < argc; i++) { + AddOpsFromModel(argv[i], &builtin_ops, &custom_ops); + } + GenerateFileContent(tflite_path, output_registration, namespace_flag, + builtin_ops, custom_ops); return 0; } diff --git a/tensorflow/lite/tools/gen_op_registration_test.cc b/tensorflow/lite/tools/gen_op_registration_test.cc index 0ae91018ddf..e572d28d2e1 100644 --- a/tensorflow/lite/tools/gen_op_registration_test.cc +++ b/tensorflow/lite/tools/gen_op_registration_test.cc @@ -32,8 +32,8 @@ class GenOpRegistrationTest : public ::testing::Test { } } - std::vector builtin_ops_; - std::vector custom_ops_; + std::map> builtin_ops_; + std::map> custom_ops_; }; TEST_F(GenOpRegistrationTest, TestNonExistantFiles) { @@ -44,8 +44,27 @@ TEST_F(GenOpRegistrationTest, TestNonExistantFiles) { TEST_F(GenOpRegistrationTest, TestModels) { ReadOps("tensorflow/lite/testdata/test_model.bin"); - EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"})); - EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"})); + RegisteredOpMap builtin_expected{{"CONV_2D", {1, 1}}}; + RegisteredOpMap custom_expected{{"testing_op", {1, 1}}}; + EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected)); + EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected)); +} + +TEST_F(GenOpRegistrationTest, TestVersionedModels) { + ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin"); + RegisteredOpMap builtin_expected{{"CONV_2D", {3, 3}}}; + RegisteredOpMap custom_expected{{"testing_op", {2, 2}}}; + EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected)); + EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected)); +} + +TEST_F(GenOpRegistrationTest, TestBothModels) { + ReadOps("tensorflow/lite/testdata/test_model.bin"); + ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin"); + RegisteredOpMap builtin_expected{{"CONV_2D", {1, 3}}}; + RegisteredOpMap custom_expected{{"testing_op", {1, 2}}}; + EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected)); + EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected)); } TEST_F(GenOpRegistrationTest, TestEmptyModels) {