Automated rollback of commit 69753ba5db

PiperOrigin-RevId: 211541639
This commit is contained in:
Andrew Selle 2018-09-04 16:01:54 -07:00 committed by TensorFlower Gardener
parent 69753ba5db
commit 0065d3389a
9 changed files with 0 additions and 1076 deletions

View File

@ -1,64 +0,0 @@
package(default_visibility = [
"//visibility:public",
])
licenses(["notice"]) # Apache 2.0
cc_binary(
name = "option_writer_generator",
srcs = ["option_writer_generator.cc"],
deps = [
"//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
"@flatbuffers",
],
)
cc_library(
name = "writer_lib",
srcs = [
"enum_mapping.h",
"writer_lib.cc",
],
hdrs = [
"writer_lib.h",
],
textual_hdrs = ["option_writer_generated.h"],
deps = [
":option_writer_gen",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:schema_fbs_version",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
],
)
cc_binary(
name = "writer",
srcs = ["writer.cc"],
deps = [
":writer_lib",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
cc_test(
name = "writer_lib_test",
size = "small",
srcs = ["writer_lib_test.cc"],
deps = [
":writer_lib",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/testing:util",
"//testing/base/public:gunit",
],
)
genrule(
name = "option_writer_gen",
outs = ["option_writer_generated.h"],
cmd = "$(location :option_writer_generator) $(@)",
tools = [":option_writer_generator"],
)

View File

@ -1,116 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
// TODO(aselle): Ideally extract this from the schema.
namespace tflite {
inline ActivationFunctionType TfLiteActivationToSchemaActivation(
TfLiteFusedActivation act) {
switch (act) {
case kTfLiteActNone:
return ActivationFunctionType_NONE;
case kTfLiteActRelu:
return ActivationFunctionType_RELU;
case kTfLiteActRelu1:
return ActivationFunctionType_RELU_N1_TO_1;
case kTfLiteActRelu6:
return ActivationFunctionType_RELU6;
case kTfLiteActTanh:
return ActivationFunctionType_TANH;
case kTfLiteActSignBit:
return ActivationFunctionType_SIGN_BIT;
case kTfLiteActSigmoid:
return ActivationFunctionType_NONE; // TODO(aselle): Add to schema
}
return ActivationFunctionType_NONE;
}
inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
switch (padding) {
case kTfLitePaddingUnknown:
return Padding_SAME; // TODO(aselle): Consider an error.
case kTfLitePaddingSame:
return Padding_SAME;
case kTfLitePaddingValid:
return Padding_VALID;
}
return Padding_SAME; // TODO(aselle): Consider an error.
}
inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
switch (type) {
// case kTfLiteNoType: return TensorType_NONE;
case kTfLiteNoType:
return TensorType_FLOAT32; // TODO(aselle): Consider an error.
case kTfLiteFloat32:
return TensorType_FLOAT32;
case kTfLiteInt32:
return TensorType_INT32;
case kTfLiteUInt8:
return TensorType_UINT8;
case kTfLiteInt64:
return TensorType_INT64;
case kTfLiteString:
return TensorType_STRING;
case kTfLiteBool:
return TensorType_BOOL;
case kTfLiteInt16:
return TensorType_INT16;
case kTfLiteComplex64:
return TensorType_COMPLEX64;
}
// TODO(aselle): consider an error
}
inline FullyConnectedOptionsWeightsFormat
FullyConnectedOptionsWeightsFormatToSchema(
TfLiteFullyConnectedWeightsFormat format) {
switch (format) {
case kTfLiteFullyConnectedWeightsFormatDefault:
return FullyConnectedOptionsWeightsFormat_DEFAULT;
case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
}
}
inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
switch (type) {
case kTfLiteLSTMFullKernel:
return LSTMKernelType_FULL;
case kTfLiteLSTMBasicKernel:
return LSTMKernelType_BASIC;
}
}
inline LSHProjectionType LSHProjectionTypeToSchema(
TfLiteLSHProjectionType type) {
switch (type) {
case kTfLiteLshProjectionUnknown:
return LSHProjectionType_UNKNOWN;
case kTfLiteLshProjectionSparse:
return LSHProjectionType_SPARSE;
case kTfLiteLshProjectionDense:
return LSHProjectionType_DENSE;
}
}
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_

