Support TFLite in the tool to print selective registration header

PiperOrigin-RevId: 314282031
Change-Id: Ie71b434c177d03e246a5cfde3d067ac695b71299
This commit is contained in:
Thai Nguyen 2020-06-02 00:45:48 -07:00 committed by TensorFlower Gardener
parent e2aa757a55
commit e9781e9b16
12 changed files with 618 additions and 29 deletions

View File

@ -486,6 +486,9 @@ class SingleOpModel {
return std::vector<T>(v, v + tensor_size);
}
// Return the TFLite model buffer, only available after BuildInterpreter.
const uint8_t* GetModelBuffer() { return builder_.GetBufferPointer(); }
std::vector<int> GetTensorShape(int index) {
std::vector<int> result;
TfLiteTensor* t = interpreter_->tensor(index);

Binary file not shown.

View File

@ -134,6 +134,7 @@ cc_binary(
deps = [
":command_line_flags",
":gen_op_registration",
"//tensorflow/lite:util",
"@com_google_absl//absl/strings",
],
)
@ -252,6 +253,60 @@ cc_test(
],
)
cc_library(
name = "list_flex_ops",
srcs = ["list_flex_ops.cc"],
hdrs = ["list_flex_ops.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/lite:framework",
"//tensorflow/lite:util",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
)
# This tool list flex ops and kernels inside a TFLite file.
# It is used to generate header file for selective registration.
cc_binary(
name = "list_flex_ops_main",
srcs = ["list_flex_ops_main.cc"],
visibility = ["//visibility:public"],
deps = [
":list_flex_ops",
"//tensorflow/lite/tools:command_line_flags",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "list_flex_ops_test",
srcs = ["list_flex_ops_test.cc"],
data = [
"//tensorflow/lite:testdata/0_subgraphs.bin",
"//tensorflow/lite:testdata/multi_add_flex.bin",
"//tensorflow/lite:testdata/softplus_flex.bin",
"//tensorflow/lite:testdata/test_model.bin",
"//tensorflow/lite:testdata/test_model_broken.bin",
],
tags = [
"no_oss", # Currently requires --config=monolithic, b/118895218.
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
deps = [
":list_flex_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:protobuf",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
py_binary(
name = "zip_files",
srcs = ["zip_files.py"],

View File

@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/gen_op_registration.h"
#include <string>
#include <vector>
#include "re2/re2.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
namespace tflite {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/strip.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
#include "tensorflow/lite/util.h"
const char kInputModelFlag[] = "input_models";
const char kNamespace[] = "namespace";
@ -84,6 +85,8 @@ void GenerateFileContent(const std::string& tflite_path,
fout << "namespace custom {\n";
fout << "// Forward-declarations for the custom ops.\n";
for (const auto& op : custom_ops) {
// Skips Tensorflow ops, only TFLite custom ops can be registered here.
if (tflite::IsFlexOp(op.first.c_str())) continue;
fout << "TfLiteRegistration* Register_"
<< ::tflite::NormalizeCustomOpName(op.first) << "();\n";
}
@ -115,6 +118,8 @@ void GenerateFileContent(const std::string& tflite_path,
fout << ");\n";
}
for (const auto& op : custom_ops) {
// Skips Tensorflow ops, only TFLite custom ops can be registered here.
if (tflite::IsFlexOp(op.first.c_str())) continue;
fout << " resolver->AddCustom(\"" << op.first
<< "\", ::tflite::ops::custom::Register_"
<< ::tflite::NormalizeCustomOpName(op.first) << "()";

View File

@ -0,0 +1,128 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/list_flex_ops.h"
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace flex {
std::string OpListToJSONString(const OpKernelSet& flex_ops) {
return absl::StrCat("[",
absl::StrJoin(flex_ops, ",\n",
[](std::string* out, const OpKernel& op) {
absl::StrAppend(out, "[\"", op.op_name,
"\", \"", op.kernel_name,
"\"]");
}),
"]");
}
// Find the class name of the op kernel described in the node_def from the pool
// of registered ops. If no kernel class is found, return an empty string.
string FindTensorflowKernelClass(tensorflow::NodeDef* node_def) {
if (!node_def || node_def->op().empty()) {
LOG(FATAL) << "Invalid NodeDef";
}
const tensorflow::OpRegistrationData* op_reg_data;
auto status =
tensorflow::OpRegistry::Global()->LookUp(node_def->op(), &op_reg_data);
if (!status.ok()) {
LOG(FATAL) << "Op " << node_def->op() << " not found: " << status;
}
AddDefaultsToNodeDef(op_reg_data->op_def, node_def);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(node_def->device(),
&parsed_name)) {
LOG(FATAL) << "Failed to parse device from node_def: "
<< node_def->ShortDebugString();
}
string class_name;
if (!tensorflow::FindKernelDef(
tensorflow::DeviceType(parsed_name.type.c_str()), *node_def,
nullptr /* kernel_def */, &class_name)
.ok()) {
LOG(FATAL) << "Failed to find kernel class for op: " << node_def->op();
}
return class_name;
}
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) {
// Read flex ops.
auto* subgraphs = model->subgraphs();
if (!subgraphs) return;
for (int subgraph_index = 0; subgraph_index < subgraphs->size();
++subgraph_index) {
const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index);
auto* operators = subgraph->operators();
auto* opcodes = model->operator_codes();
if (!operators || !opcodes) continue;
for (int i = 0; i < operators->size(); ++i) {
const tflite::Operator* op = operators->Get(i);
const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index());
if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM ||
!tflite::IsFlexOp(opcode->custom_code()->c_str())) {
continue;
}
// Remove the "Flex" prefix from op name.
std::string flex_op_name(opcode->custom_code()->c_str());
std::string tf_op_name =
flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix));
// Read NodeDef and find the op kernel class.
if (op->custom_options_format() !=
tflite::CustomOptionsFormat_FLEXBUFFERS) {
LOG(FATAL) << "Invalid CustomOptionsFormat";
}
const flatbuffers::Vector<uint8_t>* custom_opt_bytes =
op->custom_options();
if (custom_opt_bytes && custom_opt_bytes->size()) {
// NOLINTNEXTLINE: It is common to use references with flatbuffer.
const flexbuffers::Vector& v =
flexbuffers::GetRoot(custom_opt_bytes->data(),
custom_opt_bytes->size())
.AsVector();
std::string nodedef_str = v[1].AsString().str();
tensorflow::NodeDef nodedef;
if (nodedef_str.empty() || !nodedef.ParseFromString(nodedef_str)) {
LOG(FATAL) << "Failed to parse data into a valid NodeDef";
}
// Flex delegate only supports running flex ops with CPU.
*nodedef.mutable_device() = "/CPU:0";
std::string kernel_class = FindTensorflowKernelClass(&nodedef);
flex_ops->insert({tf_op_name, kernel_class});
}
}
}
}
} // namespace flex
} // namespace tflite

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_
#define TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_
#include <set>
#include <string>
#include "tensorflow/lite/model.h"
namespace tflite {
namespace flex {
// Store the Op and Kernel name of an op as the key of a set or map.
struct OpKernel {
std::string op_name;
std::string kernel_name;
};
// The comparison function for OpKernel.
struct OpKernelCompare {
bool operator()(const OpKernel& lhs, const OpKernel& rhs) const {
if (lhs.op_name == rhs.op_name) {
return lhs.kernel_name < rhs.kernel_name;
}
return lhs.op_name < rhs.op_name;
}
};
using OpKernelSet = std::set<OpKernel, OpKernelCompare>;
// Find flex ops and its kernel classes inside a TFLite model and add them to
// the map flex_ops. The map stores
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops);
// Serialize the list op of to a json string. If flex_ops is empty, return an
// empty string.
std::string OpListToJSONString(const OpKernelSet& flex_ops);
} // namespace flex
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <fstream>
#include <iostream>
#include <sstream>
#include "absl/strings/str_split.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/list_flex_ops.h"
const char kInputModelsFlag[] = "graphs";
int main(int argc, char** argv) {
std::string input_models;
std::vector<tflite::Flag> flag_list = {
tflite::Flag::CreateFlag(kInputModelsFlag, &input_models,
"path to the tflite models, separated by comma.",
tflite::Flag::kRequired),
};
tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
std::vector<std::string> models = absl::StrSplit(input_models, ',');
tflite::flex::OpKernelSet flex_ops;
for (const std::string& model_file : models) {
std::ifstream fin;
fin.exceptions(std::ifstream::failbit | std::ifstream::badbit);
fin.open(model_file);
std::stringstream content;
content << fin.rdbuf();
// Need to store content data first, otherwise, it won't work in bazel.
std::string content_str = content.str();
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
tflite::flex::AddFlexOpsFromModel(model, &flex_ops);
}
std::cout << tflite::flex::OpListToJSONString(flex_ops);
return 0;
}

View File

@ -0,0 +1,203 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/list_flex_ops.h"
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace tflite {
namespace flex {
class FlexOpsListTest : public ::testing::Test {
protected:
FlexOpsListTest() {}
void ReadOps(const string& model_path) {
auto model = FlatBufferModel::BuildFromFile(model_path.data());
AddFlexOpsFromModel(model->GetModel(), &flex_ops_);
output_text_ = OpListToJSONString(flex_ops_);
}
void ReadOps(const tflite::Model* model) {
AddFlexOpsFromModel(model, &flex_ops_);
output_text_ = OpListToJSONString(flex_ops_);
}
std::string output_text_;
OpKernelSet flex_ops_;
};
TfLiteRegistration* Register_TEST() {
static TfLiteRegistration r = {nullptr, nullptr, nullptr, nullptr};
return &r;
}
std::vector<uint8_t> CreateFlexCustomOptions(std::string nodedef_raw_string) {
tensorflow::NodeDef node_def;
tensorflow::protobuf::TextFormat::ParseFromString(nodedef_raw_string,
&node_def);
std::string node_def_str = node_def.SerializeAsString();
auto flex_builder = std::make_unique<flexbuffers::Builder>();
flex_builder->Vector([&]() {
flex_builder->String(node_def.op());
flex_builder->String(node_def_str);
});
flex_builder->Finish();
return flex_builder->GetBuffer();
}
class FlexOpModel : public SingleOpModel {
public:
FlexOpModel(const std::string& op_name, const TensorData& input1,
const TensorData& input2, const TensorType& output,
const std::vector<uint8_t>& custom_options) {
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
SetCustomOp(op_name, custom_options, Register_TEST);
BuildInterpreter({GetShape(input1_), GetShape(input2_)});
}
protected:
int input1_;
int input2_;
int output_;
};
TEST_F(FlexOpsListTest, TestModelsNoFlex) {
ReadOps("third_party/tensorflow/lite/testdata/test_model.bin");
EXPECT_EQ(output_text_, "[]");
}
TEST_F(FlexOpsListTest, TestBrokenModel) {
EXPECT_DEATH_IF_SUPPORTED(
ReadOps("third_party/tensorflow/lite/testdata/test_model_broken.bin"),
"");
}
TEST_F(FlexOpsListTest, TestZeroSubgraphs) {
ReadOps("third_party/tensorflow/lite/testdata/0_subgraphs.bin");
EXPECT_EQ(output_text_, "[]");
}
TEST_F(FlexOpsListTest, TestFlexAdd) {
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
EXPECT_EQ(output_text_,
"[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]");
}
TEST_F(FlexOpsListTest, TestTwoModel) {
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
ReadOps("third_party/tensorflow/lite/testdata/softplus_flex.bin");
EXPECT_EQ(output_text_,
"[[\"Add\", \"BinaryOp<CPUDevice, "
"functor::add<float>>\"],\n[\"Softplus\", \"SoftplusOp<CPUDevice, "
"float>\"]]");
}
TEST_F(FlexOpsListTest, TestDuplicatedOp) {
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin");
EXPECT_EQ(output_text_,
"[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]");
}
TEST_F(FlexOpsListTest, TestInvalidCustomOptions) {
// Using a invalid custom options, expected to fail.
std::vector<uint8_t> random_custom_options(20);
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
random_custom_options);
EXPECT_DEATH_IF_SUPPORTED(
ReadOps(tflite::GetModel(max_model.GetModelBuffer())),
"Failed to parse data into a valid NodeDef");
}
TEST_F(FlexOpsListTest, TestOpNameEmpty) {
// NodeDef with empty opname.
std::string nodedef_raw_str =
"name: \"node_1\""
"op: \"\""
"input: [ \"b\", \"c\" ]"
"attr: { key: \"T\" value: { type: DT_FLOAT } }";
std::string random_fieldname = "random string";
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
CreateFlexCustomOptions(nodedef_raw_str));
EXPECT_DEATH_IF_SUPPORTED(
ReadOps(tflite::GetModel(max_model.GetModelBuffer())), "Invalid NodeDef");
}
TEST_F(FlexOpsListTest, TestOpNotFound) {
// NodeDef with invalid opname.
std::string nodedef_raw_str =
"name: \"node_1\""
"op: \"FlexInvalidOp\""
"input: [ \"b\", \"c\" ]"
"attr: { key: \"T\" value: { type: DT_FLOAT } }";
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
CreateFlexCustomOptions(nodedef_raw_str));
EXPECT_DEATH_IF_SUPPORTED(
ReadOps(tflite::GetModel(max_model.GetModelBuffer())),
"Op FlexInvalidOp not found");
}
TEST_F(FlexOpsListTest, TestKernelNotFound) {
// NodeDef with non-supported type.
std::string nodedef_raw_str =
"name: \"node_1\""
"op: \"Add\""
"input: [ \"b\", \"c\" ]"
"attr: { key: \"T\" value: { type: DT_BOOL } }";
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
CreateFlexCustomOptions(nodedef_raw_str));
EXPECT_DEATH_IF_SUPPORTED(
ReadOps(tflite::GetModel(max_model.GetModelBuffer())),
"Failed to find kernel class for op: Add");
}
TEST_F(FlexOpsListTest, TestFlexAddWithSingleOpModel) {
std::string nodedef_raw_str =
"name: \"node_1\""
"op: \"Add\""
"input: [ \"b\", \"c\" ]"
"attr: { key: \"T\" value: { type: DT_FLOAT } }";
FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}},
{TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32,
CreateFlexCustomOptions(nodedef_raw_str));
ReadOps(tflite::GetModel(max_model.GetModelBuffer()));
EXPECT_EQ(output_text_,
"[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]");
}
} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
// On Linux, add: FLAGS_logtostderr = true;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -46,8 +46,10 @@ FLAGS = None
def main(unused_argv):
graphs = FLAGS.graphs.split(',')
print(selective_registration_header_lib.get_header(
graphs, FLAGS.proto_fileformat, FLAGS.default_ops))
print(
selective_registration_header_lib.get_header(graphs,
FLAGS.proto_fileformat,
FLAGS.default_ops))
if __name__ == '__main__':
@ -63,7 +65,9 @@ if __name__ == '__main__':
'--proto_fileformat',
type=str,
default='rawproto',
help='Format of proto file, either textproto or rawproto.')
help='Format of proto file, either textproto, rawproto or ops_list. The '
'ops_list is the file contains the list of ops in JSON format. Ex: '
'"[["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]".')
parser.add_argument(
'--default_ops',
type=str,

View File

@ -93,6 +93,12 @@ class PrintOpFilegroupTest(test.TestCase):
fnames.append(fname)
return fnames
def WriteTextFile(self, content):
fname = os.path.join(self.get_temp_dir(), 'text.txt')
with gfile.GFile(fname, 'w') as f:
f.write(content)
return [fname]
def testGetOps(self):
default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
graphs = [
@ -136,6 +142,59 @@ class PrintOpFilegroupTest(test.TestCase):
],
ops_and_kernels)
def testGetOpsFromList(self):
default_ops = ''
# Test with 2 different ops.
ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"],
["Softplus", "SoftplusOp<CPUDevice, float>"]]"""
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'ops_list', self.WriteTextFile(ops_list), default_ops)
self.assertListEqual([
('Add', 'BinaryOp<CPUDevice, functor::add<float>>'),
('Softplus', 'SoftplusOp<CPUDevice, float>'),
], ops_and_kernels)
# Test with a single op.
ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]'
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'ops_list', self.WriteTextFile(ops_list), default_ops)
self.assertListEqual([
('Softplus', 'SoftplusOp<CPUDevice, float>'),
], ops_and_kernels)
# Test with duplicated op.
ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"],
["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]"""
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'ops_list', self.WriteTextFile(ops_list), default_ops)
self.assertListEqual([
('Add', 'BinaryOp<CPUDevice, functor::add<float>>'),
], ops_and_kernels)
# Test op with no kernel.
ops_list = '[["Softplus", ""]]'
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'ops_list', self.WriteTextFile(ops_list), default_ops)
self.assertListEqual([
('Softplus', None),
], ops_and_kernels)
# Test two ops_list files.
ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]'
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'ops_list',
self.WriteTextFile(ops_list) + self.WriteTextFile(ops_list),
default_ops)
self.assertListEqual([
('Softplus', 'SoftplusOp<CPUDevice, float>'),
], ops_and_kernels)
# Test empty file.
ops_list = ''
with self.assertRaises(Exception):
ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
'ops_list', self.WriteTextFile(ops_list), default_ops)
def testAll(self):
default_ops = 'all'
graphs = [

View File

@ -22,11 +22,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import sys
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python import _pywrap_kernel_registry
from tensorflow.python.platform import gfile
@ -41,6 +41,39 @@ OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
# core/common_runtime/accumulate_n_optimizer.cc.
'AccumulateNV2'
])
FLEX_PREFIX = b'Flex'
FLEX_PREFIX_LENGTH = len(FLEX_PREFIX)
def _get_ops_from_ops_list(input_file):
"""Gets the ops and kernels needed from the ops list file."""
ops = set()
ops_list_str = gfile.GFile(input_file, 'r').read()
if not ops_list_str:
raise Exception('Input file should not be empty')
ops_list = json.loads(ops_list_str)
for op, kernel in ops_list:
op_and_kernel = (op, kernel if kernel else None)
ops.add(op_and_kernel)
return ops
def _get_ops_from_graphdef(graph_def):
"""Gets the ops and kernels needed from the tensorflow model."""
ops = set()
for node_def in graph_def.node:
if not node_def.device:
node_def.device = '/cpu:0'
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
node_def.SerializeToString())
op = str(node_def.op)
if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
if kernel_class else None)
ops.add(op_and_kernel)
else:
print('Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)
return ops
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
@ -49,6 +82,11 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
for proto_file in proto_files:
tf_logging.info('Loading proto file %s', proto_file)
# Load ops list file.
if proto_fileformat == 'ops_list':
ops = ops.union(_get_ops_from_ops_list(proto_file))
continue
# Load GraphDef.
file_data = gfile.GFile(proto_file, 'rb').read()
if proto_fileformat == 'rawproto':
@ -56,22 +94,7 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
else:
assert proto_fileformat == 'textproto'
graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())
# Find all ops and kernels used by the graph.
for node_def in graph_def.node:
if not node_def.device:
node_def.device = '/cpu:0'
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
node_def.SerializeToString())
op = str(node_def.op)
if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
if kernel_class else None)
if op_and_kernel not in ops:
ops.add(op_and_kernel)
else:
print(
'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)
ops = ops.union(_get_ops_from_graphdef(graph_def))
# Add default ops.
if default_ops_str and default_ops_str != 'all':
@ -91,7 +114,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
Args:
ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include.
include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op
kernels are included.
kernels are included.
Returns:
the string of the header that should be written as ops_to_register.h.
@ -112,7 +135,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
append('#define SHOULD_REGISTER_OP_GRADIENT true')
else:
line = '''
line = """
namespace {
constexpr const char* skip(const char* x) {
return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
@ -138,10 +161,11 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
}
};
} // end namespace
'''
"""
line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
for _, kernel_class in ops_and_kernels:
if kernel_class is None: continue
if kernel_class is None:
continue
line += '"%s",\n' % kernel_class
line += '};'
append(line)
@ -160,8 +184,8 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
append('')
append('#define SHOULD_REGISTER_OP_GRADIENT ' + (
'true' if 'SymbolicGradient' in ops else 'false'))
append('#define SHOULD_REGISTER_OP_GRADIENT ' +
('true' if 'SymbolicGradient' in ops else 'false'))
append('#endif')
return '\n'.join(result_list)
@ -174,11 +198,13 @@ def get_header(graphs,
Args:
graphs: a list of paths to GraphDef files to include.
proto_fileformat: optional format of proto file, either 'textproto' or
'rawproto' (default).
proto_fileformat: optional format of proto file, either 'textproto',
'rawproto' (default) or ops_list. The ops_list is the file contain the
list of ops in JSON format, Ex: "[["Transpose", "TransposeCpuOp"]]".
default_ops: optional comma-separated string of operator:kernel pairs to
always include implementation for. Pass 'all' to have all operators and
kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
Returns:
the string of the header that should be written as ops_to_register.h.
"""