Fixes to experimental TFLite writer library:

1. Bring up to date with builtin-ops / builtin-options.
2. Add support for a few ops previously skipped.
3. Fix related to special optional tensor index.

PiperOrigin-RevId: 234188624
This commit is contained in:
A. Unique TensorFlower 2019-02-15 12:06:31 -08:00 committed by TensorFlower Gardener
parent 398fce0307
commit 8ade0b1b44
3 changed files with 127 additions and 59 deletions

View File

@ -112,5 +112,29 @@ inline LSHProjectionType LSHProjectionTypeToSchema(
}
}
inline MirrorPadMode MirrorPaddingModeToSchema(TfLiteMirrorPaddingMode mode) {
switch (mode) {
case kTfLiteMirrorPaddingUnknown:
return MirrorPadMode_REFLECT; // TODO(aselle): consider an error
case kTfLiteMirrorPaddingReflect:
return MirrorPadMode_REFLECT;
case kTfLiteMirrorPaddingSymmetric:
return MirrorPadMode_SYMMETRIC;
}
}
inline CombinerType CombinerTypeToSchema(TfLiteCombinerType type) {
switch (type) {
case kTfLiteCombinerTypeSum:
return CombinerType_SUM;
case kTfLiteCombinerTypeMean:
return CombinerType_MEAN;
case kTfLiteCombinerTypeSqrtn:
return CombinerType_SQRTN;
}
}
// int
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_

View File