View File

@ -1,370 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ctype.h>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include "flatbuffers/minireflect.h" // flatbuffers
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
namespace tflite {
namespace {
// This is generated by grepping
// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
static const char* param_structs[] = {"TfLiteConvParams",
"TfLitePoolParams",
"TfLiteDepthwiseConvParams",
"TfLiteSVDFParams",
"TfLiteRNNParams",
"TfLiteSequenceRNNParams",
"TfLiteFullyConnectedParams",
"TfLiteLSHProjectionParams",
"TfLiteSoftmaxParams",
"TfLiteConcatenationParams",
"TfLiteAddParams",
"TfLiteSpaceToBatchNDParams",
"TfLiteBatchToSpaceNDParams",
"TfLiteMulParams",
"TfLiteSubParams",
"TfLiteDivParams",
"TfLiteL2NormParams",
"TfLiteLocalResponseNormParams",
"TfLiteLSTMParams",
"TfLiteResizeBilinearParams",
"TfLitePadParams",
"TfLitePadV2Params",
"TfLiteReshapeParams",
"TfLiteSkipGramParams",
"TfLiteSpaceToDepthParams",
"TfLiteCastParams",
"TfLiteEmbeddingLookupSparseParams",
"TfLiteGatherParams",
"TfLiteTransposeParams",
"TfLiteReducerParams",
"TfLiteSplitParams",
"TfLiteSqueezeParams",
"TfLiteStridedSliceParams",
"TfLiteArgMaxParams",
"TfLiteArgMinParams",
"TfLiteTransposeConvParams",
"TfLiteSparseToDenseParams",
"TfLiteShapeParams",
"TfLiteFakeQuantParams",
"TfLitePackParams",
"TfLiteOneHotParams",
nullptr};
} // namespace
// Get rid of all underscores and make everything lower case to make name
// matching work for stuff like 3D vs 3d or RNN vs Rnn.
std::string ToCollapsed(const std::string& in) {
const char* s = in.c_str();
bool first = true;
std::string out;
while (*s != '\0') {
if (*s == '_') {
first = true;
} else if (first) {
out.push_back(tolower(*s));
first = false;
} else {
out.push_back(tolower(*s));
}
s++;
}
return out;
}
// A collection of information about builtin ops.
class OpOptionData {
public:
OpOptionData() {
BuildOpList();
BuildOptionToTypeFunctionMap();
BuildOpToOptionMap();
}
// A list of builtin operations
const std::vector<std::string>& ops() const { return ops_; }
// Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
const std::unordered_map<std::string, std::string>& op_to_option() {
return op_to_option_;
}
// Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
const std::unordered_map<std::string, std::string>& option_to_struct() {
return option_to_struct_;
}
// Maps from option to a flatbuffer type function that describes that option.
const std::unordered_map<std::string, flatbuffers::TypeFunction>&
option_to_type_function() {
return option_to_type_function_;
}
private:
void BuildOpList() {
for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
++curr) {
if (strlen(*curr) != 0) ops_.push_back(*curr);
}
}
void BuildOptionToTypeFunctionMap() {
auto d = tflite::BuiltinOptionsTypeTable();
for (int i = 0; i < d->num_elems; i++) {
flatbuffers::TypeCode code = d->type_codes[i];
if (code.sequence_ref != -1) {
option_to_type_function_.insert(
std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
}
}
}
void BuildOpToOptionMap() {
// Manually specified mappings between ops and options
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
op_to_option_["REDUCE_MIN"] = "ReducerOptions";
op_to_option_["REDUCE_ANY"] = "ReducerOptions";
op_to_option_["UNPACK"] = "";
op_to_option_["SUM"] = "ReducerOptions";
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
op_to_option_["REDUCE_PROD"] = "ReducerOptions";
op_to_option_["MEAN"] = "ReducerOptions";
op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
// Manually specified mappings between ops and options (none)
op_to_option_["EMBEDDING_LOOKUP"] =
""; // TODO(aselle): maybe something else.
op_to_option_["FLOOR"] = "";
op_to_option_["HASHTABLE_LOOKUP"] =
""; // TODO(aselle): maybe something else.
op_to_option_["LOGISTIC"] = "";
op_to_option_["RELU"] = "";
op_to_option_["RELU_N1_TO_1"] = "";
op_to_option_["RELU6"] = "";
op_to_option_["TANH"] = "";
op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
op_to_option_["PRELU"] = "";
op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
op_to_option_["SIN"] = "";
op_to_option_["LOG"] = "";
op_to_option_["SQRT"] = "";
op_to_option_["RSQRT"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
option_to_struct_["LocalResponseNormalizationOptions"] =
"TfLiteLocalResponseNormParams";
// Now for every op, try to find an option.
bool fatal = false;
for (auto op_name : ops_) {
bool found_option = false;
auto d = tflite::BuiltinOptionsTypeTable();
std::string collapsed_option_name_guess =
ToCollapsed(op_name) + "options";
// O(n^2) but not that big of n.
for (int i = 0; i < d->num_elems; i++) {
std::string option_name = d->names[i];
std::string collapsed_option_name = ToCollapsed(option_name);
if (collapsed_option_name_guess == collapsed_option_name) {
op_to_option_.insert(std::make_pair(op_name, option_name));
found_option = true;
break;
}
}
auto it = op_to_option_.find(op_name);
if (it == op_to_option_.end()) {
std::cerr << "Didn't find option for " << op_name << std::endl;
fatal = true;
} else if (!it->second.empty()) {
std::string option_name = it->second;
if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
bool param_struct_found = false;
std::string params_guess = std::string("TfLite") + option_name;
size_t start = params_guess.find("Options");
size_t len = strlen("Options");
params_guess.replace(start, len, "Params");
for (auto* param = param_structs; *param != nullptr; param++) {
if (*param == params_guess) {
param_struct_found = true;
break;
}
}
if (!param_struct_found) {
std::cerr << "Failed to get param struct for option " << option_name
<< std::endl;
fatal = true;
} else {
option_to_struct_.insert(std::make_pair(option_name, params_guess));
}
}
}
}
}
private:
std::vector<std::string> ops_;
std::unordered_map<std::string, std::string> op_to_option_;
std::unordered_map<std::string, std::string> option_to_struct_;
std::unordered_map<std::string, flatbuffers::TypeFunction>
option_to_type_function_;
};
void GenerateImportForOp(FILE* fp, const std::string& op_name,
const std::string& option_name,
const std::string& option_type,
const flatbuffers::TypeTable* options,
const std::string& struct_name) {
// Skip tricky ones for now
if (struct_name == "TfLiteResizeBilinearParams") return;
if (struct_name == "TfLiteSqueezeParams") return;
if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
if (struct_name == "TfLiteReshapeParams") return;
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
fprintf(fp,
" const auto* params = reinterpret_cast<const "
"%s*>(builtin_op_data);\n",
struct_name.c_str());
for (size_t i = 0; i < options->num_elems; i++) {
std::string elem_name = options->names[i];
// TODO(aselle): Irregular naming in builtins
if (elem_name == "fused_activation_function")
elem_name = "activation";
else if (elem_name == "stride_w")
elem_name = "stride_width";
else if (elem_name == "stride_h")
elem_name = "stride_height";
else if (elem_name == "dilation_h_factor")
elem_name = "dilation_height_factor";
else if (elem_name == "dilation_w_factor")
elem_name = "dilation_width_factor";
else if (elem_name == "new_shape")
elem_name = "shape";
flatbuffers::TypeCode code = options->type_codes[i];
auto contained_type = code.sequence_ref != -1
? options->type_refs[code.sequence_ref]
: nullptr;
std::string mapper = "";
if (contained_type == TensorTypeTypeTable) {
mapper = "TfLiteTypeToSchemaType";
} else if (contained_type == ActivationFunctionTypeTypeTable) {
mapper = "TfLiteActivationToSchemaActivation";
} else if (contained_type == PaddingTypeTable) {
mapper = "TfLitePaddingToSchemaPadding";
} else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
mapper = "FullyConnectedOptionsWeightsFormatToSchema";
} else if (contained_type == LSTMKernelTypeTypeTable) {
mapper = "LSTMKernelTypeToSchema";
} else if (contained_type == LSHProjectionTypeTypeTable) {
mapper = "LSHProjectionTypeToSchema";
}
fprintf(fp,
" auto val%zu = "
"%s(params->%s);\n",
i, mapper.c_str(), elem_name.c_str());
}
fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
for (size_t i = 0; i < options->num_elems; i++) {
fprintf(fp, ", val%zu", i);
}
fprintf(fp, ").Union();\n");
fprintf(fp, " return std::make_pair(%s, union_type);\n",
option_type.c_str());
fprintf(fp, " }\n break;\n");
}
void GenerateImport(OpOptionData* option, FILE* fp) {
std::unordered_set<std::string> ignores;
ignores.insert("CONCAT_EMBEDDINGS");
ignores.insert("CALL");
// Allow any op that doesn't have an options struct to be blocked
// together
for (const auto& op_name : option->ops()) {
auto option_it = option->op_to_option().find(op_name);
if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
continue;
fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
}
fprintf(fp,
" return std::make_pair(BuiltinOptions_NONE, "
"flatbuffers::Offset<void>());\n break;\n");
// Iterate over each ops
for (const auto& op_name : option->ops()) {
if (ignores.find(op_name) != ignores.end()) continue;
// Get to the option and struct names, continuing if not found.
auto option_it = option->op_to_option().find(op_name);
if (option_it->second.empty()) continue;
std::string option_name = option_it->second;
std::string option_type = "BuiltinOptions_" + option_name;
auto option_func_it = option->option_to_type_function().find(option_name);
if (option_func_it == option->option_to_type_function().end()) continue;
auto struct_name_it = option->option_to_struct().find(option_name);
if (struct_name_it == option->option_to_struct().end()) {
// If no C struct, then it better have no arguments.
auto type_info = option_func_it->second();
if (type_info->num_elems != 0) {
// We have non-zero arguments in the schema, this means there
// should be a struct.
fprintf(stderr,
"Op %s uses option struct %s which has no builtin struct\n",
op_name.c_str(), option_name.c_str());
exit(1);
}
fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
option_type.c_str(), option_name.c_str());
} else {
// If C struct, then we need to assign all properties
auto struct_name = struct_name_it->second;
GenerateImportForOp(fp, op_name, option_name, option_type,
option_func_it->second(), struct_name);
}
}
// TODO(aselle): Handle unhandled cases more gracefully.
fprintf(fp,
"default: return std::make_pair(BuiltinOptions_NONE, "
"flatbuffers::Offset<void>());\n break;\n");
}
} // namespace tflite
int main(int argc, char* argv[]) {
tflite::OpOptionData option;
if (argc != 2) {
fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
return 1;
}
FILE* fp = fopen(argv[1], "w");
tflite::GenerateImport(&option, fp);
fclose(fp);
}

