Support TFLite in the tool to print selective registration header
PiperOrigin-RevId: 314282031 Change-Id: Ie71b434c177d03e246a5cfde3d067ac695b71299
This commit is contained in:
parent
e2aa757a55
commit
e9781e9b16
@ -486,6 +486,9 @@ class SingleOpModel {
|
||||
return std::vector<T>(v, v + tensor_size);
|
||||
}
|
||||
|
||||
// Return the TFLite model buffer, only available after BuildInterpreter.
|
||||
const uint8_t* GetModelBuffer() { return builder_.GetBufferPointer(); }
|
||||
|
||||
std::vector<int> GetTensorShape(int index) {
|
||||
std::vector<int> result;
|
||||
TfLiteTensor* t = interpreter_->tensor(index);
|
||||
|
BIN
tensorflow/lite/testdata/softplus_flex.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/softplus_flex.bin
vendored
Normal file
Binary file not shown.
@ -134,6 +134,7 @@ cc_binary(
|
||||
deps = [
|
||||
":command_line_flags",
|
||||
":gen_op_registration",
|
||||
"//tensorflow/lite:util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -252,6 +253,60 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "list_flex_ops",
|
||||
srcs = ["list_flex_ops.cc"],
|
||||
hdrs = ["list_flex_ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:util",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
# This tool list flex ops and kernels inside a TFLite file.
|
||||
# It is used to generate header file for selective registration.
|
||||
cc_binary(
|
||||
name = "list_flex_ops_main",
|
||||
srcs = ["list_flex_ops_main.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":list_flex_ops",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "list_flex_ops_test",
|
||||
srcs = ["list_flex_ops_test.cc"],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/0_subgraphs.bin",
|
||||
"//tensorflow/lite:testdata/multi_add_flex.bin",
|
||||
"//tensorflow/lite:testdata/softplus_flex.bin",
|
||||
"//tensorflow/lite:testdata/test_model.bin",
|
||||
"//tensorflow/lite:testdata/test_model_broken.bin",
|
||||
],
|
||||
tags = [
|
||||
"no_oss", # Currently requires --config=monolithic, b/118895218.
|
||||
"tflite_not_portable_android",
|
||||
"tflite_not_portable_ios",
|
||||
],
|
||||
deps = [
|
||||
":list_flex_ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "zip_files",
|
||||
srcs = ["zip_files.py"],
|
||||
|
@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "re2/re2.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/gen_op_registration.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
const char kInputModelFlag[] = "input_models";
|
||||
const char kNamespace[] = "namespace";
|
||||
@ -84,6 +85,8 @@ void GenerateFileContent(const std::string& tflite_path,
|
||||
fout << "namespace custom {\n";
|
||||
fout << "// Forward-declarations for the custom ops.\n";
|
||||
for (const auto& op : custom_ops) {
|
||||
// Skips Tensorflow ops, only TFLite custom ops can be registered here.
|
||||
if (tflite::IsFlexOp(op.first.c_str())) continue;
|
||||
fout << "TfLiteRegistration* Register_"
|
||||
<< ::tflite::NormalizeCustomOpName(op.first) << "();\n";
|
||||
}
|
||||
@ -115,6 +118,8 @@ void GenerateFileContent(const std::string& tflite_path,
|
||||
fout << ");\n";
|
||||
}
|
||||
for (const auto& op : custom_ops) {
|
||||
// Skips Tensorflow ops, only TFLite custom ops can be registered here.
|
||||
if (tflite::IsFlexOp(op.first.c_str())) continue;
|
||||
fout << " resolver->AddCustom(\"" << op.first
|
||||
<< "\", ::tflite::ops::custom::Register_"
|
||||
<< ::tflite::NormalizeCustomOpName(op.first) << "()";
|
||||
|
128
tensorflow/lite/tools/list_flex_ops.cc
Normal file
128
tensorflow/lite/tools/list_flex_ops.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/list_flex_ops.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
std::string OpListToJSONString(const OpKernelSet& flex_ops) {
|
||||
return absl::StrCat("[",
|
||||
absl::StrJoin(flex_ops, ",\n",
|
||||
[](std::string* out, const OpKernel& op) {
|
||||
absl::StrAppend(out, "[\"", op.op_name,
|
||||
"\", \"", op.kernel_name,
|
||||
"\"]");
|
||||
}),
|
||||
"]");
|
||||
}
|
||||
|
||||
// Find the class name of the op kernel described in the node_def from the pool
|
||||
// of registered ops. If no kernel class is found, return an empty string.
|
||||
string FindTensorflowKernelClass(tensorflow::NodeDef* node_def) {
|
||||
if (!node_def || node_def->op().empty()) {
|
||||
LOG(FATAL) << "Invalid NodeDef";
|
||||
}
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
auto status =
|
||||
tensorflow::OpRegistry::Global()->LookUp(node_def->op(), &op_reg_data);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "Op " << node_def->op() << " not found: " << status;
|
||||
}
|
||||
AddDefaultsToNodeDef(op_reg_data->op_def, node_def);
|
||||
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(node_def->device(),
|
||||
&parsed_name)) {
|
||||
LOG(FATAL) << "Failed to parse device from node_def: "
|
||||
<< node_def->ShortDebugString();
|
||||
}
|
||||
string class_name;
|
||||
if (!tensorflow::FindKernelDef(
|
||||
tensorflow::DeviceType(parsed_name.type.c_str()), *node_def,
|
||||
nullptr /* kernel_def */, &class_name)
|
||||
.ok()) {
|
||||
LOG(FATAL) << "Failed to find kernel class for op: " << node_def->op();
|
||||
}
|
||||
return class_name;
|
||||
}
|
||||
|
||||
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) {
|
||||
// Read flex ops.
|
||||
auto* subgraphs = model->subgraphs();
|
||||
if (!subgraphs) return;
|
||||
for (int subgraph_index = 0; subgraph_index < subgraphs->size();
|
||||
++subgraph_index) {
|
||||
const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index);
|
||||
auto* operators = subgraph->operators();
|
||||
auto* opcodes = model->operator_codes();
|
||||
if (!operators || !opcodes) continue;
|
||||
for (int i = 0; i < operators->size(); ++i) {
|
||||
const tflite::Operator* op = operators->Get(i);
|
||||
const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index());
|
||||
if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM ||
|
||||
!tflite::IsFlexOp(opcode->custom_code()->c_str())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove the "Flex" prefix from op name.
|
||||
std::string flex_op_name(opcode->custom_code()->c_str());
|
||||
std::string tf_op_name =
|
||||
flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix));
|
||||
|
||||
// Read NodeDef and find the op kernel class.
|
||||
if (op->custom_options_format() !=
|
||||
tflite::CustomOptionsFormat_FLEXBUFFERS) {
|
||||
LOG(FATAL) << "Invalid CustomOptionsFormat";
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t>* custom_opt_bytes =
|
||||
op->custom_options();
|
||||
if (custom_opt_bytes && custom_opt_bytes->size()) {
|
||||
// NOLINTNEXTLINE: It is common to use references with flatbuffer.
|
||||
const flexbuffers::Vector& v =
|
||||
flexbuffers::GetRoot(custom_opt_bytes->data(),
|
||||
custom_opt_bytes->size())
|
||||
.AsVector();
|
||||
std::string nodedef_str = v[1].AsString().str();
|
||||
tensorflow::NodeDef nodedef;
|
||||
if (nodedef_str.empty() || !nodedef.ParseFromString(nodedef_str)) {
|
||||
LOG(FATAL) << "Failed to parse data into a valid NodeDef";
|
||||
}
|
||||
// Flex delegate only supports running flex ops with CPU.
|
||||
*nodedef.mutable_device() = "/CPU:0";
|
||||
std::string kernel_class = FindTensorflowKernelClass(&nodedef);
|
||||
flex_ops->insert({tf_op_name, kernel_class});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
55
tensorflow/lite/tools/list_flex_ops.h
Normal file
55
tensorflow/lite/tools/list_flex_ops.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
// Store the Op and Kernel name of an op as the key of a set or map.
|
||||
struct OpKernel {
|
||||
std::string op_name;
|
||||
std::string kernel_name;
|
||||
};
|
||||
|
||||
// The comparison function for OpKernel.
|
||||
struct OpKernelCompare {
|
||||
bool operator()(const OpKernel& lhs, const OpKernel& rhs) const {
|
||||
if (lhs.op_name == rhs.op_name) {
|
||||
return lhs.kernel_name < rhs.kernel_name;
|
||||
}
|
||||
return lhs.op_name < rhs.op_name;
|
||||
}
|
||||
};
|
||||
|
||||
using OpKernelSet = std::set<OpKernel, OpKernelCompare>;
|
||||
|
||||
// Find flex ops and its kernel classes inside a TFLite model and add them to
|
||||
// the map flex_ops. The map stores
|
||||
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops);
|
||||
|
||||
// Serialize the list op of to a json string. If flex_ops is empty, return an
|
||||
// empty string.
|
||||
std::string OpListToJSONString(const OpKernelSet& flex_ops);
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_
|
50
tensorflow/lite/tools/list_flex_ops_main.cc
Normal file
50
tensorflow/lite/tools/list_flex_ops_main.cc
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
#include "tensorflow/lite/tools/list_flex_ops.h"
|
||||
|
||||
const char kInputModelsFlag[] = "graphs";
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string input_models;
|
||||
std::vector<tflite::Flag> flag_list = {
|
||||
tflite::Flag::CreateFlag(kInputModelsFlag, &input_models,
|
||||
"path to the tflite models, separated by comma.",
|
||||
tflite::Flag::kRequired),
|
||||
};
|
||||
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
|
||||
|
||||
std::vector<std::string> models = absl::StrSplit(input_models, ',');
|
||||
tflite::flex::OpKernelSet flex_ops;
|
||||
for (const std::string& model_file : models) {
|
||||
std::ifstream fin;
|
||||
fin.exceptions(std::ifstream::failbit | std::ifstream::badbit);
|
||||
fin.open(model_file);
|
||||
std::stringstream content;
|
||||
content << fin.rdbuf();
|
||||
|
||||
// Need to store content data first, otherwise, it won't work in bazel.
|
||||
std::string content_str = content.str();
|
||||
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
|
||||
tflite::flex::AddFlexOpsFromModel(model, &flex_ops);
|
||||
}
|
||||
std::cout << tflite::flex::OpListToJSONString(flex_ops);
|
||||
return 0;
|
||||
}
|
203
tensorflow/lite/tools/list_flex_ops_test.cc
Normal file
203
tensorflow/lite/tools/list_flex_ops_test.cc
Normal file
@ -0,0 +1,203 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/tools/list_flex_ops.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
class FlexOpsListTest : public ::testing::Test {
|
||||
protected:
|
||||
FlexOpsListTest() {}
|
||||
|
||||
void ReadOps(const string& model_path) {
|
||||
auto model = FlatBufferModel::BuildFromFile(model_path.data());
|
||||
AddFlexOpsFromModel(model->GetModel(), &flex_ops_);
|
||||
output_text_ = OpListToJSONString(flex_ops_);
|
||||
}
|
||||
|
||||
void ReadOps(const tflite::Model* model) {
|
||||
AddFlexOpsFromModel(model, &flex_ops_);
|
||||
output_text_ = OpListToJSONString(flex_ops_);
|
||||
}
|
||||
|
||||
std::string output_text_;
|
||||
OpKernelSet flex_ops_;
|
||||
};
|
||||
|
||||
TfLiteRegistration* Register_TEST() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr, nullptr};
|
||||
return &r;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> CreateFlexCustomOptions(std::string nodedef_raw_string) {
|
||||
tensorflow::NodeDef node_def;
|
||||
tensorflow::protobuf::TextFormat::ParseFromString(nodedef_raw_string,
|
||||
&node_def);
|
||||
std::string node_def_str = node_def.SerializeAsString();
|
||||
auto flex_builder = std::make_unique<flexbuffers::Builder>();
|
||||
flex_builder->Vector([&]() {
|
||||
flex_builder->String(node_def.op());
|
||||
flex_builder->String(node_def_str);
|
||||
});
|
||||
flex_builder->Finish();
|
||||
return flex_builder->GetBuffer();
|
||||
}
|
||||
|
||||
class FlexOpModel : public SingleOpModel {
|
||||
public:
|
||||
FlexOpModel(const std::string& op_name, const TensorData& input1,
|
||||
const TensorData& input2, const TensorType& output,
|
||||
const std::vector<uint8_t>& custom_options) {
|
||||
input1_ = AddInput(input1);
|
||||
input2_ = AddInput(input2);
|
||||
output_ = AddOutput(output);
|
||||
SetCustomOp(op_name, custom_options, Register_TEST);
|
||||
BuildInterpreter({GetShape(input1_), GetShape(input2_)});
|
||||
}
|
||||
|
||||
protected:
|
||||
int input1_;
|
||||
int input2_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST_F(FlexOpsListTest, TestModelsNoFlex) {
|
||||
ReadOps("third_party/tensorflow/lite/testdata/test_model.bin");
|
||||
EXPECT_EQ(output_text_, "[]");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestBrokenModel) {
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ReadOps("third_party/tensorflow/lite/testdata/test_model_broken.bin"),
|
||||
"");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestZeroSubgraphs) {
|
||||
ReadOps("third_party/tensorflow/lite/testdata/0_subgraphs.bin");
|
||||
EXPECT_EQ(output_text_, "[]");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestFlexAdd) {
|
||||
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestTwoModel) {
|
||||
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
ReadOps("third_party/tensorflow/lite/testdata/softplus_flex.bin");
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\", \"BinaryOp<CPUDevice, "
|
||||
"functor::add<float>>\"],\n[\"Softplus\", \"SoftplusOp<CPUDevice, "
|
||||
"float>\"]]");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestDuplicatedOp) {
|
||||
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestInvalidCustomOptions) {
|
||||
// Using a invalid custom options, expected to fail.
|
||||
std::vector<uint8_t> random_custom_options(20);
|
||||
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
|
||||
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
|
||||
random_custom_options);
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ReadOps(tflite::GetModel(max_model.GetModelBuffer())),
|
||||
"Failed to parse data into a valid NodeDef");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestOpNameEmpty) {
|
||||
// NodeDef with empty opname.
|
||||
std::string nodedef_raw_str =
|
||||
"name: \"node_1\""
|
||||
"op: \"\""
|
||||
"input: [ \"b\", \"c\" ]"
|
||||
"attr: { key: \"T\" value: { type: DT_FLOAT } }";
|
||||
std::string random_fieldname = "random string";
|
||||
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
|
||||
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
|
||||
CreateFlexCustomOptions(nodedef_raw_str));
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ReadOps(tflite::GetModel(max_model.GetModelBuffer())), "Invalid NodeDef");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestOpNotFound) {
|
||||
// NodeDef with invalid opname.
|
||||
std::string nodedef_raw_str =
|
||||
"name: \"node_1\""
|
||||
"op: \"FlexInvalidOp\""
|
||||
"input: [ \"b\", \"c\" ]"
|
||||
"attr: { key: \"T\" value: { type: DT_FLOAT } }";
|
||||
|
||||
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
|
||||
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
|
||||
CreateFlexCustomOptions(nodedef_raw_str));
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ReadOps(tflite::GetModel(max_model.GetModelBuffer())),
|
||||
"Op FlexInvalidOp not found");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestKernelNotFound) {
|
||||
// NodeDef with non-supported type.
|
||||
std::string nodedef_raw_str =
|
||||
"name: \"node_1\""
|
||||
"op: \"Add\""
|
||||
"input: [ \"b\", \"c\" ]"
|
||||
"attr: { key: \"T\" value: { type: DT_BOOL } }";
|
||||
|
||||
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
|
||||
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
|
||||
CreateFlexCustomOptions(nodedef_raw_str));
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ReadOps(tflite::GetModel(max_model.GetModelBuffer())),
|
||||
"Failed to find kernel class for op: Add");
|
||||
}
|
||||
|
||||
TEST_F(FlexOpsListTest, TestFlexAddWithSingleOpModel) {
|
||||
std::string nodedef_raw_str =
|
||||
"name: \"node_1\""
|
||||
"op: \"Add\""
|
||||
"input: [ \"b\", \"c\" ]"
|
||||
"attr: { key: \"T\" value: { type: DT_FLOAT } }";
|
||||
|
||||
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
|
||||
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
|
||||
CreateFlexCustomOptions(nodedef_raw_str));
|
||||
ReadOps(tflite::GetModel(max_model.GetModelBuffer()));
|
||||
EXPECT_EQ(output_text_,
|
||||
"[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]");
|
||||
}
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// On Linux, add: FLAGS_logtostderr = true;
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -46,8 +46,10 @@ FLAGS = None
|
||||
|
||||
def main(unused_argv):
|
||||
graphs = FLAGS.graphs.split(',')
|
||||
print(selective_registration_header_lib.get_header(
|
||||
graphs, FLAGS.proto_fileformat, FLAGS.default_ops))
|
||||
print(
|
||||
selective_registration_header_lib.get_header(graphs,
|
||||
FLAGS.proto_fileformat,
|
||||
FLAGS.default_ops))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -63,7 +65,9 @@ if __name__ == '__main__':
|
||||
'--proto_fileformat',
|
||||
type=str,
|
||||
default='rawproto',
|
||||
help='Format of proto file, either textproto or rawproto.')
|
||||
help='Format of proto file, either textproto, rawproto or ops_list. The '
|
||||
'ops_list is the file contains the list of ops in JSON format. Ex: '
|
||||
'"[["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]".')
|
||||
parser.add_argument(
|
||||
'--default_ops',
|
||||
type=str,
|
||||
|
@ -93,6 +93,12 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
fnames.append(fname)
|
||||
return fnames
|
||||
|
||||
def WriteTextFile(self, content):
|
||||
fname = os.path.join(self.get_temp_dir(), 'text.txt')
|
||||
with gfile.GFile(fname, 'w') as f:
|
||||
f.write(content)
|
||||
return [fname]
|
||||
|
||||
def testGetOps(self):
|
||||
default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
|
||||
graphs = [
|
||||
@ -136,6 +142,59 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
],
|
||||
ops_and_kernels)
|
||||
|
||||
def testGetOpsFromList(self):
|
||||
default_ops = ''
|
||||
# Test with 2 different ops.
|
||||
ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"],
|
||||
["Softplus", "SoftplusOp<CPUDevice, float>"]]"""
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'ops_list', self.WriteTextFile(ops_list), default_ops)
|
||||
self.assertListEqual([
|
||||
('Add', 'BinaryOp<CPUDevice, functor::add<float>>'),
|
||||
('Softplus', 'SoftplusOp<CPUDevice, float>'),
|
||||
], ops_and_kernels)
|
||||
|
||||
# Test with a single op.
|
||||
ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]'
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'ops_list', self.WriteTextFile(ops_list), default_ops)
|
||||
self.assertListEqual([
|
||||
('Softplus', 'SoftplusOp<CPUDevice, float>'),
|
||||
], ops_and_kernels)
|
||||
|
||||
# Test with duplicated op.
|
||||
ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"],
|
||||
["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]"""
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'ops_list', self.WriteTextFile(ops_list), default_ops)
|
||||
self.assertListEqual([
|
||||
('Add', 'BinaryOp<CPUDevice, functor::add<float>>'),
|
||||
], ops_and_kernels)
|
||||
|
||||
# Test op with no kernel.
|
||||
ops_list = '[["Softplus", ""]]'
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'ops_list', self.WriteTextFile(ops_list), default_ops)
|
||||
self.assertListEqual([
|
||||
('Softplus', None),
|
||||
], ops_and_kernels)
|
||||
|
||||
# Test two ops_list files.
|
||||
ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]'
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'ops_list',
|
||||
self.WriteTextFile(ops_list) + self.WriteTextFile(ops_list),
|
||||
default_ops)
|
||||
self.assertListEqual([
|
||||
('Softplus', 'SoftplusOp<CPUDevice, float>'),
|
||||
], ops_and_kernels)
|
||||
|
||||
# Test empty file.
|
||||
ops_list = ''
|
||||
with self.assertRaises(Exception):
|
||||
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
|
||||
'ops_list', self.WriteTextFile(ops_list), default_ops)
|
||||
|
||||
def testAll(self):
|
||||
default_ops = 'all'
|
||||
graphs = [
|
||||
|
@ -22,11 +22,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import _pywrap_kernel_registry
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -41,6 +41,39 @@ OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
|
||||
# core/common_runtime/accumulate_n_optimizer.cc.
|
||||
'AccumulateNV2'
|
||||
])
|
||||
FLEX_PREFIX = b'Flex'
|
||||
FLEX_PREFIX_LENGTH = len(FLEX_PREFIX)
|
||||
|
||||
|
||||
def _get_ops_from_ops_list(input_file):
|
||||
"""Gets the ops and kernels needed from the ops list file."""
|
||||
ops = set()
|
||||
ops_list_str = gfile.GFile(input_file, 'r').read()
|
||||
if not ops_list_str:
|
||||
raise Exception('Input file should not be empty')
|
||||
ops_list = json.loads(ops_list_str)
|
||||
for op, kernel in ops_list:
|
||||
op_and_kernel = (op, kernel if kernel else None)
|
||||
ops.add(op_and_kernel)
|
||||
return ops
|
||||
|
||||
|
||||
def _get_ops_from_graphdef(graph_def):
|
||||
"""Gets the ops and kernels needed from the tensorflow model."""
|
||||
ops = set()
|
||||
for node_def in graph_def.node:
|
||||
if not node_def.device:
|
||||
node_def.device = '/cpu:0'
|
||||
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
|
||||
node_def.SerializeToString())
|
||||
op = str(node_def.op)
|
||||
if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
|
||||
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
|
||||
if kernel_class else None)
|
||||
ops.add(op_and_kernel)
|
||||
else:
|
||||
print('Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)
|
||||
return ops
|
||||
|
||||
|
||||
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
|
||||
@ -49,6 +82,11 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
|
||||
|
||||
for proto_file in proto_files:
|
||||
tf_logging.info('Loading proto file %s', proto_file)
|
||||
# Load ops list file.
|
||||
if proto_fileformat == 'ops_list':
|
||||
ops = ops.union(_get_ops_from_ops_list(proto_file))
|
||||
continue
|
||||
|
||||
# Load GraphDef.
|
||||
file_data = gfile.GFile(proto_file, 'rb').read()
|
||||
if proto_fileformat == 'rawproto':
|
||||
@ -56,22 +94,7 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
|
||||
else:
|
||||
assert proto_fileformat == 'textproto'
|
||||
graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())
|
||||
|
||||
# Find all ops and kernels used by the graph.
|
||||
for node_def in graph_def.node:
|
||||
if not node_def.device:
|
||||
node_def.device = '/cpu:0'
|
||||
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
|
||||
node_def.SerializeToString())
|
||||
op = str(node_def.op)
|
||||
if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
|
||||
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
|
||||
if kernel_class else None)
|
||||
if op_and_kernel not in ops:
|
||||
ops.add(op_and_kernel)
|
||||
else:
|
||||
print(
|
||||
'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)
|
||||
ops = ops.union(_get_ops_from_graphdef(graph_def))
|
||||
|
||||
# Add default ops.
|
||||
if default_ops_str and default_ops_str != 'all':
|
||||
@ -91,7 +114,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
Args:
|
||||
ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include.
|
||||
include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op
|
||||
kernels are included.
|
||||
kernels are included.
|
||||
|
||||
Returns:
|
||||
the string of the header that should be written as ops_to_register.h.
|
||||
@ -112,7 +135,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
|
||||
append('#define SHOULD_REGISTER_OP_GRADIENT true')
|
||||
else:
|
||||
line = '''
|
||||
line = """
|
||||
namespace {
|
||||
constexpr const char* skip(const char* x) {
|
||||
return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
|
||||
@ -138,10 +161,11 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
'''
|
||||
"""
|
||||
line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
|
||||
for _, kernel_class in ops_and_kernels:
|
||||
if kernel_class is None: continue
|
||||
if kernel_class is None:
|
||||
continue
|
||||
line += '"%s",\n' % kernel_class
|
||||
line += '};'
|
||||
append(line)
|
||||
@ -160,8 +184,8 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
|
||||
append('')
|
||||
|
||||
append('#define SHOULD_REGISTER_OP_GRADIENT ' + (
|
||||
'true' if 'SymbolicGradient' in ops else 'false'))
|
||||
append('#define SHOULD_REGISTER_OP_GRADIENT ' +
|
||||
('true' if 'SymbolicGradient' in ops else 'false'))
|
||||
|
||||
append('#endif')
|
||||
return '\n'.join(result_list)
|
||||
@ -174,11 +198,13 @@ def get_header(graphs,
|
||||
|
||||
Args:
|
||||
graphs: a list of paths to GraphDef files to include.
|
||||
proto_fileformat: optional format of proto file, either 'textproto' or
|
||||
'rawproto' (default).
|
||||
proto_fileformat: optional format of proto file, either 'textproto',
|
||||
'rawproto' (default) or ops_list. The ops_list is the file contain the
|
||||
list of ops in JSON format, Ex: "[["Transpose", "TransposeCpuOp"]]".
|
||||
default_ops: optional comma-separated string of operator:kernel pairs to
|
||||
always include implementation for. Pass 'all' to have all operators and
|
||||
kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
|
||||
|
||||
Returns:
|
||||
the string of the header that should be written as ops_to_register.h.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user