parent
69753ba5db
commit
0065d3389a
@ -1,64 +0,0 @@
|
||||
package(default_visibility = [
|
||||
"//visibility:public",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
cc_binary(
|
||||
name = "option_writer_generator",
|
||||
srcs = ["option_writer_generator.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "writer_lib",
|
||||
srcs = [
|
||||
"enum_mapping.h",
|
||||
"writer_lib.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"writer_lib.h",
|
||||
],
|
||||
textual_hdrs = ["option_writer_generated.h"],
|
||||
deps = [
|
||||
":option_writer_gen",
|
||||
"//tensorflow/contrib/lite:builtin_op_data",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:schema_fbs_version",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "writer",
|
||||
srcs = ["writer.cc"],
|
||||
deps = [
|
||||
":writer_lib",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "writer_lib_test",
|
||||
size = "small",
|
||||
srcs = ["writer_lib_test.cc"],
|
||||
deps = [
|
||||
":writer_lib",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/testing:util",
|
||||
"//testing/base/public:gunit",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "option_writer_gen",
|
||||
outs = ["option_writer_generated.h"],
|
||||
cmd = "$(location :option_writer_generator) $(@)",
|
||||
tools = [":option_writer_generator"],
|
||||
)
|
@ -1,116 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
|
||||
|
||||
#include "tensorflow/contrib/lite/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
|
||||
|
||||
// TODO(aselle): Ideally extract this from the schema.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
inline ActivationFunctionType TfLiteActivationToSchemaActivation(
|
||||
TfLiteFusedActivation act) {
|
||||
switch (act) {
|
||||
case kTfLiteActNone:
|
||||
return ActivationFunctionType_NONE;
|
||||
case kTfLiteActRelu:
|
||||
return ActivationFunctionType_RELU;
|
||||
case kTfLiteActRelu1:
|
||||
return ActivationFunctionType_RELU_N1_TO_1;
|
||||
case kTfLiteActRelu6:
|
||||
return ActivationFunctionType_RELU6;
|
||||
case kTfLiteActTanh:
|
||||
return ActivationFunctionType_TANH;
|
||||
case kTfLiteActSignBit:
|
||||
return ActivationFunctionType_SIGN_BIT;
|
||||
case kTfLiteActSigmoid:
|
||||
return ActivationFunctionType_NONE; // TODO(aselle): Add to schema
|
||||
}
|
||||
return ActivationFunctionType_NONE;
|
||||
}
|
||||
|
||||
inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
|
||||
switch (padding) {
|
||||
case kTfLitePaddingUnknown:
|
||||
return Padding_SAME; // TODO(aselle): Consider an error.
|
||||
case kTfLitePaddingSame:
|
||||
return Padding_SAME;
|
||||
case kTfLitePaddingValid:
|
||||
return Padding_VALID;
|
||||
}
|
||||
return Padding_SAME; // TODO(aselle): Consider an error.
|
||||
}
|
||||
|
||||
inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
||||
switch (type) {
|
||||
// case kTfLiteNoType: return TensorType_NONE;
|
||||
case kTfLiteNoType:
|
||||
return TensorType_FLOAT32; // TODO(aselle): Consider an error.
|
||||
case kTfLiteFloat32:
|
||||
return TensorType_FLOAT32;
|
||||
case kTfLiteInt32:
|
||||
return TensorType_INT32;
|
||||
case kTfLiteUInt8:
|
||||
return TensorType_UINT8;
|
||||
case kTfLiteInt64:
|
||||
return TensorType_INT64;
|
||||
case kTfLiteString:
|
||||
return TensorType_STRING;
|
||||
case kTfLiteBool:
|
||||
return TensorType_BOOL;
|
||||
case kTfLiteInt16:
|
||||
return TensorType_INT16;
|
||||
case kTfLiteComplex64:
|
||||
return TensorType_COMPLEX64;
|
||||
}
|
||||
// TODO(aselle): consider an error
|
||||
}
|
||||
|
||||
inline FullyConnectedOptionsWeightsFormat
|
||||
FullyConnectedOptionsWeightsFormatToSchema(
|
||||
TfLiteFullyConnectedWeightsFormat format) {
|
||||
switch (format) {
|
||||
case kTfLiteFullyConnectedWeightsFormatDefault:
|
||||
return FullyConnectedOptionsWeightsFormat_DEFAULT;
|
||||
case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
|
||||
return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
|
||||
}
|
||||
}
|
||||
|
||||
inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
|
||||
switch (type) {
|
||||
case kTfLiteLSTMFullKernel:
|
||||
return LSTMKernelType_FULL;
|
||||
case kTfLiteLSTMBasicKernel:
|
||||
return LSTMKernelType_BASIC;
|
||||
}
|
||||
}
|
||||
|
||||
inline LSHProjectionType LSHProjectionTypeToSchema(
|
||||
TfLiteLSHProjectionType type) {
|
||||
switch (type) {
|
||||
case kTfLiteLshProjectionUnknown:
|
||||
return LSHProjectionType_UNKNOWN;
|
||||
case kTfLiteLshProjectionSparse:
|
||||
return LSHProjectionType_SPARSE;
|
||||
case kTfLiteLshProjectionDense:
|
||||
return LSHProjectionType_DENSE;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
|
@ -1,370 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <ctype.h>
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "flatbuffers/minireflect.h" // flatbuffers
|
||||
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
// This is generated by grepping
|
||||
// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
|
||||
//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
|
||||
static const char* param_structs[] = {"TfLiteConvParams",
|
||||
"TfLitePoolParams",
|
||||
"TfLiteDepthwiseConvParams",
|
||||
"TfLiteSVDFParams",
|
||||
"TfLiteRNNParams",
|
||||
"TfLiteSequenceRNNParams",
|
||||
"TfLiteFullyConnectedParams",
|
||||
"TfLiteLSHProjectionParams",
|
||||
"TfLiteSoftmaxParams",
|
||||
"TfLiteConcatenationParams",
|
||||
"TfLiteAddParams",
|
||||
"TfLiteSpaceToBatchNDParams",
|
||||
"TfLiteBatchToSpaceNDParams",
|
||||
"TfLiteMulParams",
|
||||
"TfLiteSubParams",
|
||||
"TfLiteDivParams",
|
||||
"TfLiteL2NormParams",
|
||||
"TfLiteLocalResponseNormParams",
|
||||
"TfLiteLSTMParams",
|
||||
"TfLiteResizeBilinearParams",
|
||||
"TfLitePadParams",
|
||||
"TfLitePadV2Params",
|
||||
"TfLiteReshapeParams",
|
||||
"TfLiteSkipGramParams",
|
||||
"TfLiteSpaceToDepthParams",
|
||||
"TfLiteCastParams",
|
||||
"TfLiteEmbeddingLookupSparseParams",
|
||||
"TfLiteGatherParams",
|
||||
"TfLiteTransposeParams",
|
||||
"TfLiteReducerParams",
|
||||
"TfLiteSplitParams",
|
||||
"TfLiteSqueezeParams",
|
||||
"TfLiteStridedSliceParams",
|
||||
"TfLiteArgMaxParams",
|
||||
"TfLiteArgMinParams",
|
||||
"TfLiteTransposeConvParams",
|
||||
"TfLiteSparseToDenseParams",
|
||||
"TfLiteShapeParams",
|
||||
"TfLiteFakeQuantParams",
|
||||
"TfLitePackParams",
|
||||
"TfLiteOneHotParams",
|
||||
nullptr};
|
||||
} // namespace
|
||||
|
||||
// Get rid of all underscores and make everything lower case to make name
|
||||
// matching work for stuff like 3D vs 3d or RNN vs Rnn.
|
||||
std::string ToCollapsed(const std::string& in) {
|
||||
const char* s = in.c_str();
|
||||
bool first = true;
|
||||
std::string out;
|
||||
while (*s != '\0') {
|
||||
if (*s == '_') {
|
||||
first = true;
|
||||
} else if (first) {
|
||||
out.push_back(tolower(*s));
|
||||
first = false;
|
||||
} else {
|
||||
out.push_back(tolower(*s));
|
||||
}
|
||||
s++;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// A collection of information about builtin ops.
|
||||
class OpOptionData {
|
||||
public:
|
||||
OpOptionData() {
|
||||
BuildOpList();
|
||||
BuildOptionToTypeFunctionMap();
|
||||
BuildOpToOptionMap();
|
||||
}
|
||||
|
||||
// A list of builtin operations
|
||||
const std::vector<std::string>& ops() const { return ops_; }
|
||||
// Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
|
||||
const std::unordered_map<std::string, std::string>& op_to_option() {
|
||||
return op_to_option_;
|
||||
}
|
||||
// Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
|
||||
const std::unordered_map<std::string, std::string>& option_to_struct() {
|
||||
return option_to_struct_;
|
||||
}
|
||||
// Maps from option to a flatbuffer type function that describes that option.
|
||||
const std::unordered_map<std::string, flatbuffers::TypeFunction>&
|
||||
option_to_type_function() {
|
||||
return option_to_type_function_;
|
||||
}
|
||||
|
||||
private:
|
||||
void BuildOpList() {
|
||||
for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
|
||||
++curr) {
|
||||
if (strlen(*curr) != 0) ops_.push_back(*curr);
|
||||
}
|
||||
}
|
||||
|
||||
void BuildOptionToTypeFunctionMap() {
|
||||
auto d = tflite::BuiltinOptionsTypeTable();
|
||||
for (int i = 0; i < d->num_elems; i++) {
|
||||
flatbuffers::TypeCode code = d->type_codes[i];
|
||||
if (code.sequence_ref != -1) {
|
||||
option_to_type_function_.insert(
|
||||
std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BuildOpToOptionMap() {
|
||||
// Manually specified mappings between ops and options
|
||||
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_MIN"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_ANY"] = "ReducerOptions";
|
||||
op_to_option_["UNPACK"] = "";
|
||||
op_to_option_["SUM"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
|
||||
op_to_option_["REDUCE_PROD"] = "ReducerOptions";
|
||||
op_to_option_["MEAN"] = "ReducerOptions";
|
||||
op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
|
||||
op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
|
||||
op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
|
||||
op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
|
||||
op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
|
||||
op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
|
||||
op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
|
||||
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
|
||||
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
|
||||
// Manually specified mappings between ops and options (none)
|
||||
op_to_option_["EMBEDDING_LOOKUP"] =
|
||||
""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["FLOOR"] = "";
|
||||
op_to_option_["HASHTABLE_LOOKUP"] =
|
||||
""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["LOGISTIC"] = "";
|
||||
op_to_option_["RELU"] = "";
|
||||
op_to_option_["RELU_N1_TO_1"] = "";
|
||||
op_to_option_["RELU6"] = "";
|
||||
op_to_option_["TANH"] = "";
|
||||
op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
|
||||
op_to_option_["PRELU"] = "";
|
||||
op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
|
||||
op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
|
||||
op_to_option_["SIN"] = "";
|
||||
op_to_option_["LOG"] = "";
|
||||
op_to_option_["SQRT"] = "";
|
||||
op_to_option_["RSQRT"] = "";
|
||||
|
||||
// TODO(aselle): These are undesirable hacks. Consider changing C structs
|
||||
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
|
||||
option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
|
||||
option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
|
||||
option_to_struct_["LocalResponseNormalizationOptions"] =
|
||||
"TfLiteLocalResponseNormParams";
|
||||
// Now for every op, try to find an option.
|
||||
bool fatal = false;
|
||||
for (auto op_name : ops_) {
|
||||
bool found_option = false;
|
||||
auto d = tflite::BuiltinOptionsTypeTable();
|
||||
std::string collapsed_option_name_guess =
|
||||
ToCollapsed(op_name) + "options";
|
||||
// O(n^2) but not that big of n.
|
||||
for (int i = 0; i < d->num_elems; i++) {
|
||||
std::string option_name = d->names[i];
|
||||
std::string collapsed_option_name = ToCollapsed(option_name);
|
||||
if (collapsed_option_name_guess == collapsed_option_name) {
|
||||
op_to_option_.insert(std::make_pair(op_name, option_name));
|
||||
found_option = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto it = op_to_option_.find(op_name);
|
||||
if (it == op_to_option_.end()) {
|
||||
std::cerr << "Didn't find option for " << op_name << std::endl;
|
||||
fatal = true;
|
||||
} else if (!it->second.empty()) {
|
||||
std::string option_name = it->second;
|
||||
|
||||
if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
|
||||
bool param_struct_found = false;
|
||||
std::string params_guess = std::string("TfLite") + option_name;
|
||||
size_t start = params_guess.find("Options");
|
||||
size_t len = strlen("Options");
|
||||
params_guess.replace(start, len, "Params");
|
||||
for (auto* param = param_structs; *param != nullptr; param++) {
|
||||
if (*param == params_guess) {
|
||||
param_struct_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!param_struct_found) {
|
||||
std::cerr << "Failed to get param struct for option " << option_name
|
||||
<< std::endl;
|
||||
fatal = true;
|
||||
} else {
|
||||
option_to_struct_.insert(std::make_pair(option_name, params_guess));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> ops_;
|
||||
std::unordered_map<std::string, std::string> op_to_option_;
|
||||
std::unordered_map<std::string, std::string> option_to_struct_;
|
||||
std::unordered_map<std::string, flatbuffers::TypeFunction>
|
||||
option_to_type_function_;
|
||||
};
|
||||
|
||||
void GenerateImportForOp(FILE* fp, const std::string& op_name,
|
||||
const std::string& option_name,
|
||||
const std::string& option_type,
|
||||
const flatbuffers::TypeTable* options,
|
||||
const std::string& struct_name) {
|
||||
// Skip tricky ones for now
|
||||
if (struct_name == "TfLiteResizeBilinearParams") return;
|
||||
if (struct_name == "TfLiteSqueezeParams") return;
|
||||
if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
|
||||
if (struct_name == "TfLiteReshapeParams") return;
|
||||
|
||||
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
|
||||
fprintf(fp,
|
||||
" const auto* params = reinterpret_cast<const "
|
||||
"%s*>(builtin_op_data);\n",
|
||||
struct_name.c_str());
|
||||
|
||||
for (size_t i = 0; i < options->num_elems; i++) {
|
||||
std::string elem_name = options->names[i];
|
||||
// TODO(aselle): Irregular naming in builtins
|
||||
if (elem_name == "fused_activation_function")
|
||||
elem_name = "activation";
|
||||
else if (elem_name == "stride_w")
|
||||
elem_name = "stride_width";
|
||||
else if (elem_name == "stride_h")
|
||||
elem_name = "stride_height";
|
||||
else if (elem_name == "dilation_h_factor")
|
||||
elem_name = "dilation_height_factor";
|
||||
else if (elem_name == "dilation_w_factor")
|
||||
elem_name = "dilation_width_factor";
|
||||
else if (elem_name == "new_shape")
|
||||
elem_name = "shape";
|
||||
|
||||
flatbuffers::TypeCode code = options->type_codes[i];
|
||||
auto contained_type = code.sequence_ref != -1
|
||||
? options->type_refs[code.sequence_ref]
|
||||
: nullptr;
|
||||
std::string mapper = "";
|
||||
if (contained_type == TensorTypeTypeTable) {
|
||||
mapper = "TfLiteTypeToSchemaType";
|
||||
} else if (contained_type == ActivationFunctionTypeTypeTable) {
|
||||
mapper = "TfLiteActivationToSchemaActivation";
|
||||
} else if (contained_type == PaddingTypeTable) {
|
||||
mapper = "TfLitePaddingToSchemaPadding";
|
||||
} else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
|
||||
mapper = "FullyConnectedOptionsWeightsFormatToSchema";
|
||||
} else if (contained_type == LSTMKernelTypeTypeTable) {
|
||||
mapper = "LSTMKernelTypeToSchema";
|
||||
} else if (contained_type == LSHProjectionTypeTypeTable) {
|
||||
mapper = "LSHProjectionTypeToSchema";
|
||||
}
|
||||
|
||||
fprintf(fp,
|
||||
" auto val%zu = "
|
||||
"%s(params->%s);\n",
|
||||
i, mapper.c_str(), elem_name.c_str());
|
||||
}
|
||||
fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
|
||||
for (size_t i = 0; i < options->num_elems; i++) {
|
||||
fprintf(fp, ", val%zu", i);
|
||||
}
|
||||
fprintf(fp, ").Union();\n");
|
||||
fprintf(fp, " return std::make_pair(%s, union_type);\n",
|
||||
option_type.c_str());
|
||||
fprintf(fp, " }\n break;\n");
|
||||
}
|
||||
|
||||
void GenerateImport(OpOptionData* option, FILE* fp) {
|
||||
std::unordered_set<std::string> ignores;
|
||||
ignores.insert("CONCAT_EMBEDDINGS");
|
||||
ignores.insert("CALL");
|
||||
|
||||
// Allow any op that doesn't have an options struct to be blocked
|
||||
// together
|
||||
for (const auto& op_name : option->ops()) {
|
||||
auto option_it = option->op_to_option().find(op_name);
|
||||
if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
|
||||
continue;
|
||||
fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
|
||||
}
|
||||
fprintf(fp,
|
||||
" return std::make_pair(BuiltinOptions_NONE, "
|
||||
"flatbuffers::Offset<void>());\n break;\n");
|
||||
|
||||
// Iterate over each ops
|
||||
for (const auto& op_name : option->ops()) {
|
||||
if (ignores.find(op_name) != ignores.end()) continue;
|
||||
// Get to the option and struct names, continuing if not found.
|
||||
auto option_it = option->op_to_option().find(op_name);
|
||||
if (option_it->second.empty()) continue;
|
||||
std::string option_name = option_it->second;
|
||||
std::string option_type = "BuiltinOptions_" + option_name;
|
||||
auto option_func_it = option->option_to_type_function().find(option_name);
|
||||
if (option_func_it == option->option_to_type_function().end()) continue;
|
||||
auto struct_name_it = option->option_to_struct().find(option_name);
|
||||
if (struct_name_it == option->option_to_struct().end()) {
|
||||
// If no C struct, then it better have no arguments.
|
||||
auto type_info = option_func_it->second();
|
||||
if (type_info->num_elems != 0) {
|
||||
// We have non-zero arguments in the schema, this means there
|
||||
// should be a struct.
|
||||
fprintf(stderr,
|
||||
"Op %s uses option struct %s which has no builtin struct\n",
|
||||
op_name.c_str(), option_name.c_str());
|
||||
exit(1);
|
||||
}
|
||||
fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
|
||||
fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
|
||||
option_type.c_str(), option_name.c_str());
|
||||
} else {
|
||||
// If C struct, then we need to assign all properties
|
||||
auto struct_name = struct_name_it->second;
|
||||
GenerateImportForOp(fp, op_name, option_name, option_type,
|
||||
option_func_it->second(), struct_name);
|
||||
}
|
||||
}
|
||||
// TODO(aselle): Handle unhandled cases more gracefully.
|
||||
fprintf(fp,
|
||||
"default: return std::make_pair(BuiltinOptions_NONE, "
|
||||
"flatbuffers::Offset<void>());\n break;\n");
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tflite::OpOptionData option;
|
||||
if (argc != 2) {
|
||||
fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
FILE* fp = fopen(argv[1], "w");
|
||||
tflite::GenerateImport(&option, fp);
|
||||
fclose(fp);
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Just does a read/write loop of tflite file format using the interpreter as
|
||||
// an intermediate.
|
||||
//
|
||||
// Usage:
|
||||
// writer <input tflite> <output tflite>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc != 3) {
|
||||
fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||
tflite::FlatBufferModel::BuildFromFile(argv[1]);
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
|
||||
tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
|
||||
tflite::InterpreterWriter writer(interpreter.get());
|
||||
writer.Write(argv[2]);
|
||||
|
||||
return 0;
|
||||
}
|
@ -1,281 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/contrib/lite/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/context_util.h"
|
||||
#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
|
||||
#include "tensorflow/contrib/lite/interpreter.h"
|
||||
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
|
||||
#include "tensorflow/contrib/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
template <class T>
|
||||
using Offset = flatbuffers::Offset<T>;
|
||||
template <class T>
|
||||
using Vector = flatbuffers::Vector<T>;
|
||||
using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
|
||||
|
||||
std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
|
||||
FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
|
||||
switch (op) {
|
||||
#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
|
||||
}
|
||||
return std::make_pair(BuiltinOptions_NONE, Offset<void>());
|
||||
}
|
||||
|
||||
template <class T_OUTPUT, class T_INPUT>
|
||||
Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
|
||||
const T_INPUT& v) {
|
||||
std::vector<T_OUTPUT> inputs(v.begin(), v.end());
|
||||
return fbb->template CreateVector<T_OUTPUT>(inputs);
|
||||
}
|
||||
|
||||
Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
|
||||
FlatBufferBuilder* fbb) {
|
||||
std::vector<Offset<Operator>> operators;
|
||||
|
||||
std::vector<int> operator_to_opcode;
|
||||
// TODO(aselle): Augment this once we put execution plan in schema.
|
||||
operator_to_opcode.resize(interpreter_->nodes_size(), -1);
|
||||
for (int op_index : interpreter_->execution_plan()) {
|
||||
const auto* node_and_registration =
|
||||
interpreter_->node_and_registration(op_index);
|
||||
const TfLiteRegistration* registration = &node_and_registration->second;
|
||||
if (!registration->custom_name) {
|
||||
operator_to_opcode[op_index] =
|
||||
GetOpCodeForBuiltin(registration->builtin_code);
|
||||
} else {
|
||||
operator_to_opcode[op_index] =
|
||||
GetOpCodeForCustom(registration->custom_name);
|
||||
}
|
||||
}
|
||||
// second pass serialize operators
|
||||
for (int op_index : interpreter_->execution_plan()) {
|
||||
const auto* node_and_registration =
|
||||
interpreter_->node_and_registration(op_index);
|
||||
const TfLiteNode& node = node_and_registration->first;
|
||||
const TfLiteRegistration& registration = node_and_registration->second;
|
||||
Offset<void> builtin_options;
|
||||
BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
|
||||
// Custom data
|
||||
// TODO(aselle): Custom options format is not known by default. Just assume
|
||||
// for now.
|
||||
auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
|
||||
Offset<Vector<uint8_t>> custom_options = 0;
|
||||
|
||||
if (!registration.custom_name) {
|
||||
// builtin
|
||||
auto builtin_options_and_type = CreateBuiltinUnion(
|
||||
fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
|
||||
node.builtin_data);
|
||||
builtin_options = builtin_options_and_type.second;
|
||||
builtin_options_type = builtin_options_and_type.first;
|
||||
} else {
|
||||
auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
|
||||
if (custom_writer != custom_op_to_writer_.end() &&
|
||||
custom_writer->second) {
|
||||
// delegate to custom writer if it exists
|
||||
custom_writer->second(fbb, interpreter_, op_index, &custom_options,
|
||||
&custom_options_format);
|
||||
} else {
|
||||
// use the custom data as fact
|
||||
custom_options = fbb->CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(node.custom_initial_data),
|
||||
node.custom_initial_data_size);
|
||||
}
|
||||
}
|
||||
|
||||
int opcode_index = operator_to_opcode[op_index];
|
||||
std::vector<int> written_inputs =
|
||||
RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
|
||||
std::vector<int> written_outputs =
|
||||
RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
|
||||
auto inputs = ExportVector<int32_t>(fbb, written_inputs);
|
||||
auto outputs = ExportVector<int32_t>(fbb, written_outputs);
|
||||
operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
|
||||
builtin_options_type, builtin_options,
|
||||
custom_options, custom_options_format));
|
||||
}
|
||||
|
||||
return fbb->template CreateVector<Offset<Operator>>(operators);
|
||||
}
|
||||
|
||||
Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
|
||||
FlatBufferBuilder* fbb) {
|
||||
tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
|
||||
|
||||
std::vector<Offset<Tensor>> tensors;
|
||||
|
||||
// Make a map from tensor index to whether the tensor is a temporary.
|
||||
std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
|
||||
for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
|
||||
const auto* node_and_registration =
|
||||
interpreter_->node_and_registration(op_index);
|
||||
for (auto tensor_index :
|
||||
TfLiteIntArrayView(node_and_registration->first.temporaries))
|
||||
tensor_is_temporary[tensor_index] = true;
|
||||
}
|
||||
|
||||
// Now we need to remap all used tensor indices
|
||||
int curr_output_index = 0;
|
||||
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
|
||||
tensor_index++) {
|
||||
if (!tensor_is_temporary[tensor_index]) {
|
||||
tensor_to_written_tensor_[tensor_index] = curr_output_index++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
|
||||
++tensor_index) {
|
||||
// Skip temporaries.
|
||||
if (tensor_is_temporary[tensor_index]) continue;
|
||||
|
||||
if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
|
||||
// We only need to convert non temporaries
|
||||
if (tensor->allocation_type != kTfLiteArenaRw &&
|
||||
tensor->allocation_type != kTfLiteMmapRo &&
|
||||
tensor->allocation_type != kTfLiteArenaRwPersistent)
|
||||
continue;
|
||||
// Allocate a buffer index
|
||||
int buffer_index = 0; // This is null
|
||||
if (tensor->allocation_type == kTfLiteMmapRo) {
|
||||
buffer_index = buffers_.size();
|
||||
buffers_.push_back(std::make_pair(
|
||||
reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
|
||||
}
|
||||
// Primitive type.
|
||||
TensorType type = TfLiteTypeToSchemaType(tensor->type);
|
||||
// Handle quantization
|
||||
const Offset<Vector<float>> null_array;
|
||||
Offset<Vector<float>> scale_array;
|
||||
Offset<Vector<int64_t>> zero_point_array;
|
||||
if (tensor->params.scale != 0.f) {
|
||||
// We have quantization, make a single arugment array (multi channel
|
||||
// quant needs updating here).
|
||||
scale_array = fbb->CreateVector<float>({tensor->params.scale});
|
||||
zero_point_array =
|
||||
fbb->CreateVector<int64_t>({tensor->params.zero_point});
|
||||
}
|
||||
Offset<QuantizationParameters> quantization_params =
|
||||
CreateQuantizationParameters(*fbb, null_array, null_array,
|
||||
scale_array, zero_point_array);
|
||||
// Shape
|
||||
TfLiteIntArrayView shape_view(tensor->dims);
|
||||
std::vector<int> shape =
|
||||
std::vector<int>(shape_view.begin(), shape_view.end());
|
||||
|
||||
tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
|
||||
type, buffer_index,
|
||||
fbb->CreateString(tensor->name),
|
||||
quantization_params, tensor->is_variable));
|
||||
}
|
||||
}
|
||||
return fbb->template CreateVector<Offset<Tensor>>(tensors);
|
||||
}
|
||||
|
||||
Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
|
||||
FlatBufferBuilder* fbb) {
|
||||
std::vector<Offset<Buffer>> buffer_vector;
|
||||
for (auto buffer : buffers_) {
|
||||
auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
|
||||
buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
|
||||
}
|
||||
return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
|
||||
}
|
||||
|
||||
Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
|
||||
FlatBufferBuilder* fbb) {
|
||||
std::vector<Offset<OperatorCode>> codes;
|
||||
for (auto it : opcodes_) {
|
||||
const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
|
||||
codes.push_back(CreateOperatorCodeDirect(
|
||||
*fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
|
||||
}
|
||||
return fbb->template CreateVector<Offset<OperatorCode>>(codes);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
|
||||
const T& input) {
|
||||
std::vector<int> output;
|
||||
output.reserve(input.size());
|
||||
for (int x : input) {
|
||||
output.push_back(tensor_to_written_tensor_[x]);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
|
||||
size_t* size) {
|
||||
if (!out || !size) return kTfLiteError;
|
||||
FlatBufferBuilder builder(/*initial_size=*/10240);
|
||||
|
||||
std::vector<Offset<SubGraph>> subgraphs_as_vector;
|
||||
{ // subgraph specific stuff
|
||||
auto tensors = ExportTensors(&builder);
|
||||
std::vector<int> written_inputs =
|
||||
RemapTensorIndicesToWritten(interpreter_->inputs());
|
||||
std::vector<int> written_outputs =
|
||||
RemapTensorIndicesToWritten(interpreter_->outputs());
|
||||
auto inputs = ExportVector<int32_t>(&builder, written_inputs);
|
||||
auto outputs = ExportVector<int32_t>(&builder, written_outputs);
|
||||
|
||||
auto ops = ExportOperators(&builder);
|
||||
subgraphs_as_vector.push_back(
|
||||
CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
|
||||
}
|
||||
Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
|
||||
|
||||
auto description = builder.CreateString("Exported from Interpreter.");
|
||||
|
||||
auto op_codes = CreateOpCodeTable(&builder);
|
||||
auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
|
||||
builder.CreateVector(subgraphs_as_vector),
|
||||
description, buffers);
|
||||
::tflite::FinishModelBuffer(builder, model);
|
||||
const uint8_t* buffer = builder.GetBufferPointer();
|
||||
*size = builder.GetSize();
|
||||
(*out).reset(new uint8_t[*size]);
|
||||
memcpy(out->get(), buffer, *size);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
|
||||
std::unique_ptr<uint8_t[]> buffer;
|
||||
size_t size;
|
||||
TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
|
||||
|
||||
FILE* fp = fopen(filename.c_str(), "wb");
|
||||
if (!fp) return kTfLiteError;
|
||||
|
||||
if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
|
||||
if (fclose(fp)) return kTfLiteError;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus InterpreterWriter::RegisterCustomWriter(
|
||||
const std::string& custom_name, CustomWriter custom_writer) {
|
||||
if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -1,126 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
|
||||
//
|
||||
// Usage:
|
||||
// From command line:
|
||||
// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
|
||||
// -- foo.tflite foo.out.tflite
|
||||
//
|
||||
// From C++
|
||||
// std::unique_ptr<Interpreter> interpreter;
|
||||
// // Build Interpreter however
|
||||
// // ... <omitted>
|
||||
// InterpreterWriter(interpreter.get()).Write("output.tflite");
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
|
||||
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/contrib/lite/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/context_util.h"
|
||||
#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
|
||||
#include "tensorflow/contrib/lite/interpreter.h"
|
||||
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
|
||||
#include "tensorflow/contrib/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
|
||||
// file format.
|
||||
class InterpreterWriter {
|
||||
public:
|
||||
typedef flatbuffers::Offset<Operator> (*CustomWriter)(
|
||||
flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
|
||||
int node_index,
|
||||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
|
||||
CustomOptionsFormat* custom_options_format);
|
||||
|
||||
// Construct an interpreter writer for the specified `interpreter`. Then,
|
||||
// a uses .Write() or .GetBuffer(...) to extract the data.
|
||||
explicit InterpreterWriter(Interpreter* interpreter)
|
||||
: interpreter_(interpreter) {
|
||||
buffers_.push_back(std::make_pair(nullptr, 0));
|
||||
}
|
||||
|
||||
// Get a buffer and size of a serialized flatbuffer.
|
||||
TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
|
||||
// Write the serialized flatbuffer to the prescribed `filename`.
|
||||
TfLiteStatus Write(const std::string& filename);
|
||||
// Registers a custom writer for a custom op. The customization allows the
|
||||
// caller to change the custom data.
|
||||
TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
|
||||
CustomWriter custom_writer);
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
using Offset = flatbuffers::Offset<T>;
|
||||
template <class T_OUTPUT, class T_INPUT>
|
||||
Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
|
||||
flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
|
||||
Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
|
||||
flatbuffers::FlatBufferBuilder* fbb);
|
||||
Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
|
||||
flatbuffers::FlatBufferBuilder* fbb);
|
||||
Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
|
||||
flatbuffers::FlatBufferBuilder* fbb);
|
||||
Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
|
||||
flatbuffers::FlatBufferBuilder* fbb);
|
||||
|
||||
template <class T>
|
||||
std::vector<int> RemapTensorIndicesToWritten(const T& input);
|
||||
|
||||
int GetOpCodeForBuiltin(int builtin_op_index) {
|
||||
// auto it = builtin_op_to_opcode_.find(builtin_op_index);
|
||||
std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
|
||||
builtin_op_to_opcode_.insert(
|
||||
std::make_pair(builtin_op_index, opcodes_.size()));
|
||||
if (result.second) {
|
||||
opcodes_.push_back({builtin_op_index, ""});
|
||||
}
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
int GetOpCodeForCustom(const std::string& custom_name) {
|
||||
std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
|
||||
custom_op_to_opcode_.insert(
|
||||
std::make_pair(custom_name, opcodes_.size()));
|
||||
if (result.second) {
|
||||
opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
|
||||
}
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
// The interpreter we are writing
|
||||
Interpreter* interpreter_;
|
||||
// Keep track of byte buffers
|
||||
std::vector<std::pair<const uint8_t*, size_t>> buffers_;
|
||||
// List of op codes and mappings from builtin or custom op to opcode
|
||||
struct OpCode {
|
||||
int builtin;
|
||||
std::string custom;
|
||||
};
|
||||
// For every tensor index in the interpreter, the index in the written.
|
||||
// This is different due to temporary tensors not being written.
|
||||
std::vector<int> tensor_to_written_tensor_;
|
||||
// List of used opcodes
|
||||
std::vector<OpCode> opcodes_;
|
||||
std::unordered_map<int, int> builtin_op_to_opcode_;
|
||||
std::unordered_map<std::string, int> custom_op_to_opcode_;
|
||||
std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
|
@ -1,62 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/contrib/lite/interpreter.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
// Make an interpreter that has no tensors and no nodes
|
||||
// TODO(b/113731921): add more tests.
|
||||
TEST(Writer, BasicTest) {
|
||||
Interpreter interpreter;
|
||||
interpreter.AddTensors(3);
|
||||
float foo[] = {1, 2, 3};
|
||||
interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
|
||||
TfLiteQuantizationParams());
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
|
||||
reinterpret_cast<char*>(foo), sizeof(foo));
|
||||
interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
|
||||
TfLiteQuantizationParams());
|
||||
interpreter.SetInputs({0, 1});
|
||||
interpreter.SetOutputs({2});
|
||||
const char* initial_data = "";
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
TfLiteAddParams* builtin_data =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
builtin_data->activation = kTfLiteActNone;
|
||||
const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
|
||||
interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
|
||||
reinterpret_cast<void*>(builtin_data), reg);
|
||||
|
||||
InterpreterWriter writer(&interpreter);
|
||||
writer.Write("/tmp/test.tflite");
|
||||
std::unique_ptr<FlatBufferModel> model =
|
||||
FlatBufferModel::BuildFromFile("/tmp/test.tflite");
|
||||
InterpreterBuilder builder(*model, resolver);
|
||||
std::unique_ptr<Interpreter> new_interpreter;
|
||||
builder(&new_interpreter);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -46,8 +46,6 @@ void MutableOpResolver::AddCustom(const char* name,
|
||||
TfLiteRegistration* registration,
|
||||
int min_version, int max_version) {
|
||||
for (int version = min_version; version <= max_version; ++version) {
|
||||
// TODO(aselle): This should verify that the incoming registration
|
||||
// has the name in the registration already and it matches!!!
|
||||
TfLiteRegistration new_registration = *registration;
|
||||
new_registration.builtin_code = BuiltinOperator_CUSTOM;
|
||||
new_registration.version = version;
|
||||
|
@ -56,20 +56,6 @@ flatbuffer_cc_library(
|
||||
srcs = ["schema.fbs"],
|
||||
)
|
||||
|
||||
# Generic schema for inference on device (but with reflections makes bigger).
|
||||
flatbuffer_cc_library(
|
||||
name = "schema_fbs_with_reflection",
|
||||
srcs = ["schema.fbs"],
|
||||
flatc_args = [
|
||||
"--reflect-types",
|
||||
"--reflect-names",
|
||||
"--no-union-value-namespacing",
|
||||
"--gen-object-api",
|
||||
],
|
||||
gen_reflections = True,
|
||||
out_prefix = "reflection/",
|
||||
)
|
||||
|
||||
# Schema test to make sure we don't introduce backward incompatible changes
|
||||
# to schemas.
|
||||
cc_test(
|
||||
|
Loading…
Reference in New Issue
Block a user