View File

@ -1,41 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Just does a read/write loop of tflite file format using the interpreter as
// an intermediate.
//
// Usage:
// writer <input tflite> <output tflite>
#include <iostream>
#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
int main(int argc, char* argv[]) {
if (argc != 3) {
fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
return 1;
}
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(argv[1]);
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
tflite::InterpreterWriter writer(interpreter.get());
writer.Write(argv[2]);
return 0;
}

View File

@ -1,281 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
#include <cstdlib>
#include <cstring>
#include <unordered_map>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
template <class T>
using Offset = flatbuffers::Offset<T>;
template <class T>
using Vector = flatbuffers::Vector<T>;
using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
switch (op) {
#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
}
return std::make_pair(BuiltinOptions_NONE, Offset<void>());
}
template <class T_OUTPUT, class T_INPUT>
Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
const T_INPUT& v) {
std::vector<T_OUTPUT> inputs(v.begin(), v.end());
return fbb->template CreateVector<T_OUTPUT>(inputs);
}
Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
FlatBufferBuilder* fbb) {
std::vector<Offset<Operator>> operators;
std::vector<int> operator_to_opcode;
// TODO(aselle): Augment this once we put execution plan in schema.
operator_to_opcode.resize(interpreter_->nodes_size(), -1);
for (int op_index : interpreter_->execution_plan()) {
const auto* node_and_registration =
interpreter_->node_and_registration(op_index);
const TfLiteRegistration* registration = &node_and_registration->second;
if (!registration->custom_name) {
operator_to_opcode[op_index] =
GetOpCodeForBuiltin(registration->builtin_code);
} else {
operator_to_opcode[op_index] =
GetOpCodeForCustom(registration->custom_name);
}
}
// second pass serialize operators
for (int op_index : interpreter_->execution_plan()) {
const auto* node_and_registration =
interpreter_->node_and_registration(op_index);
const TfLiteNode& node = node_and_registration->first;
const TfLiteRegistration& registration = node_and_registration->second;
Offset<void> builtin_options;
BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
// Custom data
// TODO(aselle): Custom options format is not known by default. Just assume
// for now.
auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
Offset<Vector<uint8_t>> custom_options = 0;
if (!registration.custom_name) {
// builtin
auto builtin_options_and_type = CreateBuiltinUnion(
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
node.builtin_data);
builtin_options = builtin_options_and_type.second;
builtin_options_type = builtin_options_and_type.first;
} else {
auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
if (custom_writer != custom_op_to_writer_.end() &&
custom_writer->second) {
// delegate to custom writer if it exists
custom_writer->second(fbb, interpreter_, op_index, &custom_options,
&custom_options_format);
} else {
// use the custom data as fact
custom_options = fbb->CreateVector(
reinterpret_cast<const uint8_t*>(node.custom_initial_data),
node.custom_initial_data_size);
}
}
int opcode_index = operator_to_opcode[op_index];
std::vector<int> written_inputs =
RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
std::vector<int> written_outputs =
RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
auto inputs = ExportVector<int32_t>(fbb, written_inputs);
auto outputs = ExportVector<int32_t>(fbb, written_outputs);
operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
builtin_options_type, builtin_options,
custom_options, custom_options_format));
}
return fbb->template CreateVector<Offset<Operator>>(operators);
}
Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
FlatBufferBuilder* fbb) {
tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
std::vector<Offset<Tensor>> tensors;
// Make a map from tensor index to whether the tensor is a temporary.
std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
const auto* node_and_registration =
interpreter_->node_and_registration(op_index);
for (auto tensor_index :
TfLiteIntArrayView(node_and_registration->first.temporaries))
tensor_is_temporary[tensor_index] = true;
}
// Now we need to remap all used tensor indices
int curr_output_index = 0;
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
tensor_index++) {
if (!tensor_is_temporary[tensor_index]) {
tensor_to_written_tensor_[tensor_index] = curr_output_index++;
}
}
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
++tensor_index) {
// Skip temporaries.
if (tensor_is_temporary[tensor_index]) continue;
if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
// We only need to convert non temporaries
if (tensor->allocation_type != kTfLiteArenaRw &&
tensor->allocation_type != kTfLiteMmapRo &&
tensor->allocation_type != kTfLiteArenaRwPersistent)
continue;
// Allocate a buffer index
int buffer_index = 0; // This is null
if (tensor->allocation_type == kTfLiteMmapRo) {
buffer_index = buffers_.size();
buffers_.push_back(std::make_pair(
reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
}
// Primitive type.
TensorType type = TfLiteTypeToSchemaType(tensor->type);
// Handle quantization
const Offset<Vector<float>> null_array;
Offset<Vector<float>> scale_array;
Offset<Vector<int64_t>> zero_point_array;
if (tensor->params.scale != 0.f) {
// We have quantization, make a single arugment array (multi channel
// quant needs updating here).
scale_array = fbb->CreateVector<float>({tensor->params.scale});
zero_point_array =
fbb->CreateVector<int64_t>({tensor->params.zero_point});
}
Offset<QuantizationParameters> quantization_params =
CreateQuantizationParameters(*fbb, null_array, null_array,
scale_array, zero_point_array);
// Shape
TfLiteIntArrayView shape_view(tensor->dims);
std::vector<int> shape =
std::vector<int>(shape_view.begin(), shape_view.end());
tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
type, buffer_index,
fbb->CreateString(tensor->name),
quantization_params, tensor->is_variable));
}
}
return fbb->template CreateVector<Offset<Tensor>>(tensors);
}
Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
FlatBufferBuilder* fbb) {
std::vector<Offset<Buffer>> buffer_vector;
for (auto buffer : buffers_) {
auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
}
return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
}
Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
FlatBufferBuilder* fbb) {
std::vector<Offset<OperatorCode>> codes;
for (auto it : opcodes_) {
const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
codes.push_back(CreateOperatorCodeDirect(
*fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
}
return fbb->template CreateVector<Offset<OperatorCode>>(codes);
}
template <class T>
std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
const T& input) {
std::vector<int> output;
output.reserve(input.size());
for (int x : input) {
output.push_back(tensor_to_written_tensor_[x]);
}
return output;
}
TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
size_t* size) {
if (!out || !size) return kTfLiteError;
FlatBufferBuilder builder(/*initial_size=*/10240);
std::vector<Offset<SubGraph>> subgraphs_as_vector;
{ // subgraph specific stuff
auto tensors = ExportTensors(&builder);
std::vector<int> written_inputs =
RemapTensorIndicesToWritten(interpreter_->inputs());
std::vector<int> written_outputs =
RemapTensorIndicesToWritten(interpreter_->outputs());
auto inputs = ExportVector<int32_t>(&builder, written_inputs);
auto outputs = ExportVector<int32_t>(&builder, written_outputs);
auto ops = ExportOperators(&builder);
subgraphs_as_vector.push_back(
CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
}
Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
auto description = builder.CreateString("Exported from Interpreter.");
auto op_codes = CreateOpCodeTable(&builder);
auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs_as_vector),
description, buffers);
::tflite::FinishModelBuffer(builder, model);
const uint8_t* buffer = builder.GetBufferPointer();
*size = builder.GetSize();
(*out).reset(new uint8_t[*size]);
memcpy(out->get(), buffer, *size);
return kTfLiteOk;
}
TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
std::unique_ptr<uint8_t[]> buffer;
size_t size;
TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
FILE* fp = fopen(filename.c_str(), "wb");
if (!fp) return kTfLiteError;
if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
if (fclose(fp)) return kTfLiteError;
return kTfLiteOk;
}
TfLiteStatus InterpreterWriter::RegisterCustomWriter(
const std::string& custom_name, CustomWriter custom_writer) {
if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
return kTfLiteError;
}
custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
return kTfLiteOk;
}
} // namespace tflite