@ -22,54 +22,59 @@ limitations under the License.
namespace tflite {
namespace {
// This is generated by grepping
// cat third_party/tensorflow/lite/builtin_op_data.h
//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
static const char* param_structs[] = {"TfLiteConvParams",
"TfLitePoolParams",
"TfLiteDepthwiseConvParams",
"TfLiteSVDFParams",
"TfLiteRNNParams",
"TfLiteSequenceRNNParams",
"TfLiteFullyConnectedParams",
"TfLiteLSHProjectionParams",
"TfLiteSoftmaxParams",
"TfLiteConcatenationParams",
"TfLiteAddParams",
"TfLiteSpaceToBatchNDParams",
// cat third_party/tensorflow/lite/c/builtin_op_data.h | grep "^} TfLite" |
// sed 's/^} \(TfLite.*\)Params;/\1Params/g' | grep -v "^}" | sed
// 's/\(.*\)/"\1",/g' | sort
static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteArgMaxParams",
"TfLiteArgMinParams",
"TfLiteBatchToSpaceNDParams",
"TfLiteMulParams",
"TfLiteSubParams",
"TfLiteBidirectionalSequenceLSTMParams",
"TfLiteBidirectionalSequenceRNNParams",
"TfLiteCastParams",
"TfLiteConcatenationParams",
"TfLiteConvParams",
"TfLiteDepthwiseConvParams",
"TfLiteDivParams",
"TfLiteEmbeddingLookupSparseParams",
"TfLiteFakeQuantParams",
"TfLiteFullyConnectedParams",
"TfLiteGatherParams",
"TfLiteL2NormParams",
"TfLiteLeakyReluParams",
"TfLiteLocalResponseNormParams",
"TfLiteLSHProjectionParams",
"TfLiteLSTMParams",
"TfLiteResizeBilinearParams",
"TfLiteResizeNearestNeighborParams",
"TfLiteMirrorPaddingParams",
"TfLiteMulParams",
"TfLiteOneHotParams",
"TfLitePackParams",
"TfLitePadParams",
"TfLitePadV2Params",
"TfLiteReshapeParams",
"TfLiteSkipGramParams",
"TfLiteSpaceToDepthParams",
"TfLiteCastParams",
"TfLiteEmbeddingLookupSparseParams",
"TfLiteGatherParams",
"TfLiteTransposeParams",
"TfLitePoolParams",
"TfLiteReducerParams",
"TfLiteReshapeParams",
"TfLiteResizeBilinearParams",
"TfLiteResizeNearestNeighborParams",
"TfLiteRNNParams",
"TfLiteSequenceRNNParams",
"TfLiteShapeParams",
"TfLiteSkipGramParams",
"TfLiteSoftmaxParams",
"TfLiteSpaceToBatchNDParams",
"TfLiteSpaceToDepthParams",
"TfLiteSparseToDenseParams",
"TfLiteSplitParams",
"TfLiteSplitVParams",
"TfLiteSqueezeParams",
"TfLiteStridedSliceParams",
"TfLiteArgMaxParams",
"TfLiteArgMinParams",
"TfLiteSubParams",
"TfLiteSVDFParams",
"TfLiteTransposeConvParams",
"TfLiteSparseToDenseParams",
"TfLiteShapeParams",
"TfLiteFakeQuantParams",
"TfLitePackParams",
"TfLiteOneHotParams",
"TfLiteLeakyReluParams",
"TfLiteMirrorPaddingParams",
"TfLiteTransposeParams",
"TfLiteUnidirectionalSequenceLSTMParams",
"TfLiteUniqueParams",
"TfLiteUnpackParams",
nullptr};
} // namespace
@ -142,7 +147,6 @@ class OpOptionData {
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
op_to_option_["REDUCE_MIN"] = "ReducerOptions";
op_to_option_["REDUCE_ANY"] = "ReducerOptions";
op_to_option_["UNPACK"] = "";
op_to_option_["SUM"] = "ReducerOptions";
op_to_option_["REDUCE_MAX"] = "ReducerOptions";
op_to_option_["REDUCE_PROD"] = "ReducerOptions";
@ -151,35 +155,30 @@ class OpOptionData {
op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["MIRROR_PAD"] = ""; // TODO(karimnosseir): MirrorPadOptions.
op_to_option_["UNIQUE"] = ""; // TODO(karimnosseir): UniqueOptions.
// Manually specified mappings between ops and options (none)
op_to_option_["EMBEDDING_LOOKUP"] =
""; // TODO(aselle): maybe something else.
op_to_option_["MAXIMUM"] = "MaximumMinimumOptions";
op_to_option_["MINIMUM"] = "MaximumMinimumOptions";
op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
// Manually specified mappings between ops to "none" options -- these are
// ops without a corresponding Options message in schema as yet. If these
// options do get assigned an Options message in future, they need to be
// updated here as well.
op_to_option_["EMBEDDING_LOOKUP"] = "";
op_to_option_["FLOOR"] = "";
op_to_option_["CEIL"] = "";
op_to_option_["HASHTABLE_LOOKUP"] =
""; // TODO(aselle): maybe something else.
op_to_option_["HASHTABLE_LOOKUP"] = "";
op_to_option_["LOGISTIC"] = "";
op_to_option_["RELU"] = "";
op_to_option_["RELU_N1_TO_1"] = "";
op_to_option_["RELU6"] = "";
op_to_option_["TANH"] = "";
op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
op_to_option_["PRELU"] = "";
op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
op_to_option_["SIN"] = "";
op_to_option_["LOG"] = "";
op_to_option_["SQRT"] = "";
op_to_option_["RSQRT"] = "";
op_to_option_["Rank"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
@ -187,6 +186,7 @@ class OpOptionData {
option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
option_to_struct_["LocalResponseNormalizationOptions"] =
"TfLiteLocalResponseNormParams";
option_to_struct_["MirrorPadOptions"] = "TfLiteMirrorPaddingParams";
// Now for every op, try to find an option.
bool fatal = false;
for (auto op_name : ops_) {
@ -226,13 +226,15 @@ class OpOptionData {
if (!param_struct_found) {
std::cerr << "Failed to get param struct for option " << option_name
<< std::endl;
fatal = true;
} else {
option_to_struct_.insert(std::make_pair(option_name, params_guess));
}
}
}
}
if (fatal) {
exit(1);
}
}
private:
@ -243,16 +245,28 @@ class OpOptionData {
option_to_type_function_;
};
void GenerateImportForResizeBilinearOp(FILE* fp) {
fprintf(fp,
" case BuiltinOperator_RESIZE_BILINEAR: {\n"
" const auto* params = reinterpret_cast<const "
"TfLiteResizeBilinearParams*>(builtin_op_data);\n"
" auto union_type = CreateResizeBilinearOptions(*fbb, "
"params->align_corners).Union();\n"
" return std::make_pair(BuiltinOptions_ResizeBilinearOptions, "
"union_type);\n"
" }\n break;\n");
}
void GenerateImportForOp(FILE* fp, const std::string& op_name,
const std::string& option_name,
const std::string& option_type,
const flatbuffers::TypeTable* options,
const std::string& struct_name) {
// Skip tricky ones for now
if (struct_name == "TfLiteResizeBilinearParams") return;
if (struct_name == "TfLiteSqueezeParams") return;
if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
if (struct_name == "TfLiteReshapeParams") return;
// Special-case ResizeBilinear which has some deprecated fields.
if (struct_name == "TfLiteResizeBilinearParams") {
GenerateImportForResizeBilinearOp(fp);
return;
}
fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
fprintf(fp,
@ -262,6 +276,9 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
for (size_t i = 0; i < options->num_elems; i++) {
std::string elem_name = options->names[i];
bool is_int_vector = false;
std::string vector_name = elem_name;
std::string vector_size;
// TODO(aselle): Irregular naming in builtins
if (elem_name == "fused_activation_function")
elem_name = "activation";
@ -273,8 +290,26 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
elem_name = "dilation_height_factor";
else if (elem_name == "dilation_w_factor")
elem_name = "dilation_width_factor";
else if (elem_name == "new_shape")
elem_name = "shape";
else if (elem_name == "idx_out_type")
elem_name = "index_out_type";
// Vector fields treated specially.
if (elem_name == "new_shape") {
is_int_vector = true;
vector_name = "shape";
vector_size = "num_dimensions";
} else if (elem_name == "squeeze_dims") {
is_int_vector = true;
vector_size = "num_squeeze_dims";
}
if (is_int_vector) {
fprintf(fp,
" auto val%zu = fbb->CreateVector("
"std::vector<int>(params->%s, params->%s + params->%s));\n",
i, vector_name.c_str(), vector_name.c_str(), vector_size.c_str());
continue;
}
flatbuffers::TypeCode code = options->type_codes[i];
auto contained_type = code.sequence_ref != -1
@ -293,6 +328,10 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
mapper = "LSTMKernelTypeToSchema";
} else if (contained_type == LSHProjectionTypeTypeTable) {
mapper = "LSHProjectionTypeToSchema";
} else if (contained_type == MirrorPadModeTypeTable) {
mapper = "MirrorPaddingModeToSchema";
} else if (contained_type == CombinerTypeTypeTable) {
mapper = "CombinerTypeToSchema";
}
fprintf(fp,

View File

@ -219,6 +219,11 @@ std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
std::vector<int> output;
output.reserve(input.size());
for (int x : input) {
// Special value representing an optional tensor which is not present.
if (x == -1) {
output.push_back(x);
continue;
}
if (tensor_to_written_tensor_[x] != -1) {
output.push_back(tensor_to_written_tensor_[x]);
}