Support TFLite in the tool to print selective registration header
PiperOrigin-RevId: 314282031 Change-Id: Ie71b434c177d03e246a5cfde3d067ac695b71299
This commit is contained in:
		
							parent
							
								
									e2aa757a55
								
							
						
					
					
						commit
						e9781e9b16
					
				| @ -486,6 +486,9 @@ class SingleOpModel { | ||||
|     return std::vector<T>(v, v + tensor_size); | ||||
|   } | ||||
| 
 | ||||
|   // Return the TFLite model buffer, only available after BuildInterpreter.
 | ||||
|   const uint8_t* GetModelBuffer() { return builder_.GetBufferPointer(); } | ||||
| 
 | ||||
|   std::vector<int> GetTensorShape(int index) { | ||||
|     std::vector<int> result; | ||||
|     TfLiteTensor* t = interpreter_->tensor(index); | ||||
|  | ||||
							
								
								
									
										
											BIN
										
									
								
								tensorflow/lite/testdata/softplus_flex.bin
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tensorflow/lite/testdata/softplus_flex.bin
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| @ -134,6 +134,7 @@ cc_binary( | ||||
|     deps = [ | ||||
|         ":command_line_flags", | ||||
|         ":gen_op_registration", | ||||
|         "//tensorflow/lite:util", | ||||
|         "@com_google_absl//absl/strings", | ||||
|     ], | ||||
| ) | ||||
| @ -252,6 +253,60 @@ cc_test( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "list_flex_ops", | ||||
|     srcs = ["list_flex_ops.cc"], | ||||
|     hdrs = ["list_flex_ops.h"], | ||||
|     deps = [ | ||||
|         "//tensorflow/core:framework", | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|         "//tensorflow/core:tensorflow", | ||||
|         "//tensorflow/lite:framework", | ||||
|         "//tensorflow/lite:util", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@flatbuffers", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| # This tool list flex ops and kernels inside a TFLite file. | ||||
| # It is used to generate header file for selective registration. | ||||
| cc_binary( | ||||
|     name = "list_flex_ops_main", | ||||
|     srcs = ["list_flex_ops_main.cc"], | ||||
|     visibility = ["//visibility:public"], | ||||
|     deps = [ | ||||
|         ":list_flex_ops", | ||||
|         "//tensorflow/lite/tools:command_line_flags", | ||||
|         "@com_google_absl//absl/strings", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "list_flex_ops_test", | ||||
|     srcs = ["list_flex_ops_test.cc"], | ||||
|     data = [ | ||||
|         "//tensorflow/lite:testdata/0_subgraphs.bin", | ||||
|         "//tensorflow/lite:testdata/multi_add_flex.bin", | ||||
|         "//tensorflow/lite:testdata/softplus_flex.bin", | ||||
|         "//tensorflow/lite:testdata/test_model.bin", | ||||
|         "//tensorflow/lite:testdata/test_model_broken.bin", | ||||
|     ], | ||||
|     tags = [ | ||||
|         "no_oss",  # Currently requires --config=monolithic, b/118895218. | ||||
|         "tflite_not_portable_android", | ||||
|         "tflite_not_portable_ios", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":list_flex_ops", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|         "//tensorflow/core/platform:protobuf", | ||||
|         "//tensorflow/lite/kernels:test_util", | ||||
|         "@com_google_googletest//:gtest", | ||||
|         "@flatbuffers", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_binary( | ||||
|     name = "zip_files", | ||||
|     srcs = ["zip_files.py"], | ||||
|  | ||||
| @ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #include "tensorflow/lite/tools/gen_op_registration.h" | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "re2/re2.h" | ||||
| #include "tensorflow/lite/model.h" | ||||
| #include "tensorflow/lite/tools/gen_op_registration.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| 
 | ||||
|  | ||||
| @ -23,6 +23,7 @@ limitations under the License. | ||||
| #include "absl/strings/strip.h" | ||||
| #include "tensorflow/lite/tools/command_line_flags.h" | ||||
| #include "tensorflow/lite/tools/gen_op_registration.h" | ||||
| #include "tensorflow/lite/util.h" | ||||
| 
 | ||||
| const char kInputModelFlag[] = "input_models"; | ||||
| const char kNamespace[] = "namespace"; | ||||
| @ -84,6 +85,8 @@ void GenerateFileContent(const std::string& tflite_path, | ||||
|     fout << "namespace custom {\n"; | ||||
|     fout << "// Forward-declarations for the custom ops.\n"; | ||||
|     for (const auto& op : custom_ops) { | ||||
|       // Skips Tensorflow ops, only TFLite custom ops can be registered here.
 | ||||
|       if (tflite::IsFlexOp(op.first.c_str())) continue; | ||||
|       fout << "TfLiteRegistration* Register_" | ||||
|            << ::tflite::NormalizeCustomOpName(op.first) << "();\n"; | ||||
|     } | ||||
| @ -115,6 +118,8 @@ void GenerateFileContent(const std::string& tflite_path, | ||||
|     fout << ");\n"; | ||||
|   } | ||||
|   for (const auto& op : custom_ops) { | ||||
|     // Skips Tensorflow ops, only TFLite custom ops can be registered here.
 | ||||
|     if (tflite::IsFlexOp(op.first.c_str())) continue; | ||||
|     fout << "  resolver->AddCustom(\"" << op.first | ||||
|          << "\", ::tflite::ops::custom::Register_" | ||||
|          << ::tflite::NormalizeCustomOpName(op.first) << "()"; | ||||
|  | ||||
							
								
								
									
										128
									
								
								tensorflow/lite/tools/list_flex_ops.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								tensorflow/lite/tools/list_flex_ops.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,128 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #include "tensorflow/lite/tools/list_flex_ops.h" | ||||
| 
 | ||||
| #include <fstream> | ||||
| #include <sstream> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/strings/str_cat.h" | ||||
| #include "absl/strings/str_join.h" | ||||
| #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 | ||||
| #include "tensorflow/core/framework/node_def.pb.h" | ||||
| #include "tensorflow/core/framework/node_def_util.h" | ||||
| #include "tensorflow/core/framework/op.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/platform/logging.h" | ||||
| #include "tensorflow/core/util/device_name_utils.h" | ||||
| #include "tensorflow/lite/util.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace flex { | ||||
| 
 | ||||
| std::string OpListToJSONString(const OpKernelSet& flex_ops) { | ||||
|   return absl::StrCat("[", | ||||
|                       absl::StrJoin(flex_ops, ",\n", | ||||
|                                     [](std::string* out, const OpKernel& op) { | ||||
|                                       absl::StrAppend(out, "[\"", op.op_name, | ||||
|                                                       "\", \"", op.kernel_name, | ||||
|                                                       "\"]"); | ||||
|                                     }), | ||||
|                       "]"); | ||||
| } | ||||
| 
 | ||||
| // Find the class name of the op kernel described in the node_def from the pool
 | ||||
| // of registered ops. If no kernel class is found, return an empty string.
 | ||||
| string FindTensorflowKernelClass(tensorflow::NodeDef* node_def) { | ||||
|   if (!node_def || node_def->op().empty()) { | ||||
|     LOG(FATAL) << "Invalid NodeDef"; | ||||
|   } | ||||
| 
 | ||||
|   const tensorflow::OpRegistrationData* op_reg_data; | ||||
|   auto status = | ||||
|       tensorflow::OpRegistry::Global()->LookUp(node_def->op(), &op_reg_data); | ||||
|   if (!status.ok()) { | ||||
|     LOG(FATAL) << "Op " << node_def->op() << " not found: " << status; | ||||
|   } | ||||
|   AddDefaultsToNodeDef(op_reg_data->op_def, node_def); | ||||
| 
 | ||||
|   tensorflow::DeviceNameUtils::ParsedName parsed_name; | ||||
|   if (!tensorflow::DeviceNameUtils::ParseFullName(node_def->device(), | ||||
|                                                   &parsed_name)) { | ||||
|     LOG(FATAL) << "Failed to parse device from node_def: " | ||||
|                << node_def->ShortDebugString(); | ||||
|   } | ||||
|   string class_name; | ||||
|   if (!tensorflow::FindKernelDef( | ||||
|            tensorflow::DeviceType(parsed_name.type.c_str()), *node_def, | ||||
|            nullptr /* kernel_def */, &class_name) | ||||
|            .ok()) { | ||||
|     LOG(FATAL) << "Failed to find kernel class for op: " << node_def->op(); | ||||
|   } | ||||
|   return class_name; | ||||
| } | ||||
| 
 | ||||
| void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) { | ||||
|   // Read flex ops.
 | ||||
|   auto* subgraphs = model->subgraphs(); | ||||
|   if (!subgraphs) return; | ||||
|   for (int subgraph_index = 0; subgraph_index < subgraphs->size(); | ||||
|        ++subgraph_index) { | ||||
|     const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index); | ||||
|     auto* operators = subgraph->operators(); | ||||
|     auto* opcodes = model->operator_codes(); | ||||
|     if (!operators || !opcodes) continue; | ||||
|     for (int i = 0; i < operators->size(); ++i) { | ||||
|       const tflite::Operator* op = operators->Get(i); | ||||
|       const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index()); | ||||
|       if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM || | ||||
|           !tflite::IsFlexOp(opcode->custom_code()->c_str())) { | ||||
|         continue; | ||||
|       } | ||||
| 
 | ||||
|       // Remove the "Flex" prefix from op name.
 | ||||
|       std::string flex_op_name(opcode->custom_code()->c_str()); | ||||
|       std::string tf_op_name = | ||||
|           flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix)); | ||||
| 
 | ||||
|       // Read NodeDef and find the op kernel class.
 | ||||
|       if (op->custom_options_format() != | ||||
|           tflite::CustomOptionsFormat_FLEXBUFFERS) { | ||||
|         LOG(FATAL) << "Invalid CustomOptionsFormat"; | ||||
|       } | ||||
|       const flatbuffers::Vector<uint8_t>* custom_opt_bytes = | ||||
|           op->custom_options(); | ||||
|       if (custom_opt_bytes && custom_opt_bytes->size()) { | ||||
|         // NOLINTNEXTLINE: It is common to use references with flatbuffer.
 | ||||
|         const flexbuffers::Vector& v = | ||||
|             flexbuffers::GetRoot(custom_opt_bytes->data(), | ||||
|                                  custom_opt_bytes->size()) | ||||
|                 .AsVector(); | ||||
|         std::string nodedef_str = v[1].AsString().str(); | ||||
|         tensorflow::NodeDef nodedef; | ||||
|         if (nodedef_str.empty() || !nodedef.ParseFromString(nodedef_str)) { | ||||
|           LOG(FATAL) << "Failed to parse data into a valid NodeDef"; | ||||
|         } | ||||
|         // Flex delegate only supports running flex ops with CPU.
 | ||||
|         *nodedef.mutable_device() = "/CPU:0"; | ||||
|         std::string kernel_class = FindTensorflowKernelClass(&nodedef); | ||||
|         flex_ops->insert({tf_op_name, kernel_class}); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| }  // namespace flex
 | ||||
| }  // namespace tflite
 | ||||
							
								
								
									
										55
									
								
								tensorflow/lite/tools/list_flex_ops.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								tensorflow/lite/tools/list_flex_ops.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,55 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #ifndef TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_ | ||||
| #define TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_ | ||||
| 
 | ||||
| #include <set> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "tensorflow/lite/model.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace flex { | ||||
| 
 | ||||
| // Store the Op and Kernel name of an op as the key of a set or map.
 | ||||
| struct OpKernel { | ||||
|   std::string op_name; | ||||
|   std::string kernel_name; | ||||
| }; | ||||
| 
 | ||||
| // The comparison function for OpKernel.
 | ||||
| struct OpKernelCompare { | ||||
|   bool operator()(const OpKernel& lhs, const OpKernel& rhs) const { | ||||
|     if (lhs.op_name == rhs.op_name) { | ||||
|       return lhs.kernel_name < rhs.kernel_name; | ||||
|     } | ||||
|     return lhs.op_name < rhs.op_name; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| using OpKernelSet = std::set<OpKernel, OpKernelCompare>; | ||||
| 
 | ||||
| // Find flex ops and its kernel classes inside a TFLite model and add them to
 | ||||
| // the map flex_ops. The map stores
 | ||||
| void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops); | ||||
| 
 | ||||
| // Serialize the list op of to a json string. If flex_ops is empty, return an
 | ||||
| // empty string.
 | ||||
| std::string OpListToJSONString(const OpKernelSet& flex_ops); | ||||
| 
 | ||||
| }  // namespace flex
 | ||||
| }  // namespace tflite
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_LITE_TOOLS_LIST_FLEX_OPS_H_
 | ||||
							
								
								
									
										50
									
								
								tensorflow/lite/tools/list_flex_ops_main.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tensorflow/lite/tools/list_flex_ops_main.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #include <fstream> | ||||
| #include <iostream> | ||||
| #include <sstream> | ||||
| 
 | ||||
| #include "absl/strings/str_split.h" | ||||
| #include "tensorflow/lite/tools/command_line_flags.h" | ||||
| #include "tensorflow/lite/tools/list_flex_ops.h" | ||||
| 
 | ||||
| const char kInputModelsFlag[] = "graphs"; | ||||
| 
 | ||||
| int main(int argc, char** argv) { | ||||
|   std::string input_models; | ||||
|   std::vector<tflite::Flag> flag_list = { | ||||
|       tflite::Flag::CreateFlag(kInputModelsFlag, &input_models, | ||||
|                                "path to the tflite models, separated by comma.", | ||||
|                                tflite::Flag::kRequired), | ||||
|   }; | ||||
|   tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list); | ||||
| 
 | ||||
|   std::vector<std::string> models = absl::StrSplit(input_models, ','); | ||||
|   tflite::flex::OpKernelSet flex_ops; | ||||
|   for (const std::string& model_file : models) { | ||||
|     std::ifstream fin; | ||||
|     fin.exceptions(std::ifstream::failbit | std::ifstream::badbit); | ||||
|     fin.open(model_file); | ||||
|     std::stringstream content; | ||||
|     content << fin.rdbuf(); | ||||
| 
 | ||||
|     // Need to store content data first, otherwise, it won't work in bazel.
 | ||||
|     std::string content_str = content.str(); | ||||
|     const ::tflite::Model* model = ::tflite::GetModel(content_str.data()); | ||||
|     tflite::flex::AddFlexOpsFromModel(model, &flex_ops); | ||||
|   } | ||||
|   std::cout << tflite::flex::OpListToJSONString(flex_ops); | ||||
|   return 0; | ||||
| } | ||||
							
								
								
									
										203
									
								
								tensorflow/lite/tools/list_flex_ops_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								tensorflow/lite/tools/list_flex_ops_test.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,203 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/lite/tools/list_flex_ops.h" | ||||
| 
 | ||||
| #include <cstdint> | ||||
| 
 | ||||
| #include <gmock/gmock.h> | ||||
| #include <gtest/gtest.h> | ||||
| #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 | ||||
| #include "tensorflow/core/framework/node_def.pb.h" | ||||
| #include "tensorflow/core/platform/protobuf.h" | ||||
| #include "tensorflow/lite/kernels/test_util.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace flex { | ||||
| 
 | ||||
| class FlexOpsListTest : public ::testing::Test { | ||||
|  protected: | ||||
|   FlexOpsListTest() {} | ||||
| 
 | ||||
|   void ReadOps(const string& model_path) { | ||||
|     auto model = FlatBufferModel::BuildFromFile(model_path.data()); | ||||
|     AddFlexOpsFromModel(model->GetModel(), &flex_ops_); | ||||
|     output_text_ = OpListToJSONString(flex_ops_); | ||||
|   } | ||||
| 
 | ||||
|   void ReadOps(const tflite::Model* model) { | ||||
|     AddFlexOpsFromModel(model, &flex_ops_); | ||||
|     output_text_ = OpListToJSONString(flex_ops_); | ||||
|   } | ||||
| 
 | ||||
|   std::string output_text_; | ||||
|   OpKernelSet flex_ops_; | ||||
| }; | ||||
| 
 | ||||
| TfLiteRegistration* Register_TEST() { | ||||
|   static TfLiteRegistration r = {nullptr, nullptr, nullptr, nullptr}; | ||||
|   return &r; | ||||
| } | ||||
| 
 | ||||
| std::vector<uint8_t> CreateFlexCustomOptions(std::string nodedef_raw_string) { | ||||
|   tensorflow::NodeDef node_def; | ||||
|   tensorflow::protobuf::TextFormat::ParseFromString(nodedef_raw_string, | ||||
|                                                     &node_def); | ||||
|   std::string node_def_str = node_def.SerializeAsString(); | ||||
|   auto flex_builder = std::make_unique<flexbuffers::Builder>(); | ||||
|   flex_builder->Vector([&]() { | ||||
|     flex_builder->String(node_def.op()); | ||||
|     flex_builder->String(node_def_str); | ||||
|   }); | ||||
|   flex_builder->Finish(); | ||||
|   return flex_builder->GetBuffer(); | ||||
| } | ||||
| 
 | ||||
| class FlexOpModel : public SingleOpModel { | ||||
|  public: | ||||
|   FlexOpModel(const std::string& op_name, const TensorData& input1, | ||||
|               const TensorData& input2, const TensorType& output, | ||||
|               const std::vector<uint8_t>& custom_options) { | ||||
|     input1_ = AddInput(input1); | ||||
|     input2_ = AddInput(input2); | ||||
|     output_ = AddOutput(output); | ||||
|     SetCustomOp(op_name, custom_options, Register_TEST); | ||||
|     BuildInterpreter({GetShape(input1_), GetShape(input2_)}); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   int input1_; | ||||
|   int input2_; | ||||
|   int output_; | ||||
| }; | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestModelsNoFlex) { | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/test_model.bin"); | ||||
|   EXPECT_EQ(output_text_, "[]"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestBrokenModel) { | ||||
|   EXPECT_DEATH_IF_SUPPORTED( | ||||
|       ReadOps("third_party/tensorflow/lite/testdata/test_model_broken.bin"), | ||||
|       ""); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestZeroSubgraphs) { | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/0_subgraphs.bin"); | ||||
|   EXPECT_EQ(output_text_, "[]"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestFlexAdd) { | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin"); | ||||
|   EXPECT_EQ(output_text_, | ||||
|             "[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestTwoModel) { | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin"); | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/softplus_flex.bin"); | ||||
|   EXPECT_EQ(output_text_, | ||||
|             "[[\"Add\", \"BinaryOp<CPUDevice, " | ||||
|             "functor::add<float>>\"],\n[\"Softplus\", \"SoftplusOp<CPUDevice, " | ||||
|             "float>\"]]"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestDuplicatedOp) { | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin"); | ||||
|   ReadOps("third_party/tensorflow/lite/testdata/multi_add_flex.bin"); | ||||
|   EXPECT_EQ(output_text_, | ||||
|             "[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestInvalidCustomOptions) { | ||||
|   // Using a invalid custom options, expected to fail.
 | ||||
|   std::vector<uint8_t> random_custom_options(20); | ||||
|   FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, | ||||
|                         {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, | ||||
|                         random_custom_options); | ||||
|   EXPECT_DEATH_IF_SUPPORTED( | ||||
|       ReadOps(tflite::GetModel(max_model.GetModelBuffer())), | ||||
|       "Failed to parse data into a valid NodeDef"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestOpNameEmpty) { | ||||
|   // NodeDef with empty opname.
 | ||||
|   std::string nodedef_raw_str = | ||||
|       "name: \"node_1\"" | ||||
|       "op: \"\"" | ||||
|       "input: [ \"b\", \"c\" ]" | ||||
|       "attr: { key: \"T\" value: { type: DT_FLOAT } }"; | ||||
|   std::string random_fieldname = "random string"; | ||||
|   FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, | ||||
|                         {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, | ||||
|                         CreateFlexCustomOptions(nodedef_raw_str)); | ||||
|   EXPECT_DEATH_IF_SUPPORTED( | ||||
|       ReadOps(tflite::GetModel(max_model.GetModelBuffer())), "Invalid NodeDef"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestOpNotFound) { | ||||
|   // NodeDef with invalid opname.
 | ||||
|   std::string nodedef_raw_str = | ||||
|       "name: \"node_1\"" | ||||
|       "op: \"FlexInvalidOp\"" | ||||
|       "input: [ \"b\", \"c\" ]" | ||||
|       "attr: { key: \"T\" value: { type: DT_FLOAT } }"; | ||||
| 
 | ||||
|   FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, | ||||
|                         {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, | ||||
|                         CreateFlexCustomOptions(nodedef_raw_str)); | ||||
|   EXPECT_DEATH_IF_SUPPORTED( | ||||
|       ReadOps(tflite::GetModel(max_model.GetModelBuffer())), | ||||
|       "Op FlexInvalidOp not found"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestKernelNotFound) { | ||||
|   // NodeDef with non-supported type.
 | ||||
|   std::string nodedef_raw_str = | ||||
|       "name: \"node_1\"" | ||||
|       "op: \"Add\"" | ||||
|       "input: [ \"b\", \"c\" ]" | ||||
|       "attr: { key: \"T\" value: { type: DT_BOOL } }"; | ||||
| 
 | ||||
|   FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, | ||||
|                         {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, | ||||
|                         CreateFlexCustomOptions(nodedef_raw_str)); | ||||
|   EXPECT_DEATH_IF_SUPPORTED( | ||||
|       ReadOps(tflite::GetModel(max_model.GetModelBuffer())), | ||||
|       "Failed to find kernel class for op: Add"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(FlexOpsListTest, TestFlexAddWithSingleOpModel) { | ||||
|   std::string nodedef_raw_str = | ||||
|       "name: \"node_1\"" | ||||
|       "op: \"Add\"" | ||||
|       "input: [ \"b\", \"c\" ]" | ||||
|       "attr: { key: \"T\" value: { type: DT_FLOAT } }"; | ||||
| 
 | ||||
|   FlexOpModel max_model("FlexAdd", {TensorType_FLOAT32, {3, 1, 2, 2}}, | ||||
|                         {TensorType_FLOAT32, {3, 1, 2, 1}}, TensorType_FLOAT32, | ||||
|                         CreateFlexCustomOptions(nodedef_raw_str)); | ||||
|   ReadOps(tflite::GetModel(max_model.GetModelBuffer())); | ||||
|   EXPECT_EQ(output_text_, | ||||
|             "[[\"Add\", \"BinaryOp<CPUDevice, functor::add<float>>\"]]"); | ||||
| } | ||||
| }  // namespace flex
 | ||||
| }  // namespace tflite
 | ||||
| 
 | ||||
| int main(int argc, char** argv) { | ||||
|   // On Linux, add: FLAGS_logtostderr = true;
 | ||||
|   ::testing::InitGoogleTest(&argc, argv); | ||||
|   return RUN_ALL_TESTS(); | ||||
| } | ||||
| @ -46,8 +46,10 @@ FLAGS = None | ||||
| 
 | ||||
| def main(unused_argv): | ||||
|   graphs = FLAGS.graphs.split(',') | ||||
|   print(selective_registration_header_lib.get_header( | ||||
|       graphs, FLAGS.proto_fileformat, FLAGS.default_ops)) | ||||
|   print( | ||||
|       selective_registration_header_lib.get_header(graphs, | ||||
|                                                    FLAGS.proto_fileformat, | ||||
|                                                    FLAGS.default_ops)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| @ -63,7 +65,9 @@ if __name__ == '__main__': | ||||
|       '--proto_fileformat', | ||||
|       type=str, | ||||
|       default='rawproto', | ||||
|       help='Format of proto file, either textproto or rawproto.') | ||||
|       help='Format of proto file, either textproto, rawproto or ops_list. The ' | ||||
|       'ops_list is the file contains the list of ops in JSON format. Ex: ' | ||||
|       '"[["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]".') | ||||
|   parser.add_argument( | ||||
|       '--default_ops', | ||||
|       type=str, | ||||
|  | ||||
| @ -93,6 +93,12 @@ class PrintOpFilegroupTest(test.TestCase): | ||||
|       fnames.append(fname) | ||||
|     return fnames | ||||
| 
 | ||||
|   def WriteTextFile(self, content): | ||||
|     fname = os.path.join(self.get_temp_dir(), 'text.txt') | ||||
|     with gfile.GFile(fname, 'w') as f: | ||||
|       f.write(content) | ||||
|     return [fname] | ||||
| 
 | ||||
|   def testGetOps(self): | ||||
|     default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' | ||||
|     graphs = [ | ||||
| @ -136,6 +142,59 @@ class PrintOpFilegroupTest(test.TestCase): | ||||
|         ], | ||||
|         ops_and_kernels) | ||||
| 
 | ||||
|   def testGetOpsFromList(self): | ||||
|     default_ops = '' | ||||
|     # Test with 2 different ops. | ||||
|     ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"], | ||||
|         ["Softplus", "SoftplusOp<CPUDevice, float>"]]""" | ||||
|     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( | ||||
|         'ops_list', self.WriteTextFile(ops_list), default_ops) | ||||
|     self.assertListEqual([ | ||||
|         ('Add', 'BinaryOp<CPUDevice, functor::add<float>>'), | ||||
|         ('Softplus', 'SoftplusOp<CPUDevice, float>'), | ||||
|     ], ops_and_kernels) | ||||
| 
 | ||||
|     # Test with a single op. | ||||
|     ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]' | ||||
|     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( | ||||
|         'ops_list', self.WriteTextFile(ops_list), default_ops) | ||||
|     self.assertListEqual([ | ||||
|         ('Softplus', 'SoftplusOp<CPUDevice, float>'), | ||||
|     ], ops_and_kernels) | ||||
| 
 | ||||
|     # Test with duplicated op. | ||||
|     ops_list = """[["Add", "BinaryOp<CPUDevice, functor::add<float>>"], | ||||
|         ["Add", "BinaryOp<CPUDevice, functor::add<float>>"]]""" | ||||
|     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( | ||||
|         'ops_list', self.WriteTextFile(ops_list), default_ops) | ||||
|     self.assertListEqual([ | ||||
|         ('Add', 'BinaryOp<CPUDevice, functor::add<float>>'), | ||||
|     ], ops_and_kernels) | ||||
| 
 | ||||
|     # Test op with no kernel. | ||||
|     ops_list = '[["Softplus", ""]]' | ||||
|     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( | ||||
|         'ops_list', self.WriteTextFile(ops_list), default_ops) | ||||
|     self.assertListEqual([ | ||||
|         ('Softplus', None), | ||||
|     ], ops_and_kernels) | ||||
| 
 | ||||
|     # Test two ops_list files. | ||||
|     ops_list = '[["Softplus", "SoftplusOp<CPUDevice, float>"]]' | ||||
|     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( | ||||
|         'ops_list', | ||||
|         self.WriteTextFile(ops_list) + self.WriteTextFile(ops_list), | ||||
|         default_ops) | ||||
|     self.assertListEqual([ | ||||
|         ('Softplus', 'SoftplusOp<CPUDevice, float>'), | ||||
|     ], ops_and_kernels) | ||||
| 
 | ||||
|     # Test empty file. | ||||
|     ops_list = '' | ||||
|     with self.assertRaises(Exception): | ||||
|       ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels( | ||||
|           'ops_list', self.WriteTextFile(ops_list), default_ops) | ||||
| 
 | ||||
|   def testAll(self): | ||||
|     default_ops = 'all' | ||||
|     graphs = [ | ||||
|  | ||||
| @ -22,11 +22,11 @@ from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import json | ||||
| import os | ||||
| import sys | ||||
| 
 | ||||
| from google.protobuf import text_format | ||||
| 
 | ||||
| from tensorflow.core.framework import graph_pb2 | ||||
| from tensorflow.python import _pywrap_kernel_registry | ||||
| from tensorflow.python.platform import gfile | ||||
| @ -41,6 +41,39 @@ OPS_WITHOUT_KERNEL_WHITELIST = frozenset([ | ||||
|     # core/common_runtime/accumulate_n_optimizer.cc. | ||||
|     'AccumulateNV2' | ||||
| ]) | ||||
| FLEX_PREFIX = b'Flex' | ||||
| FLEX_PREFIX_LENGTH = len(FLEX_PREFIX) | ||||
| 
 | ||||
| 
 | ||||
| def _get_ops_from_ops_list(input_file): | ||||
|   """Gets the ops and kernels needed from the ops list file.""" | ||||
|   ops = set() | ||||
|   ops_list_str = gfile.GFile(input_file, 'r').read() | ||||
|   if not ops_list_str: | ||||
|     raise Exception('Input file should not be empty') | ||||
|   ops_list = json.loads(ops_list_str) | ||||
|   for op, kernel in ops_list: | ||||
|     op_and_kernel = (op, kernel if kernel else None) | ||||
|     ops.add(op_and_kernel) | ||||
|   return ops | ||||
| 
 | ||||
| 
 | ||||
| def _get_ops_from_graphdef(graph_def): | ||||
|   """Gets the ops and kernels needed from the tensorflow model.""" | ||||
|   ops = set() | ||||
|   for node_def in graph_def.node: | ||||
|     if not node_def.device: | ||||
|       node_def.device = '/cpu:0' | ||||
|     kernel_class = _pywrap_kernel_registry.TryFindKernelClass( | ||||
|         node_def.SerializeToString()) | ||||
|     op = str(node_def.op) | ||||
|     if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: | ||||
|       op_and_kernel = (op, str(kernel_class.decode('utf-8')) | ||||
|                        if kernel_class else None) | ||||
|       ops.add(op_and_kernel) | ||||
|     else: | ||||
|       print('Warning: no kernel found for op %s' % node_def.op, file=sys.stderr) | ||||
|   return ops | ||||
| 
 | ||||
| 
 | ||||
| def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): | ||||
| @ -49,6 +82,11 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): | ||||
| 
 | ||||
|   for proto_file in proto_files: | ||||
|     tf_logging.info('Loading proto file %s', proto_file) | ||||
|     # Load ops list file. | ||||
|     if proto_fileformat == 'ops_list': | ||||
|       ops = ops.union(_get_ops_from_ops_list(proto_file)) | ||||
|       continue | ||||
| 
 | ||||
|     # Load GraphDef. | ||||
|     file_data = gfile.GFile(proto_file, 'rb').read() | ||||
|     if proto_fileformat == 'rawproto': | ||||
| @ -56,22 +94,7 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): | ||||
|     else: | ||||
|       assert proto_fileformat == 'textproto' | ||||
|       graph_def = text_format.Parse(file_data, graph_pb2.GraphDef()) | ||||
| 
 | ||||
|     # Find all ops and kernels used by the graph. | ||||
|     for node_def in graph_def.node: | ||||
|       if not node_def.device: | ||||
|         node_def.device = '/cpu:0' | ||||
|       kernel_class = _pywrap_kernel_registry.TryFindKernelClass( | ||||
|           node_def.SerializeToString()) | ||||
|       op = str(node_def.op) | ||||
|       if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: | ||||
|         op_and_kernel = (op, str(kernel_class.decode('utf-8')) | ||||
|                          if kernel_class else None) | ||||
|         if op_and_kernel not in ops: | ||||
|           ops.add(op_and_kernel) | ||||
|       else: | ||||
|         print( | ||||
|             'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr) | ||||
|     ops = ops.union(_get_ops_from_graphdef(graph_def)) | ||||
| 
 | ||||
|   # Add default ops. | ||||
|   if default_ops_str and default_ops_str != 'all': | ||||
| @ -91,7 +114,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels, | ||||
|   Args: | ||||
|     ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include. | ||||
|     include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op | ||||
|     kernels are included. | ||||
|       kernels are included. | ||||
| 
 | ||||
|   Returns: | ||||
|     the string of the header that should be written as ops_to_register.h. | ||||
| @ -112,7 +135,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels, | ||||
|     append('#define SHOULD_REGISTER_OP_KERNEL(clz) true') | ||||
|     append('#define SHOULD_REGISTER_OP_GRADIENT true') | ||||
|   else: | ||||
|     line = ''' | ||||
|     line = """ | ||||
|     namespace { | ||||
|       constexpr const char* skip(const char* x) { | ||||
|         return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; | ||||
| @ -138,10 +161,11 @@ def get_header_from_ops_and_kernels(ops_and_kernels, | ||||
|         } | ||||
|       }; | ||||
|     }  // end namespace | ||||
|     ''' | ||||
|     """ | ||||
|     line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' | ||||
|     for _, kernel_class in ops_and_kernels: | ||||
|       if kernel_class is None: continue | ||||
|       if kernel_class is None: | ||||
|         continue | ||||
|       line += '"%s",\n' % kernel_class | ||||
|     line += '};' | ||||
|     append(line) | ||||
| @ -160,8 +184,8 @@ def get_header_from_ops_and_kernels(ops_and_kernels, | ||||
|     append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') | ||||
|     append('') | ||||
| 
 | ||||
|     append('#define SHOULD_REGISTER_OP_GRADIENT ' + ( | ||||
|         'true' if 'SymbolicGradient' in ops else 'false')) | ||||
|     append('#define SHOULD_REGISTER_OP_GRADIENT ' + | ||||
|            ('true' if 'SymbolicGradient' in ops else 'false')) | ||||
| 
 | ||||
|   append('#endif') | ||||
|   return '\n'.join(result_list) | ||||
| @ -174,11 +198,13 @@ def get_header(graphs, | ||||
| 
 | ||||
|   Args: | ||||
|     graphs: a list of paths to GraphDef files to include. | ||||
|     proto_fileformat: optional format of proto file, either 'textproto' or | ||||
|       'rawproto' (default). | ||||
|     proto_fileformat: optional format of proto file, either 'textproto', | ||||
|       'rawproto' (default) or ops_list. The ops_list is the file contain the | ||||
|       list of ops in JSON format, Ex: "[["Transpose", "TransposeCpuOp"]]". | ||||
|     default_ops: optional comma-separated string of operator:kernel pairs to | ||||
|       always include implementation for. Pass 'all' to have all operators and | ||||
|       kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'. | ||||
| 
 | ||||
|   Returns: | ||||
|     the string of the header that should be written as ops_to_register.h. | ||||
|   """ | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user