View File

@ -1,126 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
//
// Usage:
// From command line:
// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
// -- foo.tflite foo.out.tflite
//
// From C++
// std::unique_ptr<Interpreter> interpreter;
// // Build Interpreter however
// // ... <omitted>
// InterpreterWriter(interpreter.get()).Write("output.tflite");
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
#include <iostream>
#include <unordered_map>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
// file format.
class InterpreterWriter {
public:
typedef flatbuffers::Offset<Operator> (*CustomWriter)(
flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
int node_index,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
CustomOptionsFormat* custom_options_format);
// Construct an interpreter writer for the specified `interpreter`. Then,
// a uses .Write() or .GetBuffer(...) to extract the data.
explicit InterpreterWriter(Interpreter* interpreter)
: interpreter_(interpreter) {
buffers_.push_back(std::make_pair(nullptr, 0));
}
// Get a buffer and size of a serialized flatbuffer.
TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
// Write the serialized flatbuffer to the prescribed `filename`.
TfLiteStatus Write(const std::string& filename);
// Registers a custom writer for a custom op. The customization allows the
// caller to change the custom data.
TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
CustomWriter custom_writer);
private:
template <class T>
using Offset = flatbuffers::Offset<T>;
template <class T_OUTPUT, class T_INPUT>
Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
flatbuffers::FlatBufferBuilder* fbb);
Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
flatbuffers::FlatBufferBuilder* fbb);
template <class T>
std::vector<int> RemapTensorIndicesToWritten(const T& input);
int GetOpCodeForBuiltin(int builtin_op_index) {
// auto it = builtin_op_to_opcode_.find(builtin_op_index);
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
builtin_op_to_opcode_.insert(
std::make_pair(builtin_op_index, opcodes_.size()));
if (result.second) {
opcodes_.push_back({builtin_op_index, ""});
}
return result.first->second;
}
int GetOpCodeForCustom(const std::string& custom_name) {
std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
custom_op_to_opcode_.insert(
std::make_pair(custom_name, opcodes_.size()));
if (result.second) {
opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
}
return result.first->second;
}
// The interpreter we are writing
Interpreter* interpreter_;
// Keep track of byte buffers
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
// List of op codes and mappings from builtin or custom op to opcode
struct OpCode {
int builtin;
std::string custom;
};
// For every tensor index in the interpreter, the index in the written.
// This is different due to temporary tensors not being written.
std::vector<int> tensor_to_written_tensor_;
// List of used opcodes
std::vector<OpCode> opcodes_;
std::unordered_map<int, int> builtin_op_to_opcode_;
std::unordered_map<std::string, int> custom_op_to_opcode_;
std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
};
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_

