Added features to gen_selected_ops.

- Allow to  put RegisterSelectedOps in a namespace.
- Support versioned ops.
- Support multiple input tflite models.

PiperOrigin-RevId: 270053948
This commit is contained in:
A. Unique TensorFlower 2019-09-19 09:09:21 -07:00 committed by TensorFlower Gardener
parent dfef3d744b
commit 0228bf0854
7 changed files with 114 additions and 43 deletions

View File

@ -608,23 +608,30 @@ def gen_zipped_test_file(name, file, toco, flags):
srcs = [file],
)
def gen_selected_ops(name, model):
def gen_selected_ops(name, model, namespace = "", **kwargs):
"""Generate the library that includes only used ops.
Args:
name: Name of the generated library.
model: TFLite model to interpret.
namespace: Namespace in which to put RegisterSelectedOps.
**kwargs: Additional kwargs to pass to genrule.
"""
out = name + "_registration.cc"
tool = "//tensorflow/lite/tools:generate_op_registrations"
tflite_path = "//tensorflow/lite"
# isinstance is not supported in skylark.
if type(model) != type([]):
model = [model]
native.genrule(
name = name,
srcs = [model],
srcs = model,
outs = [out],
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
(tool, model, out, tflite_path[2:]),
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") %
(tool, namespace, out, tflite_path[2:]),
tools = [tool],
**kwargs
)
def flex_dep(target_op_sets):

Binary file not shown.

View File

@ -54,6 +54,7 @@ cc_test(
"//tensorflow/lite:testdata/empty_model.bin",
"//tensorflow/lite:testdata/test_model.bin",
"//tensorflow/lite:testdata/test_model_broken.bin",
"//tensorflow/lite:testdata/test_model_versioned_ops.bin",
],
tags = [
"tflite_not_portable_android",

View File

@ -29,17 +29,26 @@ string NormalizeCustomOpName(const string& op) {
}
void ReadOpsFromModel(const ::tflite::Model* model,
std::vector<string>* builtin_ops,
std::vector<string>* custom_ops) {
tflite::RegisteredOpMap* builtin_ops,
tflite::RegisteredOpMap* custom_ops) {
if (!model) return;
auto opcodes = model->operator_codes();
if (!opcodes) return;
for (const auto* opcode : *opcodes) {
const int version = opcode->version();
if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
builtin_ops->push_back(
tflite::EnumNameBuiltinOperator(opcode->builtin_code()));
auto iter_and_bool = builtin_ops->insert(std::make_pair(
tflite::EnumNameBuiltinOperator(opcode->builtin_code()),
std::make_pair(version, version)));
auto& versions = iter_and_bool.first->second;
versions.first = std::min(versions.first, version);
versions.second = std::max(versions.second, version);
} else {
custom_ops->push_back(opcode->custom_code()->c_str());
auto iter_and_bool = custom_ops->insert(std::make_pair(
opcode->custom_code()->c_str(), std::make_pair(version, version)));
auto& versions = iter_and_bool.first->second;
versions.first = std::min(versions.first, version);
versions.second = std::max(versions.second, version);
}
}
}

View File

@ -27,12 +27,15 @@ namespace tflite {
// Note "Register_" suffix will be added later in the tool.
string NormalizeCustomOpName(const string& op);
// A map from op name to {min_version, max_version}.
typedef std::map<string, std::pair<int, int>> RegisteredOpMap;
// Read ops from the TFLite model.
// Enum name of builtin ops will be stored, such as "CONV_2D".
// Custom op name will be stored as it is.
// The builtin ops key is the enum name of builtin ops, such as "CONV_2D".
// The custom ops key is stored as it is.
void ReadOpsFromModel(const ::tflite::Model* model,
std::vector<string>* builtin_ops,
std::vector<string>* custom_ops);
RegisteredOpMap* builtin_ops,
RegisteredOpMap* custom_ops);
} // namespace tflite

View File

