Added features to gen_selected_ops.
- Allow to put RegisterSelectedOps in a namespace. - Support versioned ops. - Support multiple input tflite models. PiperOrigin-RevId: 270053948
This commit is contained in:
parent
dfef3d744b
commit
0228bf0854
@ -608,23 +608,30 @@ def gen_zipped_test_file(name, file, toco, flags):
|
|||||||
srcs = [file],
|
srcs = [file],
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen_selected_ops(name, model):
|
def gen_selected_ops(name, model, namespace = "", **kwargs):
|
||||||
"""Generate the library that includes only used ops.
|
"""Generate the library that includes only used ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the generated library.
|
name: Name of the generated library.
|
||||||
model: TFLite model to interpret.
|
model: TFLite model to interpret.
|
||||||
|
namespace: Namespace in which to put RegisterSelectedOps.
|
||||||
|
**kwargs: Additional kwargs to pass to genrule.
|
||||||
"""
|
"""
|
||||||
out = name + "_registration.cc"
|
out = name + "_registration.cc"
|
||||||
tool = "//tensorflow/lite/tools:generate_op_registrations"
|
tool = "//tensorflow/lite/tools:generate_op_registrations"
|
||||||
tflite_path = "//tensorflow/lite"
|
tflite_path = "//tensorflow/lite"
|
||||||
|
|
||||||
|
# isinstance is not supported in skylark.
|
||||||
|
if type(model) != type([]):
|
||||||
|
model = [model]
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name = name,
|
name = name,
|
||||||
srcs = [model],
|
srcs = model,
|
||||||
outs = [out],
|
outs = [out],
|
||||||
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
|
cmd = ("$(location %s) --namespace=%s --output_registration=$(location %s) --tflite_path=%s $(SRCS)") %
|
||||||
(tool, model, out, tflite_path[2:]),
|
(tool, namespace, out, tflite_path[2:]),
|
||||||
tools = [tool],
|
tools = [tool],
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def flex_dep(target_op_sets):
|
def flex_dep(target_op_sets):
|
||||||
|
BIN
tensorflow/lite/testdata/test_model_versioned_ops.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/test_model_versioned_ops.bin
vendored
Normal file
Binary file not shown.
@ -54,6 +54,7 @@ cc_test(
|
|||||||
"//tensorflow/lite:testdata/empty_model.bin",
|
"//tensorflow/lite:testdata/empty_model.bin",
|
||||||
"//tensorflow/lite:testdata/test_model.bin",
|
"//tensorflow/lite:testdata/test_model.bin",
|
||||||
"//tensorflow/lite:testdata/test_model_broken.bin",
|
"//tensorflow/lite:testdata/test_model_broken.bin",
|
||||||
|
"//tensorflow/lite:testdata/test_model_versioned_ops.bin",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"tflite_not_portable_android",
|
"tflite_not_portable_android",
|
||||||
|
@ -29,17 +29,26 @@ string NormalizeCustomOpName(const string& op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ReadOpsFromModel(const ::tflite::Model* model,
|
void ReadOpsFromModel(const ::tflite::Model* model,
|
||||||
std::vector<string>* builtin_ops,
|
tflite::RegisteredOpMap* builtin_ops,
|
||||||
std::vector<string>* custom_ops) {
|
tflite::RegisteredOpMap* custom_ops) {
|
||||||
if (!model) return;
|
if (!model) return;
|
||||||
auto opcodes = model->operator_codes();
|
auto opcodes = model->operator_codes();
|
||||||
if (!opcodes) return;
|
if (!opcodes) return;
|
||||||
for (const auto* opcode : *opcodes) {
|
for (const auto* opcode : *opcodes) {
|
||||||
|
const int version = opcode->version();
|
||||||
if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
|
if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) {
|
||||||
builtin_ops->push_back(
|
auto iter_and_bool = builtin_ops->insert(std::make_pair(
|
||||||
tflite::EnumNameBuiltinOperator(opcode->builtin_code()));
|
tflite::EnumNameBuiltinOperator(opcode->builtin_code()),
|
||||||
|
std::make_pair(version, version)));
|
||||||
|
auto& versions = iter_and_bool.first->second;
|
||||||
|
versions.first = std::min(versions.first, version);
|
||||||
|
versions.second = std::max(versions.second, version);
|
||||||
} else {
|
} else {
|
||||||
custom_ops->push_back(opcode->custom_code()->c_str());
|
auto iter_and_bool = custom_ops->insert(std::make_pair(
|
||||||
|
opcode->custom_code()->c_str(), std::make_pair(version, version)));
|
||||||
|
auto& versions = iter_and_bool.first->second;
|
||||||
|
versions.first = std::min(versions.first, version);
|
||||||
|
versions.second = std::max(versions.second, version);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,12 +27,15 @@ namespace tflite {
|
|||||||
// Note "Register_" suffix will be added later in the tool.
|
// Note "Register_" suffix will be added later in the tool.
|
||||||
string NormalizeCustomOpName(const string& op);
|
string NormalizeCustomOpName(const string& op);
|
||||||
|
|
||||||
|
// A map from op name to {min_version, max_version}.
|
||||||
|
typedef std::map<string, std::pair<int, int>> RegisteredOpMap;
|
||||||
|
|
||||||
// Read ops from the TFLite model.
|
// Read ops from the TFLite model.
|
||||||
// Enum name of builtin ops will be stored, such as "CONV_2D".
|
// The builtin ops key is the enum name of builtin ops, such as "CONV_2D".
|
||||||
// Custom op name will be stored as it is.
|
// The custom ops key is stored as it is.
|
||||||
void ReadOpsFromModel(const ::tflite::Model* model,
|
void ReadOpsFromModel(const ::tflite::Model* model,
|
||||||
std::vector<string>* builtin_ops,
|
RegisteredOpMap* builtin_ops,
|
||||||
std::vector<string>* custom_ops);
|
RegisteredOpMap* custom_ops);
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -21,11 +21,12 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/strip.h"
|
#include "absl/strings/strip.h"
|
||||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||||
|
|
||||||
const char kInputModelFlag[] = "input_model";
|
const char kInputModelFlag[] = "input_model";
|
||||||
|
const char kNamespace[] = "namespace";
|
||||||
const char kOutputRegistrationFlag[] = "output_registration";
|
const char kOutputRegistrationFlag[] = "output_registration";
|
||||||
const char kTfLitePathFlag[] = "tflite_path";
|
const char kTfLitePathFlag[] = "tflite_path";
|
||||||
|
|
||||||
@ -33,25 +34,29 @@ using tensorflow::Flag;
|
|||||||
using tensorflow::Flags;
|
using tensorflow::Flags;
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
void ParseFlagAndInit(int argc, char** argv, string* input_model,
|
void ParseFlagAndInit(int* argc, char** argv, string* input_model,
|
||||||
string* output_registration, string* tflite_path) {
|
string* output_registration, string* tflite_path,
|
||||||
|
string* namespace_flag) {
|
||||||
std::vector<tensorflow::Flag> flag_list = {
|
std::vector<tensorflow::Flag> flag_list = {
|
||||||
Flag(kInputModelFlag, input_model, "path to the tflite model"),
|
Flag(kInputModelFlag, input_model, "path to the tflite model"),
|
||||||
Flag(kOutputRegistrationFlag, output_registration,
|
Flag(kOutputRegistrationFlag, output_registration,
|
||||||
"filename for generated registration code"),
|
"filename for generated registration code"),
|
||||||
Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
|
Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
|
||||||
|
Flag(kNamespace, namespace_flag,
|
||||||
|
"Namespace in which to put RegisterSelectedOps."),
|
||||||
};
|
};
|
||||||
|
|
||||||
Flags::Parse(&argc, argv, flag_list);
|
Flags::Parse(argc, argv, flag_list);
|
||||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
tensorflow::port::InitMain(argv[0], argc, &argv);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void GenerateFileContent(const std::string& tflite_path,
|
void GenerateFileContent(const std::string& tflite_path,
|
||||||
const std::string& filename,
|
const std::string& filename,
|
||||||
const std::vector<string>& builtin_ops,
|
const std::string& namespace_flag,
|
||||||
const std::vector<string>& custom_ops) {
|
const tflite::RegisteredOpMap& builtin_ops,
|
||||||
|
const tflite::RegisteredOpMap& custom_ops) {
|
||||||
std::ofstream fout(filename);
|
std::ofstream fout(filename);
|
||||||
|
|
||||||
fout << "#include \"" << tflite_path << "/model.h\"\n";
|
fout << "#include \"" << tflite_path << "/model.h\"\n";
|
||||||
@ -63,7 +68,7 @@ void GenerateFileContent(const std::string& tflite_path,
|
|||||||
fout << "namespace builtin {\n";
|
fout << "namespace builtin {\n";
|
||||||
fout << "// Forward-declarations for the builtin ops.\n";
|
fout << "// Forward-declarations for the builtin ops.\n";
|
||||||
for (const auto& op : builtin_ops) {
|
for (const auto& op : builtin_ops) {
|
||||||
fout << "TfLiteRegistration* Register_" << op << "();\n";
|
fout << "TfLiteRegistration* Register_" << op.first << "();\n";
|
||||||
}
|
}
|
||||||
fout << "} // namespace builtin\n";
|
fout << "} // namespace builtin\n";
|
||||||
}
|
}
|
||||||
@ -73,45 +78,72 @@ void GenerateFileContent(const std::string& tflite_path,
|
|||||||
fout << "// Forward-declarations for the custom ops.\n";
|
fout << "// Forward-declarations for the custom ops.\n";
|
||||||
for (const auto& op : custom_ops) {
|
for (const auto& op : custom_ops) {
|
||||||
fout << "TfLiteRegistration* Register_"
|
fout << "TfLiteRegistration* Register_"
|
||||||
<< ::tflite::NormalizeCustomOpName(op) << "();\n";
|
<< ::tflite::NormalizeCustomOpName(op.first) << "();\n";
|
||||||
}
|
}
|
||||||
fout << "} // namespace custom\n";
|
fout << "} // namespace custom\n";
|
||||||
}
|
}
|
||||||
fout << "} // namespace ops\n";
|
fout << "} // namespace ops\n";
|
||||||
fout << "} // namespace tflite\n";
|
fout << "} // namespace tflite\n";
|
||||||
|
|
||||||
|
if (!namespace_flag.empty()) {
|
||||||
|
fout << "namespace " << namespace_flag << " {\n";
|
||||||
|
}
|
||||||
fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n";
|
fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n";
|
||||||
for (const auto& op : builtin_ops) {
|
for (const auto& op : builtin_ops) {
|
||||||
fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op
|
fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op.first
|
||||||
<< ", ::tflite::ops::builtin::Register_" << op << "());\n";
|
<< ", ::tflite::ops::builtin::Register_" << op.first << "()";
|
||||||
|
if (op.second.first != 1 || op.second.second != 1) {
|
||||||
|
fout << ", " << op.second.first << ", " << op.second.second;
|
||||||
|
}
|
||||||
|
fout << ");\n";
|
||||||
}
|
}
|
||||||
for (const auto& op : custom_ops) {
|
for (const auto& op : custom_ops) {
|
||||||
fout << " resolver->AddCustom(\"" << op
|
fout << " resolver->AddCustom(\"" << op.first
|
||||||
<< "\", ::tflite::ops::custom::Register_"
|
<< "\", ::tflite::ops::custom::Register_"
|
||||||
<< ::tflite::NormalizeCustomOpName(op) << "());\n";
|
<< ::tflite::NormalizeCustomOpName(op.first) << "()";
|
||||||
|
if (op.second.first != 1 || op.second.second != 1) {
|
||||||
|
fout << ", " << op.second.first << ", " << op.second.second;
|
||||||
|
}
|
||||||
|
fout << ");\n";
|
||||||
}
|
}
|
||||||
fout << "}\n";
|
fout << "}\n";
|
||||||
|
if (!namespace_flag.empty()) {
|
||||||
|
fout << "} // namespace " << namespace_flag << "\n";
|
||||||
|
}
|
||||||
fout.close();
|
fout.close();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
void AddOpsFromModel(const string& input_model,
|
||||||
string input_model;
|
tflite::RegisteredOpMap* builtin_ops,
|
||||||
string output_registration;
|
tflite::RegisteredOpMap* custom_ops) {
|
||||||
string tflite_path;
|
|
||||||
ParseFlagAndInit(argc, argv, &input_model, &output_registration,
|
|
||||||
&tflite_path);
|
|
||||||
|
|
||||||
std::vector<string> builtin_ops;
|
|
||||||
std::vector<string> custom_ops;
|
|
||||||
std::ifstream fin(input_model);
|
std::ifstream fin(input_model);
|
||||||
std::stringstream content;
|
std::stringstream content;
|
||||||
content << fin.rdbuf();
|
content << fin.rdbuf();
|
||||||
// Need to store content data first, otherwise, it won't work in bazel.
|
// Need to store content data first, otherwise, it won't work in bazel.
|
||||||
string content_str = content.str();
|
string content_str = content.str();
|
||||||
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
|
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
|
||||||
::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
|
::tflite::ReadOpsFromModel(model, builtin_ops, custom_ops);
|
||||||
GenerateFileContent(tflite_path, output_registration, builtin_ops,
|
}
|
||||||
custom_ops);
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
string input_model;
|
||||||
|
string output_registration;
|
||||||
|
string tflite_path;
|
||||||
|
string namespace_flag;
|
||||||
|
ParseFlagAndInit(&argc, argv, &input_model, &output_registration,
|
||||||
|
&tflite_path, &namespace_flag);
|
||||||
|
|
||||||
|
tflite::RegisteredOpMap builtin_ops;
|
||||||
|
tflite::RegisteredOpMap custom_ops;
|
||||||
|
if (!input_model.empty()) {
|
||||||
|
AddOpsFromModel(input_model, &builtin_ops, &custom_ops);
|
||||||
|
}
|
||||||
|
for (int i = 1; i < argc; i++) {
|
||||||
|
AddOpsFromModel(argv[i], &builtin_ops, &custom_ops);
|
||||||
|
}
|
||||||
|
GenerateFileContent(tflite_path, output_registration, namespace_flag,
|
||||||
|
builtin_ops, custom_ops);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -32,8 +32,8 @@ class GenOpRegistrationTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> builtin_ops_;
|
std::map<string, std::pair<int, int>> builtin_ops_;
|
||||||
std::vector<string> custom_ops_;
|
std::map<string, std::pair<int, int>> custom_ops_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
|
TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
|
||||||
@ -44,8 +44,27 @@ TEST_F(GenOpRegistrationTest, TestNonExistantFiles) {
|
|||||||
|
|
||||||
TEST_F(GenOpRegistrationTest, TestModels) {
|
TEST_F(GenOpRegistrationTest, TestModels) {
|
||||||
ReadOps("tensorflow/lite/testdata/test_model.bin");
|
ReadOps("tensorflow/lite/testdata/test_model.bin");
|
||||||
EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"}));
|
RegisteredOpMap builtin_expected{{"CONV_2D", {1, 1}}};
|
||||||
EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"}));
|
RegisteredOpMap custom_expected{{"testing_op", {1, 1}}};
|
||||||
|
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
|
||||||
|
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GenOpRegistrationTest, TestVersionedModels) {
|
||||||
|
ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin");
|
||||||
|
RegisteredOpMap builtin_expected{{"CONV_2D", {3, 3}}};
|
||||||
|
RegisteredOpMap custom_expected{{"testing_op", {2, 2}}};
|
||||||
|
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
|
||||||
|
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GenOpRegistrationTest, TestBothModels) {
|
||||||
|
ReadOps("tensorflow/lite/testdata/test_model.bin");
|
||||||
|
ReadOps("tensorflow/lite/testdata/test_model_versioned_ops.bin");
|
||||||
|
RegisteredOpMap builtin_expected{{"CONV_2D", {1, 3}}};
|
||||||
|
RegisteredOpMap custom_expected{{"testing_op", {1, 2}}};
|
||||||
|
EXPECT_THAT(builtin_ops_, ElementsAreArray(builtin_expected));
|
||||||
|
EXPECT_THAT(custom_ops_, ElementsAreArray(custom_expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GenOpRegistrationTest, TestEmptyModels) {
|
TEST_F(GenOpRegistrationTest, TestEmptyModels) {
|
||||||
|
Loading…
Reference in New Issue
Block a user