Added features to gen_selected_ops.
- Allow to put RegisterSelectedOps in a namespace. - Support versioned ops. - Support multiple input tflite models. PiperOrigin-RevId: 270053948
This commit is contained in:
parent
dfef3d744b
commit
0228bf0854
@ -608,23 +608,30 @@ def gen_zipped_test_file(name, file, toco, flags):
|
||||
srcs = [file],
|
||||
)
|
||||
|
||||
def gen_selected_ops(name, model):
|
||||
def gen_selected_ops(name, model, namespace = "", **kwargs):
|
||||
"""Generate the library that includes only used ops.
|
||||
|
||||
Args:
|
||||
name: Name of the generated library.
|
||||
model: TFLite model to interpret.
|
||||
namespace: Namespace in which to put RegisterSelectedOps.
|
||||
**kwargs: Additional kwargs to pass to genrule.
|
||||
"""
|
||||
out = name + "_registration.cc"
|
||||
tool = "//tensorflow/lite/tools:generate_op_registrations"
|
||||
tflite_path = "//tensorflow/lite"
|
||||
|
||||
# isinstance is not supported in skylark.
|
||||
if type(model) != type([]):
|
||||
model = [model]
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [model],
|
||||
srcs = model,
|
||||
outs = [out],
|
||||
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
|
||||
(tool, model, out, tflite_path[2:]),
|
||||
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") %
|
||||
(tool, namespace, out, tflite_path[2:]),
|
||||
tools = [tool],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def flex_dep(target_op_sets):
|
||||
|
BIN
tensorflow/lite/testdata/test_model_versioned_ops.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/test_model_versioned_ops.bin
vendored
Normal file
Binary file not shown.
@ -54,6 +54,7 @@ cc_test(
|
||||
"//tensorflow/lite:testdata/empty_model.bin",
|
||||
"//tensorflow/lite:testdata/test_model.bin",
|
||||
"//tensorflow/lite:testdata/test_model_broken.bin",
|
||||
"//tensorflow/lite:testdata/test_model_versioned_ops.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable_android",
|
||||
|
@ -29,17 +29,26 @@ string NormalizeCustomOpName(const string& op) {
|
||||
}
|
||||
|
||||
void ReadOpsFromModel(const ::tflite::Model* model,
|
||||
std::vector<string>* builtin_ops,
|
||||
std::vector<string>* custom_ops) {
|
||||
tflite::RegisteredOpMap* builtin_ops,
|
||||
tflite::RegisteredOpMap* custom_ops) {
|
||||
if (!model) return;
|
||||
auto opcodes = model->operator_codes();
|
||||
if (!opcodes) return;
|
||||
for (const auto* opcode : *opcodes) {
|
||||
const int version = opcode->version();
|
||||
if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
|
||||
builtin_ops->push_back(
|
||||
tflite::EnumNameBuiltinOperator(opcode->builtin_code()));
|
||||
auto iter_and_bool = builtin_ops->insert(std::make_pair(
|
||||
tflite::EnumNameBuiltinOperator(opcode->builtin_code()),
|
||||
std::make_pair(version, version)));
|
||||
auto& versions = iter_and_bool.first->second;
|
||||
versions.first = std::min(versions.first, version);
|
||||
versions.second = std::max(versions.second, version);
|
||||
} else {
|
||||
custom_ops->push_back(opcode->custom_code()->c_str());
|
||||
auto iter_and_bool = custom_ops->insert(std::make_pair(
|
||||
opcode->custom_code()->c_str(), std::make_pair(version, version)));
|
||||
auto& versions = iter_and_bool.first->second;
|
||||
versions.first = std::min(versions.first, version);
|
||||
versions.second = std::max(versions.second, version);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,12 +27,15 @@ namespace tflite {
|
||||
// Note "Register_" suffix will be added later in the tool.
|
||||
string NormalizeCustomOpName(const string& op);
|
||||
|
||||
// A map from op name to {min_version, max_version}.
|
||||
typedef std::map<string, std::pair<int, int>> RegisteredOpMap;
|
||||
|
||||
// Read ops from the TFLite model.
|
||||
// Enum name of builtin ops will be stored, such as "CONV_2D".
|
||||
// Custom op name will be stored as it is.
|
||||
// The builtin ops key is the enum name of builtin ops, such as "CONV_2D".
|
||||
// The custom ops key is stored as it is.
|
||||
void ReadOpsFromModel(const ::tflite::Model* model,
|
||||
std::vector<string>* builtin_ops,
|
||||
std::vector<string>* custom_ops);
|
||||
RegisteredOpMap* builtin_ops,
|
||||
RegisteredOpMap* custom_ops);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -21,11 +21,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||
|
||||
const char kInputModelFlag[] = "input_model";
|
||||
const char kNamespace[] = "namespace";
|
||||
const char kOutputRegistrationFlag[] = "output_registration";
|
||||
const char kTfLitePathFlag[] = "tflite_path";
|
||||
|
||||
@ -33,25 +34,29 @@ using tensorflow::Flag;
|
||||
using tensorflow::Flags;
|
||||
using tensorflow::string;
|
||||
|
||||
void ParseFlagAndInit(int argc, char** argv, string* input_model,
|
||||
string* output_registration, string* tflite_path) {
|
||||
void ParseFlagAndInit(int* argc, char** argv, string* input_model,
|
||||
string* output_registration, string* tflite_path,
|
||||
string* namespace_flag) {
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
Flag(kInputModelFlag, input_model, "path to the tflite model"),
|
||||
Flag(kOutputRegistrationFlag, output_registration,
|
||||
"filename for generated registration code"),
|
||||
Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
|
||||
Flag(kNamespace, namespace_flag,
|
||||
"Namespace in which to put RegisterSelectedOps."),
|
||||
};
|
||||
|
||||
Flags::Parse(&argc, argv, flag_list);
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
Flags::Parse(argc, argv, flag_list);
|
||||
tensorflow::port::InitMain(argv[0], argc, &argv);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void GenerateFileContent(const std::string& tflite_path,
|
||||
const std::string& filename,
|
||||
const std::vector<string>& builtin_ops,
|
||||
const std::vector<string>& custom_ops) {
|
||||
const std::string& namespace_flag,
|
||||
const tflite::RegisteredOpMap& builtin_ops,
|
||||
const tflite::RegisteredOpMap& custom_ops) {
|
||||
std::ofstream fout(filename);
|
||||
|
||||
fout << "#include \"" << tflite_path << "/model.h\"\n";
|
||||
@ -63,7 +68,7 @@ void GenerateFileContent(const std::string& tflite_path,
|
||||
fout << "namespace builtin {\n";
|
||||
fout << "// Forward-declarations for the builtin ops.\n";
|
||||
for (const auto& op : builtin_ops) {
|
||||
fout << "TfLiteRegistration* Register_" << op << "();\n";
|
||||
fout << "TfLiteRegistration* Register_" << op.first << "();\n";
|
||||
}
|
||||
fout << "} // namespace builtin\n";
|
||||
}
|
||||
@ -73,45 +78,72 @@ void GenerateFileContent(const std::string& tflite_path,
|
||||
fout << "// Forward-declarations for the custom ops.\n";
|
||||
for (const auto& op : custom_ops) {
|
||||
fout << "TfLiteRegistration* Register_"
|
||||
<< ::tflite::NormalizeCustomOpName(op) << "();\n";
|
||||
<< ::tflite::NormalizeCustomOpName(op.first) << "();\n";
|
||||
}
|
||||
fout << "} // namespace custom\n";
|
||||
}
|
||||
fout << "} // namespace ops\n";
|
||||
fout << "} // namespace tflite\n";
|
||||
|
||||
if (!namespace_flag.empty()) {
|
||||
fout << "namespace " << namespace_flag << " {\n";
|
||||
}
|
||||
fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n";
|
||||
for (const auto& op : builtin_ops) {
|
||||
fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op
|
||||
<< ", ::tflite::ops::builtin::Register_" << op << "());\n";
|
||||
fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op.first
|
||||
<< ", ::tflite::ops::builtin::Register_" << op.first << "()";
|
||||
if (op.second.first != 1 || op.second.second != 1) {
|
||||
fout << ", " << op.second.first << ", " << op.second.second;
|
||||
}
|
||||
fout << ");\n";
|
||||
}
|
||||
for (const auto& op : custom_ops) {
|
||||
fout << " resolver->AddCustom(\"" << op
|
||||
fout << " resolver->AddCustom(\"" << op.first
|
||||
<< "\", ::tflite::ops::custom::Register_"
|
||||
<< ::tflite::NormalizeCustomOpName(op) << "());\n";
|
||||
<< ::tflite::NormalizeCustomOpName(op.first) << "()";
|
||||
if (op.second.first != 1 || op.second.second != 1) {
|
||||
fout << ", " << op.second.first << ", " << op.second.second;
|
||||
}
|
||||
fout << ");\n";
|
||||
}
|
||||
fout << "}\n";
|
||||
if (!namespace_flag.empty()) {
|
||||
fout << "} // namespace " << namespace_flag << "\n";
|
||||
}
|
||||
fout.close();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
string input_model;
|
||||
string output_registration;
|
||||
string tflite_path;
|
||||
ParseFlagAndInit(argc, argv, &input_model, &output_registration,
|
||||
&tflite_path);
|
||||
|
||||
std::vector<string> builtin_ops;
|
||||
std::vector<string> custom_ops;
|
||||
void AddOpsFromModel(const string& input_model,
|
||||
tflite::RegisteredOpMap* builtin_ops,
|
||||
tflite::RegisteredOpMap* custom_ops) {
|
||||
std::ifstream fin(input_model);
|
||||
std::stringstream content;
|
||||
content << fin.rdbuf();
|
||||
// Need to store content data first, otherwise, it won't work in bazel.
|
||||
string content_str = content.str();
|
||||
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
|
||||
::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
|
||||
GenerateFileContent(tflite_path, output_registration, builtin_ops,
|
||||
custom_ops);
|
||||
::tflite::ReadOpsFromModel(model, builtin_ops, custom_ops);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
string input_model;
|
||||
string output_registration;
|
||||
string tflite_path;
|
||||
string namespace_flag;
|
||||
ParseFlagAndInit(&argc, argv, &input_model, &output_registration,
|
||||
&tflite_path, &namespace_flag);
|
||||
|
||||
tflite::RegisteredOpMap builtin_ops;
|
||||
tflite::RegisteredOpMap custom_ops;
|
||||
if (!input_model.empty()) {
|
||||
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
|
||||
}
|
||||
for (int i = 1; i < argc; i++) {
|
||||
AddOpsFromModel(argv[i], &builtin_ops, &custom_ops);
|
||||
}
|
||||
GenerateFileContent(tflite_path, output_registration, namespace_flag,
|
||||
builtin_ops, custom_ops);
|
||||
return 0;
|
||||
}
|
||||
|
@ -32,8 +32,8 @@ class GenOpRegistrationTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<string> builtin_ops_;
|
||||
std::vector<string> custom_ops_;
|
||||
std::map<string, std::pair<int, int>> builtin_ops_;
|
||||
std::map<string, std::pair<int, int>> custom_ops_;
|
||||
};
|
||||
|
||||
TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
|
||||
@ -44,8 +44,27 @@ TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
|
||||
|
||||
TEST_F(GenOpRegistrationTest, TestModels) {
|
||||
ReadOps("tensorflow/lite/testdata/test_model.bin");
|
||||
EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"}));
|
||||
EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"}));
|
||||
RegisteredOpMap builtin_expected{{"CONV_2D", {1, 1}}};
|
||||
RegisteredOpMap custom_expected{{"testing_op", {1, 1}}};
|
||||
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
|
||||
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
|
||||
}
|
||||
|
||||
TEST_F(GenOpRegistrationTest, TestVersionedModels) {
|
||||
ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin");
|
||||
RegisteredOpMap builtin_expected{{"CONV_2D", {3, 3}}};
|
||||
RegisteredOpMap custom_expected{{"testing_op", {2, 2}}};
|
||||
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
|
||||
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
|
||||
}
|
||||
|
||||
TEST_F(GenOpRegistrationTest, TestBothModels) {
|
||||
ReadOps("tensorflow/lite/testdata/test_model.bin");
|
||||
ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin");
|
||||
RegisteredOpMap builtin_expected{{"CONV_2D", {1, 3}}};
|
||||
RegisteredOpMap custom_expected{{"testing_op", {1, 2}}};
|
||||
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
|
||||
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
|
||||
}
|
||||
|
||||
TEST_F(GenOpRegistrationTest, TestEmptyModels) {
|
||||
|
Loading…
Reference in New Issue
Block a user