View File

@ -1,62 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
// Make an interpreter that has no tensors and no nodes
// TODO(b/113731921): add more tests.
TEST(Writer, BasicTest) {
Interpreter interpreter;
interpreter.AddTensors(3);
float foo[] = {1, 2, 3};
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
TfLiteQuantizationParams());
interpreter.SetTensorParametersReadOnly(
1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
reinterpret_cast<char*>(foo), sizeof(foo));
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
TfLiteQuantizationParams());
interpreter.SetInputs({0, 1});
interpreter.SetOutputs({2});
const char* initial_data = "";
tflite::ops::builtin::BuiltinOpResolver resolver;
TfLiteAddParams* builtin_data =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
builtin_data->activation = kTfLiteActNone;
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
reinterpret_cast<void*>(builtin_data), reg);
InterpreterWriter writer(&interpreter);
writer.Write("/tmp/test.tflite");
std::unique_ptr<FlatBufferModel> model =
FlatBufferModel::BuildFromFile("/tmp/test.tflite");
InterpreterBuilder builder(*model, resolver);
std::unique_ptr<Interpreter> new_interpreter;
builder(&new_interpreter);
}
} // namespace tflite
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -46,8 +46,6 @@ void MutableOpResolver::AddCustom(const char* name,
TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
// TODO(aselle): This should verify that the incoming registration
// has the name in the registration already and it matches!!!
TfLiteRegistration new_registration = *registration;
new_registration.builtin_code = BuiltinOperator_CUSTOM;
new_registration.version = version;

View File

@ -56,20 +56,6 @@ flatbuffer_cc_library(
srcs = ["schema.fbs"],
)
# Generic schema for inference on device (but with reflections makes bigger).
flatbuffer_cc_library(
name = "schema_fbs_with_reflection",
srcs = ["schema.fbs"],
flatc_args = [
"--reflect-types",
"--reflect-names",
"--no-union-value-namespacing",
"--gen-object-api",
],
gen_reflections = True,
out_prefix = "reflection/",
)
# Schema test to make sure we don't introduce backward incompatible changes
# to schemas.
cc_test(