@ -21,11 +21,12 @@ limitations under the License.
#include <vector>
#include "absl/strings/strip.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
const char kInputModelFlag[] = "input_model";
const char kNamespace[] = "namespace";
const char kOutputRegistrationFlag[] = "output_registration";
const char kTfLitePathFlag[] = "tflite_path";
@ -33,25 +34,29 @@ using tensorflow::Flag;
using tensorflow::Flags;
using tensorflow::string;
void ParseFlagAndInit(int argc, char** argv, string* input_model,
string* output_registration, string* tflite_path) {
void ParseFlagAndInit(int* argc, char** argv, string* input_model,
string* output_registration, string* tflite_path,
string* namespace_flag) {
std::vector<tensorflow::Flag> flag_list = {
Flag(kInputModelFlag, input_model, "path to the tflite model"),
Flag(kOutputRegistrationFlag, output_registration,
"filename for generated registration code"),
Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
Flag(kNamespace, namespace_flag,
"Namespace in which to put RegisterSelectedOps."),
};
Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
Flags::Parse(argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], argc, &argv);
}
namespace {
void GenerateFileContent(const std::string& tflite_path,
const std::string& filename,
const std::vector<string>& builtin_ops,
const std::vector<string>& custom_ops) {
const std::string& namespace_flag,
const tflite::RegisteredOpMap& builtin_ops,
const tflite::RegisteredOpMap& custom_ops) {
std::ofstream fout(filename);
fout << "#include \"" << tflite_path << "/model.h\"\n";
@ -63,7 +68,7 @@ void GenerateFileContent(const std::string& tflite_path,
fout << "namespace builtin {\n";
fout << "// Forward-declarations for the builtin ops.\n";
for (const auto& op : builtin_ops) {
fout << "TfLiteRegistration* Register_" << op << "();\n";
fout << "TfLiteRegistration* Register_" << op.first << "();\n";
}
fout << "} // namespace builtin\n";
}
@ -73,45 +78,72 @@ void GenerateFileContent(const std::string& tflite_path,
fout << "// Forward-declarations for the custom ops.\n";
for (const auto& op : custom_ops) {
fout << "TfLiteRegistration* Register_"
<< ::tflite::NormalizeCustomOpName(op) << "();\n";
<< ::tflite::NormalizeCustomOpName(op.first) << "();\n";
}
fout << "} // namespace custom\n";
}
fout << "} // namespace ops\n";
fout << "} // namespace tflite\n";
if (!namespace_flag.empty()) {
fout << "namespace " << namespace_flag << " {\n";
}
fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n";
for (const auto& op : builtin_ops) {
fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op
<< ", ::tflite::ops::builtin::Register_" << op << "());\n";
fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op.first
<< ", ::tflite::ops::builtin::Register_" << op.first << "()";
if (op.second.first != 1 || op.second.second != 1) {
fout << ", " << op.second.first << ", " << op.second.second;
}
fout << ");\n";
}
for (const auto& op : custom_ops) {
fout << " resolver->AddCustom(\"" << op
fout << " resolver->AddCustom(\"" << op.first
<< "\", ::tflite::ops::custom::Register_"
<< ::tflite::NormalizeCustomOpName(op) << "());\n";
<< ::tflite::NormalizeCustomOpName(op.first) << "()";
if (op.second.first != 1 || op.second.second != 1) {
fout << ", " << op.second.first << ", " << op.second.second;
}
fout << ");\n";
}
fout << "}\n";
if (!namespace_flag.empty()) {
fout << "} // namespace " << namespace_flag << "\n";
}
fout.close();
}
} // namespace
int main(int argc, char** argv) {
string input_model;
string output_registration;
string tflite_path;
ParseFlagAndInit(argc, argv, &input_model, &output_registration,
&tflite_path);
std::vector<string> builtin_ops;
std::vector<string> custom_ops;
void AddOpsFromModel(const string& input_model,
tflite::RegisteredOpMap* builtin_ops,
tflite::RegisteredOpMap* custom_ops) {
std::ifstream fin(input_model);
std::stringstream content;
content << fin.rdbuf();
// Need to store content data first, otherwise, it won't work in bazel.
string content_str = content.str();
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
GenerateFileContent(tflite_path, output_registration, builtin_ops,
custom_ops);
::tflite::ReadOpsFromModel(model, builtin_ops, custom_ops);
}
} // namespace
int main(int argc, char** argv) {
string input_model;
string output_registration;
string tflite_path;
string namespace_flag;
ParseFlagAndInit(&argc, argv, &input_model, &output_registration,
&tflite_path, &namespace_flag);
tflite::RegisteredOpMap builtin_ops;
tflite::RegisteredOpMap custom_ops;
if (!input_model.empty()) {
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
}
for (int i = 1; i < argc; i++) {
AddOpsFromModel(argv[i], &builtin_ops, &custom_ops);
}
GenerateFileContent(tflite_path, output_registration, namespace_flag,
builtin_ops, custom_ops);
return 0;
}

View File

@ -32,8 +32,8 @@ class GenOpRegistrationTest : public ::testing::Test {
}
}
std::vector<string> builtin_ops_;
std::vector<string> custom_ops_;
std::map<string, std::pair<int, int>> builtin_ops_;
std::map<string, std::pair<int, int>> custom_ops_;
};
TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
@ -44,8 +44,27 @@ TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
TEST_F(GenOpRegistrationTest, TestModels) {
ReadOps("tensorflow/lite/testdata/test_model.bin");
EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"}));
EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"}));
RegisteredOpMap builtin_expected{{"CONV_2D", {1, 1}}};
RegisteredOpMap custom_expected{{"testing_op", {1, 1}}};
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
}
TEST_F(GenOpRegistrationTest, TestVersionedModels) {
ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin");
RegisteredOpMap builtin_expected{{"CONV_2D", {3, 3}}};
RegisteredOpMap custom_expected{{"testing_op", {2, 2}}};
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
}
TEST_F(GenOpRegistrationTest, TestBothModels) {
ReadOps("tensorflow/lite/testdata/test_model.bin");
ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin");
RegisteredOpMap builtin_expected{{"CONV_2D", {1, 3}}};
RegisteredOpMap custom_expected{{"testing_op", {1, 2}}};
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
}
TEST_F(GenOpRegistrationTest, TestEmptyModels) {