diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD new file mode 100644 index 00000000000..d43964208bc --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/BUILD @@ -0,0 +1,64 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_binary( + name = "option_writer_generator", + srcs = ["option_writer_generator.cc"], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + "@flatbuffers", + ], +) + +cc_library( + name = "writer_lib", + srcs = [ + "enum_mapping.h", + "writer_lib.cc", + ], + hdrs = [ + "writer_lib.h", + ], + textual_hdrs = ["option_writer_generated.h"], + deps = [ + ":option_writer_gen", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + ], +) + +cc_binary( + name = "writer", + srcs = ["writer.cc"], + deps = [ + ":writer_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "writer_lib_test", + size = "small", + srcs = ["writer_lib_test.cc"], + deps = [ + ":writer_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/testing:util", + "//testing/base/public:gunit", + ], +) + +genrule( + name = "option_writer_gen", + outs = ["option_writer_generated.h"], + cmd = "$(location :option_writer_generator) $(@)", + tools = [":option_writer_generator"], +) diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h new file mode 100644 index 00000000000..8bc464fd718 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" + +// TODO(aselle): Ideally extract this from the schema. + +namespace tflite { + +inline ActivationFunctionType TfLiteActivationToSchemaActivation( + TfLiteFusedActivation act) { + switch (act) { + case kTfLiteActNone: + return ActivationFunctionType_NONE; + case kTfLiteActRelu: + return ActivationFunctionType_RELU; + case kTfLiteActRelu1: + return ActivationFunctionType_RELU_N1_TO_1; + case kTfLiteActRelu6: + return ActivationFunctionType_RELU6; + case kTfLiteActTanh: + return ActivationFunctionType_TANH; + case kTfLiteActSignBit: + return ActivationFunctionType_SIGN_BIT; + case kTfLiteActSigmoid: + return ActivationFunctionType_NONE; // TODO(aselle): Add to schema + } + return ActivationFunctionType_NONE; +} + +inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingUnknown: + return Padding_SAME; // TODO(aselle): Consider an error. + case kTfLitePaddingSame: + return Padding_SAME; + case kTfLitePaddingValid: + return Padding_VALID; + } + return Padding_SAME; // TODO(aselle): Consider an error. +} + +inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { + switch (type) { + // case kTfLiteNoType: return TensorType_NONE; + case kTfLiteNoType: + return TensorType_FLOAT32; // TODO(aselle): Consider an error. + case kTfLiteFloat32: + return TensorType_FLOAT32; + case kTfLiteInt32: + return TensorType_INT32; + case kTfLiteUInt8: + return TensorType_UINT8; + case kTfLiteInt64: + return TensorType_INT64; + case kTfLiteString: + return TensorType_STRING; + case kTfLiteBool: + return TensorType_BOOL; + case kTfLiteInt16: + return TensorType_INT16; + case kTfLiteComplex64: + return TensorType_COMPLEX64; + } + // TODO(aselle): consider an error +} + +inline FullyConnectedOptionsWeightsFormat +FullyConnectedOptionsWeightsFormatToSchema( + TfLiteFullyConnectedWeightsFormat format) { + switch (format) { + case kTfLiteFullyConnectedWeightsFormatDefault: + return FullyConnectedOptionsWeightsFormat_DEFAULT; + case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8: + return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + } +} + +inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) { + switch (type) { + case kTfLiteLSTMFullKernel: + return LSTMKernelType_FULL; + case kTfLiteLSTMBasicKernel: + return LSTMKernelType_BASIC; + } +} + +inline LSHProjectionType LSHProjectionTypeToSchema( + TfLiteLSHProjectionType type) { + switch (type) { + case kTfLiteLshProjectionUnknown: + return LSHProjectionType_UNKNOWN; + case kTfLiteLshProjectionSparse: + return LSHProjectionType_SPARSE; + case kTfLiteLshProjectionDense: + return LSHProjectionType_DENSE; + } +} + +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc new file mode 100644 index 00000000000..e6d5a776b32 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc @@ -0,0 +1,370 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <ctype.h> +#include <iostream> +#include <unordered_map> +#include <unordered_set> +#include "flatbuffers/minireflect.h" // flatbuffers +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" + +namespace tflite { +namespace { +// This is generated by grepping +// cat third_party/tensorflow/contrib/lite/builtin_op_data.h +//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" +static const char* param_structs[] = {"TfLiteConvParams", + "TfLitePoolParams", + "TfLiteDepthwiseConvParams", + "TfLiteSVDFParams", + "TfLiteRNNParams", + "TfLiteSequenceRNNParams", + "TfLiteFullyConnectedParams", + "TfLiteLSHProjectionParams", + "TfLiteSoftmaxParams", + "TfLiteConcatenationParams", + "TfLiteAddParams", + "TfLiteSpaceToBatchNDParams", + "TfLiteBatchToSpaceNDParams", + "TfLiteMulParams", + "TfLiteSubParams", + "TfLiteDivParams", + "TfLiteL2NormParams", + "TfLiteLocalResponseNormParams", + "TfLiteLSTMParams", + "TfLiteResizeBilinearParams", + "TfLitePadParams", + "TfLitePadV2Params", + "TfLiteReshapeParams", + "TfLiteSkipGramParams", + "TfLiteSpaceToDepthParams", + "TfLiteCastParams", + "TfLiteEmbeddingLookupSparseParams", + "TfLiteGatherParams", + "TfLiteTransposeParams", + "TfLiteReducerParams", + "TfLiteSplitParams", + "TfLiteSqueezeParams", + "TfLiteStridedSliceParams", + "TfLiteArgMaxParams", + "TfLiteArgMinParams", + "TfLiteTransposeConvParams", + "TfLiteSparseToDenseParams", + "TfLiteShapeParams", + "TfLiteFakeQuantParams", + "TfLitePackParams", + "TfLiteOneHotParams", + nullptr}; +} // namespace + +// Get rid of all underscores and make everything lower case to make name +// matching work for stuff like 3D vs 3d or RNN vs Rnn. +std::string ToCollapsed(const std::string& in) { + const char* s = in.c_str(); + bool first = true; + std::string out; + while (*s != '\0') { + if (*s == '_') { + first = true; + } else if (first) { + out.push_back(tolower(*s)); + first = false; + } else { + out.push_back(tolower(*s)); + } + s++; + } + return out; +} + +// A collection of information about builtin ops. +class OpOptionData { + public: + OpOptionData() { + BuildOpList(); + BuildOptionToTypeFunctionMap(); + BuildOpToOptionMap(); + } + + // A list of builtin operations + const std::vector<std::string>& ops() const { return ops_; } + // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions') + const std::unordered_map<std::string, std::string>& op_to_option() { + return op_to_option_; + } + // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions' + const std::unordered_map<std::string, std::string>& option_to_struct() { + return option_to_struct_; + } + // Maps from option to a flatbuffer type function that describes that option. + const std::unordered_map<std::string, flatbuffers::TypeFunction>& + option_to_type_function() { + return option_to_type_function_; + } + + private: + void BuildOpList() { + for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr; + ++curr) { + if (strlen(*curr) != 0) ops_.push_back(*curr); + } + } + + void BuildOptionToTypeFunctionMap() { + auto d = tflite::BuiltinOptionsTypeTable(); + for (int i = 0; i < d->num_elems; i++) { + flatbuffers::TypeCode code = d->type_codes[i]; + if (code.sequence_ref != -1) { + option_to_type_function_.insert( + std::make_pair(d->names[i], d->type_refs[code.sequence_ref])); + } + } + } + + void BuildOpToOptionMap() { + // Manually specified mappings between ops and options + op_to_option_["REDUCE_MAX"] = "ReducerOptions"; + op_to_option_["REDUCE_MIN"] = "ReducerOptions"; + op_to_option_["REDUCE_ANY"] = "ReducerOptions"; + op_to_option_["UNPACK"] = ""; + op_to_option_["SUM"] = "ReducerOptions"; + op_to_option_["REDUCE_MAX"] = "ReducerOptions"; + op_to_option_["REDUCE_PROD"] = "ReducerOptions"; + op_to_option_["MEAN"] = "ReducerOptions"; + op_to_option_["L2_POOL_2D"] = "Pool2DOptions"; + op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions"; + op_to_option_["MAX_POOL_2D"] = "Pool2DOptions"; + op_to_option_["L2_NORMALIZATION"] = "L2NormOptions"; + op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; + op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + // Manually specified mappings between ops and options (none) + op_to_option_["EMBEDDING_LOOKUP"] = + ""; // TODO(aselle): maybe something else. + op_to_option_["FLOOR"] = ""; + op_to_option_["HASHTABLE_LOOKUP"] = + ""; // TODO(aselle): maybe something else. + op_to_option_["LOGISTIC"] = ""; + op_to_option_["RELU"] = ""; + op_to_option_["RELU_N1_TO_1"] = ""; + op_to_option_["RELU6"] = ""; + op_to_option_["TANH"] = ""; + op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. + op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. + op_to_option_["PRELU"] = ""; + op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions + op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions + op_to_option_["SIN"] = ""; + op_to_option_["LOG"] = ""; + op_to_option_["SQRT"] = ""; + op_to_option_["RSQRT"] = ""; + + // TODO(aselle): These are undesirable hacks. Consider changing C structs + option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; + option_to_struct_["Conv2DOptions"] = "TfLiteConvParams"; + option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams"; + option_to_struct_["LocalResponseNormalizationOptions"] = + "TfLiteLocalResponseNormParams"; + // Now for every op, try to find an option. + bool fatal = false; + for (auto op_name : ops_) { + bool found_option = false; + auto d = tflite::BuiltinOptionsTypeTable(); + std::string collapsed_option_name_guess = + ToCollapsed(op_name) + "options"; + // O(n^2) but not that big of n. + for (int i = 0; i < d->num_elems; i++) { + std::string option_name = d->names[i]; + std::string collapsed_option_name = ToCollapsed(option_name); + if (collapsed_option_name_guess == collapsed_option_name) { + op_to_option_.insert(std::make_pair(op_name, option_name)); + found_option = true; + break; + } + } + auto it = op_to_option_.find(op_name); + if (it == op_to_option_.end()) { + std::cerr << "Didn't find option for " << op_name << std::endl; + fatal = true; + } else if (!it->second.empty()) { + std::string option_name = it->second; + + if (option_to_struct_.find(option_name) == option_to_struct_.end()) { + bool param_struct_found = false; + std::string params_guess = std::string("TfLite") + option_name; + size_t start = params_guess.find("Options"); + size_t len = strlen("Options"); + params_guess.replace(start, len, "Params"); + for (auto* param = param_structs; *param != nullptr; param++) { + if (*param == params_guess) { + param_struct_found = true; + break; + } + } + if (!param_struct_found) { + std::cerr << "Failed to get param struct for option " << option_name + << std::endl; + fatal = true; + } else { + option_to_struct_.insert(std::make_pair(option_name, params_guess)); + } + } + } + } + } + + private: + std::vector<std::string> ops_; + std::unordered_map<std::string, std::string> op_to_option_; + std::unordered_map<std::string, std::string> option_to_struct_; + std::unordered_map<std::string, flatbuffers::TypeFunction> + option_to_type_function_; +}; + +void GenerateImportForOp(FILE* fp, const std::string& op_name, + const std::string& option_name, + const std::string& option_type, + const flatbuffers::TypeTable* options, + const std::string& struct_name) { + // Skip tricky ones for now + if (struct_name == "TfLiteResizeBilinearParams") return; + if (struct_name == "TfLiteSqueezeParams") return; + if (struct_name == "TfLiteEmbeddingLookupSparseParams") return; + if (struct_name == "TfLiteReshapeParams") return; + + fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); + fprintf(fp, + " const auto* params = reinterpret_cast<const " + "%s*>(builtin_op_data);\n", + struct_name.c_str()); + + for (size_t i = 0; i < options->num_elems; i++) { + std::string elem_name = options->names[i]; + // TODO(aselle): Irregular naming in builtins + if (elem_name == "fused_activation_function") + elem_name = "activation"; + else if (elem_name == "stride_w") + elem_name = "stride_width"; + else if (elem_name == "stride_h") + elem_name = "stride_height"; + else if (elem_name == "dilation_h_factor") + elem_name = "dilation_height_factor"; + else if (elem_name == "dilation_w_factor") + elem_name = "dilation_width_factor"; + else if (elem_name == "new_shape") + elem_name = "shape"; + + flatbuffers::TypeCode code = options->type_codes[i]; + auto contained_type = code.sequence_ref != -1 + ? options->type_refs[code.sequence_ref] + : nullptr; + std::string mapper = ""; + if (contained_type == TensorTypeTypeTable) { + mapper = "TfLiteTypeToSchemaType"; + } else if (contained_type == ActivationFunctionTypeTypeTable) { + mapper = "TfLiteActivationToSchemaActivation"; + } else if (contained_type == PaddingTypeTable) { + mapper = "TfLitePaddingToSchemaPadding"; + } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) { + mapper = "FullyConnectedOptionsWeightsFormatToSchema"; + } else if (contained_type == LSTMKernelTypeTypeTable) { + mapper = "LSTMKernelTypeToSchema"; + } else if (contained_type == LSHProjectionTypeTypeTable) { + mapper = "LSHProjectionTypeToSchema"; + } + + fprintf(fp, + " auto val%zu = " + "%s(params->%s);\n", + i, mapper.c_str(), elem_name.c_str()); + } + fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str()); + for (size_t i = 0; i < options->num_elems; i++) { + fprintf(fp, ", val%zu", i); + } + fprintf(fp, ").Union();\n"); + fprintf(fp, " return std::make_pair(%s, union_type);\n", + option_type.c_str()); + fprintf(fp, " }\n break;\n"); +} + +void GenerateImport(OpOptionData* option, FILE* fp) { + std::unordered_set<std::string> ignores; + ignores.insert("CONCAT_EMBEDDINGS"); + ignores.insert("CALL"); + + // Allow any op that doesn't have an options struct to be blocked + // together + for (const auto& op_name : option->ops()) { + auto option_it = option->op_to_option().find(op_name); + if (!option_it->second.empty() && ignores.find(op_name) == ignores.end()) + continue; + fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); + } + fprintf(fp, + " return std::make_pair(BuiltinOptions_NONE, " + "flatbuffers::Offset<void>());\n break;\n"); + + // Iterate over each ops + for (const auto& op_name : option->ops()) { + if (ignores.find(op_name) != ignores.end()) continue; + // Get to the option and struct names, continuing if not found. + auto option_it = option->op_to_option().find(op_name); + if (option_it->second.empty()) continue; + std::string option_name = option_it->second; + std::string option_type = "BuiltinOptions_" + option_name; + auto option_func_it = option->option_to_type_function().find(option_name); + if (option_func_it == option->option_to_type_function().end()) continue; + auto struct_name_it = option->option_to_struct().find(option_name); + if (struct_name_it == option->option_to_struct().end()) { + // If no C struct, then it better have no arguments. + auto type_info = option_func_it->second(); + if (type_info->num_elems != 0) { + // We have non-zero arguments in the schema, this means there + // should be a struct. + fprintf(stderr, + "Op %s uses option struct %s which has no builtin struct\n", + op_name.c_str(), option_name.c_str()); + exit(1); + } + fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); + fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());", + option_type.c_str(), option_name.c_str()); + } else { + // If C struct, then we need to assign all properties + auto struct_name = struct_name_it->second; + GenerateImportForOp(fp, op_name, option_name, option_type, + option_func_it->second(), struct_name); + } + } + // TODO(aselle): Handle unhandled cases more gracefully. + fprintf(fp, + "default: return std::make_pair(BuiltinOptions_NONE, " + "flatbuffers::Offset<void>());\n break;\n"); +} + +} // namespace tflite + +int main(int argc, char* argv[]) { + tflite::OpOptionData option; + if (argc != 2) { + fprintf(stderr, "Usage: %s <fname out>\n", argv[0]); + return 1; + } + FILE* fp = fopen(argv[1], "w"); + tflite::GenerateImport(&option, fp); + fclose(fp); +} diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc new file mode 100644 index 00000000000..20ede214fba --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Just does a read/write loop of tflite file format using the interpreter as +// an intermediate. +// +// Usage: +// writer <input tflite> <output tflite> + +#include <iostream> + +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]); + return 1; + } + std::unique_ptr<tflite::FlatBufferModel> model = + tflite::FlatBufferModel::BuildFromFile(argv[1]); + std::unique_ptr<tflite::Interpreter> interpreter; + tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; + tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); + tflite::InterpreterWriter writer(interpreter.get()); + writer.Write(argv[2]); + + return 0; +} diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc new file mode 100644 index 00000000000..52b17faf82e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc @@ -0,0 +1,281 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include <cstdlib> +#include <cstring> +#include <unordered_map> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { +template <class T> +using Offset = flatbuffers::Offset<T>; +template <class T> +using Vector = flatbuffers::Vector<T>; +using FlatBufferBuilder = flatbuffers::FlatBufferBuilder; + +std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion( + FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) { + switch (op) { +#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h" + } + return std::make_pair(BuiltinOptions_NONE, Offset<void>()); +} + +template <class T_OUTPUT, class T_INPUT> +Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb, + const T_INPUT& v) { + std::vector<T_OUTPUT> inputs(v.begin(), v.end()); + return fbb->template CreateVector<T_OUTPUT>(inputs); +} + +Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators( + FlatBufferBuilder* fbb) { + std::vector<Offset<Operator>> operators; + + std::vector<int> operator_to_opcode; + // TODO(aselle): Augment this once we put execution plan in schema. + operator_to_opcode.resize(interpreter_->nodes_size(), -1); + for (int op_index : interpreter_->execution_plan()) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + const TfLiteRegistration* registration = &node_and_registration->second; + if (!registration->custom_name) { + operator_to_opcode[op_index] = + GetOpCodeForBuiltin(registration->builtin_code); + } else { + operator_to_opcode[op_index] = + GetOpCodeForCustom(registration->custom_name); + } + } + // second pass serialize operators + for (int op_index : interpreter_->execution_plan()) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + const TfLiteNode& node = node_and_registration->first; + const TfLiteRegistration& registration = node_and_registration->second; + Offset<void> builtin_options; + BuiltinOptions builtin_options_type = BuiltinOptions_NONE; + // Custom data + // TODO(aselle): Custom options format is not known by default. Just assume + // for now. + auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS; + Offset<Vector<uint8_t>> custom_options = 0; + + if (!registration.custom_name) { + // builtin + auto builtin_options_and_type = CreateBuiltinUnion( + fbb, static_cast<enum BuiltinOperator>(registration.builtin_code), + node.builtin_data); + builtin_options = builtin_options_and_type.second; + builtin_options_type = builtin_options_and_type.first; + } else { + auto custom_writer = custom_op_to_writer_.find(registration.custom_name); + if (custom_writer != custom_op_to_writer_.end() && + custom_writer->second) { + // delegate to custom writer if it exists + custom_writer->second(fbb, interpreter_, op_index, &custom_options, + &custom_options_format); + } else { + // use the custom data as fact + custom_options = fbb->CreateVector( + reinterpret_cast<const uint8_t*>(node.custom_initial_data), + node.custom_initial_data_size); + } + } + + int opcode_index = operator_to_opcode[op_index]; + std::vector<int> written_inputs = + RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs)); + std::vector<int> written_outputs = + RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs)); + auto inputs = ExportVector<int32_t>(fbb, written_inputs); + auto outputs = ExportVector<int32_t>(fbb, written_outputs); + operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs, + builtin_options_type, builtin_options, + custom_options, custom_options_format)); + } + + return fbb->template CreateVector<Offset<Operator>>(operators); +} + +Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors( + FlatBufferBuilder* fbb) { + tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1); + + std::vector<Offset<Tensor>> tensors; + + // Make a map from tensor index to whether the tensor is a temporary. + std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false); + for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + for (auto tensor_index : + TfLiteIntArrayView(node_and_registration->first.temporaries)) + tensor_is_temporary[tensor_index] = true; + } + + // Now we need to remap all used tensor indices + int curr_output_index = 0; + for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); + tensor_index++) { + if (!tensor_is_temporary[tensor_index]) { + tensor_to_written_tensor_[tensor_index] = curr_output_index++; + } + } + + for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); + ++tensor_index) { + // Skip temporaries. + if (tensor_is_temporary[tensor_index]) continue; + + if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) { + // We only need to convert non temporaries + if (tensor->allocation_type != kTfLiteArenaRw && + tensor->allocation_type != kTfLiteMmapRo && + tensor->allocation_type != kTfLiteArenaRwPersistent) + continue; + // Allocate a buffer index + int buffer_index = 0; // This is null + if (tensor->allocation_type == kTfLiteMmapRo) { + buffer_index = buffers_.size(); + buffers_.push_back(std::make_pair( + reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes)); + } + // Primitive type. + TensorType type = TfLiteTypeToSchemaType(tensor->type); + // Handle quantization + const Offset<Vector<float>> null_array; + Offset<Vector<float>> scale_array; + Offset<Vector<int64_t>> zero_point_array; + if (tensor->params.scale != 0.f) { + // We have quantization, make a single arugment array (multi channel + // quant needs updating here). + scale_array = fbb->CreateVector<float>({tensor->params.scale}); + zero_point_array = + fbb->CreateVector<int64_t>({tensor->params.zero_point}); + } + Offset<QuantizationParameters> quantization_params = + CreateQuantizationParameters(*fbb, null_array, null_array, + scale_array, zero_point_array); + // Shape + TfLiteIntArrayView shape_view(tensor->dims); + std::vector<int> shape = + std::vector<int>(shape_view.begin(), shape_view.end()); + + tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape), + type, buffer_index, + fbb->CreateString(tensor->name), + quantization_params, tensor->is_variable)); + } + } + return fbb->template CreateVector<Offset<Tensor>>(tensors); +} + +Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers( + FlatBufferBuilder* fbb) { + std::vector<Offset<Buffer>> buffer_vector; + for (auto buffer : buffers_) { + auto data_offset = fbb->CreateVector(buffer.first, buffer.second); + buffer_vector.push_back(CreateBuffer(*fbb, data_offset)); + } + return fbb->template CreateVector<Offset<Buffer>>(buffer_vector); +} + +Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable( + FlatBufferBuilder* fbb) { + std::vector<Offset<OperatorCode>> codes; + for (auto it : opcodes_) { + const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str(); + codes.push_back(CreateOperatorCodeDirect( + *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name)); + } + return fbb->template CreateVector<Offset<OperatorCode>>(codes); +} + +template <class T> +std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten( + const T& input) { + std::vector<int> output; + output.reserve(input.size()); + for (int x : input) { + output.push_back(tensor_to_written_tensor_[x]); + } + return output; +} + +TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out, + size_t* size) { + if (!out || !size) return kTfLiteError; + FlatBufferBuilder builder(/*initial_size=*/10240); + + std::vector<Offset<SubGraph>> subgraphs_as_vector; + { // subgraph specific stuff + auto tensors = ExportTensors(&builder); + std::vector<int> written_inputs = + RemapTensorIndicesToWritten(interpreter_->inputs()); + std::vector<int> written_outputs = + RemapTensorIndicesToWritten(interpreter_->outputs()); + auto inputs = ExportVector<int32_t>(&builder, written_inputs); + auto outputs = ExportVector<int32_t>(&builder, written_outputs); + + auto ops = ExportOperators(&builder); + subgraphs_as_vector.push_back( + CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0)); + } + Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder); + + auto description = builder.CreateString("Exported from Interpreter."); + + auto op_codes = CreateOpCodeTable(&builder); + auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, + builder.CreateVector(subgraphs_as_vector), + description, buffers); + ::tflite::FinishModelBuffer(builder, model); + const uint8_t* buffer = builder.GetBufferPointer(); + *size = builder.GetSize(); + (*out).reset(new uint8_t[*size]); + memcpy(out->get(), buffer, *size); + return kTfLiteOk; +} + +TfLiteStatus InterpreterWriter::Write(const std::string& filename) { + std::unique_ptr<uint8_t[]> buffer; + size_t size; + TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); + + FILE* fp = fopen(filename.c_str(), "wb"); + if (!fp) return kTfLiteError; + + if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError; + if (fclose(fp)) return kTfLiteError; + + return kTfLiteOk; +} + +TfLiteStatus InterpreterWriter::RegisterCustomWriter( + const std::string& custom_name, CustomWriter custom_writer) { + if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) { + return kTfLiteError; + } + custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer)); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h new file mode 100644 index 00000000000..a98108b4960 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter. +// +// Usage: +// From command line: +// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer +// -- foo.tflite foo.out.tflite +// +// From C++ +// std::unique_ptr<Interpreter> interpreter; +// // Build Interpreter however +// // ... <omitted> +// InterpreterWriter(interpreter.get()).Write("output.tflite"); +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#include <iostream> +#include <unordered_map> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +// Handles writing TensorFlow Lite running interpreter to a serialized TF lite +// file format. +class InterpreterWriter { + public: + typedef flatbuffers::Offset<Operator> (*CustomWriter)( + flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter, + int node_index, + flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options, + CustomOptionsFormat* custom_options_format); + + // Construct an interpreter writer for the specified `interpreter`. Then, + // a uses .Write() or .GetBuffer(...) to extract the data. + explicit InterpreterWriter(Interpreter* interpreter) + : interpreter_(interpreter) { + buffers_.push_back(std::make_pair(nullptr, 0)); + } + + // Get a buffer and size of a serialized flatbuffer. + TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size); + // Write the serialized flatbuffer to the prescribed `filename`. + TfLiteStatus Write(const std::string& filename); + // Registers a custom writer for a custom op. The customization allows the + // caller to change the custom data. + TfLiteStatus RegisterCustomWriter(const std::string& custom_name, + CustomWriter custom_writer); + + private: + template <class T> + using Offset = flatbuffers::Offset<T>; + template <class T_OUTPUT, class T_INPUT> + Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector( + flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); + Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors( + flatbuffers::FlatBufferBuilder* fbb); + Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators( + flatbuffers::FlatBufferBuilder* fbb); + Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable( + flatbuffers::FlatBufferBuilder* fbb); + Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers( + flatbuffers::FlatBufferBuilder* fbb); + + template <class T> + std::vector<int> RemapTensorIndicesToWritten(const T& input); + + int GetOpCodeForBuiltin(int builtin_op_index) { + // auto it = builtin_op_to_opcode_.find(builtin_op_index); + std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result = + builtin_op_to_opcode_.insert( + std::make_pair(builtin_op_index, opcodes_.size())); + if (result.second) { + opcodes_.push_back({builtin_op_index, ""}); + } + return result.first->second; + } + + int GetOpCodeForCustom(const std::string& custom_name) { + std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result = + custom_op_to_opcode_.insert( + std::make_pair(custom_name, opcodes_.size())); + if (result.second) { + opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); + } + return result.first->second; + } + + // The interpreter we are writing + Interpreter* interpreter_; + // Keep track of byte buffers + std::vector<std::pair<const uint8_t*, size_t>> buffers_; + // List of op codes and mappings from builtin or custom op to opcode + struct OpCode { + int builtin; + std::string custom; + }; + // For every tensor index in the interpreter, the index in the written. + // This is different due to temporary tensors not being written. + std::vector<int> tensor_to_written_tensor_; + // List of used opcodes + std::vector<OpCode> opcodes_; + std::unordered_map<int, int> builtin_op_to_opcode_; + std::unordered_map<std::string, int> custom_op_to_opcode_; + std::unordered_map<std::string, CustomWriter> custom_op_to_writer_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc new file mode 100644 index 00000000000..49194a76c8c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +// Make an interpreter that has no tensors and no nodes +// TODO(b/113731921): add more tests. +TEST(Writer, BasicTest) { + Interpreter interpreter; + interpreter.AddTensors(3); + float foo[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, + TfLiteQuantizationParams()); + interpreter.SetTensorParametersReadOnly( + 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(), + reinterpret_cast<char*>(foo), sizeof(foo)); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, + TfLiteQuantizationParams()); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({2}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, + reinterpret_cast<void*>(builtin_data), reg); + + InterpreterWriter writer(&interpreter); + writer.Write("/tmp/test.tflite"); + std::unique_ptr<FlatBufferModel> model = + FlatBufferModel::BuildFromFile("/tmp/test.tflite"); + InterpreterBuilder builder(*model, resolver); + std::unique_ptr<Interpreter> new_interpreter; + builder(&new_interpreter); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/op_resolver.cc index f6e435e9824..a9885f77371 100644 --- a/tensorflow/contrib/lite/op_resolver.cc +++ b/tensorflow/contrib/lite/op_resolver.cc @@ -46,6 +46,8 @@ void MutableOpResolver::AddCustom(const char* name, TfLiteRegistration* registration, int min_version, int max_version) { for (int version = min_version; version <= max_version; ++version) { + // TODO(aselle): This should verify that the incoming registration + // has the name in the registration already and it matches!!! TfLiteRegistration new_registration = *registration; new_registration.builtin_code = BuiltinOperator_CUSTOM; new_registration.version = version; diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 28a7e500034..55bf2c48b97 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -56,6 +56,20 @@ flatbuffer_cc_library( srcs = ["schema.fbs"], ) +# Generic schema for inference on device (but with reflections makes bigger). +flatbuffer_cc_library( + name = "schema_fbs_with_reflection", + srcs = ["schema.fbs"], + flatc_args = [ + "--reflect-types", + "--reflect-names", + "--no-union-value-namespacing", + "--gen-object-api", + ], + gen_reflections = True, + out_prefix = "reflection/", +) + # Schema test to make sure we don't introduce backward incompatible changes # to schemas. cc_test(